@maximai/maxim-js
Version:
Maxim AI JS SDK. Visit https://getmaxim.ai for more info.
488 lines • 29.8 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.createTestRunBuilder = void 0;
const dataset_1 = require("../apis/dataset");
const evaluator_1 = require("../apis/evaluator");
const testRun_1 = require("../apis/testRun");
const csvParser_1 = require("../utils/csvParser");
const semaphore_1 = require("../utils/semaphore");
const defaultLogger_1 = require("./defaultLogger");
const dataset_2 = require("../dataset/dataset");
const utils_1 = require("../utils/utils");
const runUtils_1 = require("./runUtils");
const sanitizationUtils_1 = require("./sanitizationUtils");
const utils_2 = require("./utils");
const createTestRunBuilder = (config) => ({
withDataStructure: (dataStructure) => {
(0, dataset_2.sanitizeDataStructure)(dataStructure);
return (0, exports.createTestRunBuilder)({ ...config, dataStructure });
},
withData: (data) => {
(0, sanitizationUtils_1.sanitizeData)(data, config.dataStructure);
return (0, exports.createTestRunBuilder)({ ...config, data });
},
withEvaluators: (...evaluators) => {
(0, sanitizationUtils_1.sanitizeEvaluators)(evaluators);
return (0, exports.createTestRunBuilder)({ ...config, evaluators: [...evaluators] });
},
withHumanEvaluationConfig: (humanEvaluationConfig) => {
const emailRegex = /^(?!\.)(?!.*\.\.)([A-Z0-9_'+\-\.]*)[A-Z0-9_+-]@([A-Z0-9][A-Z0-9\-]*\.)+[A-Z]{2,}$/i;
humanEvaluationConfig.emails.forEach((email) => {
if (!emailRegex.test(email)) {
throw new Error(`Invalid email address: ${email}`);
}
});
return (0, exports.createTestRunBuilder)({ ...config, humanEvaluationConfig });
},
withPromptVersionId: (id, contextToEvaluate) => (0, exports.createTestRunBuilder)({ ...config, promptVersion: { id, contextToEvaluate } }),
withPromptChainVersionId: (id, contextToEvaluate) => (0, exports.createTestRunBuilder)({ ...config, promptChainVersion: { id, contextToEvaluate } }),
withWorkflowId: (id, contextToEvaluate) => (0, exports.createTestRunBuilder)({ ...config, workflow: { id, contextToEvaluate } }),
yieldsOutput: (outputFunction) => (0, exports.createTestRunBuilder)({ ...config, outputFunction }),
withLogger: (logger) => (0, exports.createTestRunBuilder)({ ...config, logger }),
getConfig: () => config,
withConcurrency: (concurrency) => (0, exports.createTestRunBuilder)({ ...config, concurrency }),
run: async (timeoutInMinutes = 15) => {
var _a, _b;
let errors = [];
const logger = (_a = config.logger) !== null && _a !== void 0 ? _a : new defaultLogger_1.DefaultLogger();
logger.info("Running sanitization checks...");
if (!config.name) {
errors.push("Name is required to run a test.");
}
if (!config.workspaceId) {
errors.push("Workspace Id is required to run a test.");
}
if (!config.outputFunction && !config.promptVersion && !config.promptChainVersion && !config.workflow) {
errors.push("Output function or prompt version id, prompt chain version id, or workflow id is required to run a test. You can use either yieldsOutput, withPromptVersionId, withPromptChainVersionId or withWorkflowId to set them respectively.");
}
if ((config.outputFunction ? 1 : 0) + (config.promptVersion ? 1 : 0) + (config.promptChainVersion ? 1 : 0) + (config.workflow ? 1 : 0) !==
1) {
errors.push("Exactly one of outputFunction, promptVersionId, promptChainVersionId, or workflowId must be set.");
}
if (!config.data) {
errors.push("Data or dataset id is required to run a test.");
}
if (errors.length > 0) {
throw new Error(`Missing required configuration for test run ${config.name ? ` "${config.name}"` : ""}:\n\t${errors.join(", \n\t")}`, {
cause: JSON.stringify({ config }, null, 2),
});
}
(0, dataset_2.sanitizeDataStructure)(config.dataStructure);
(0, sanitizationUtils_1.sanitizeData)(config.data, config.dataStructure);
(0, sanitizationUtils_1.sanitizeEvaluators)(config.evaluators);
const APIEvaluatorService = new evaluator_1.MaximEvaluatorAPI(config.baseUrl, config.apiKey);
const platformEvaluatorsConfig = await Promise.all(config.evaluators
.filter((e) => typeof e === "string")
.map(async (e) => {
const evaluatorConfig = await APIEvaluatorService.fetchPlatformEvaluator(e, config.workspaceId);
return evaluatorConfig;
}));
if (platformEvaluatorsConfig.some((e) => e.type === "Human")) {
if (!config.humanEvaluationConfig) {
throw new Error("Human evaluator found in evaluators, but no human evaluation config was provided.");
}
}
const dataStructure = config.dataStructure;
const concurrency = (_b = config.concurrency) !== null && _b !== void 0 ? _b : 10;
const name = config.name;
const workspaceId = config.workspaceId;
const data = config.data;
const testConfigId = config.testConfigId;
const evaluators = config.evaluators;
const humanEvaluationConfig = config.humanEvaluationConfig;
const outputFunction = config.outputFunction;
const promptVersion = config.promptVersion;
const promptChainVersion = config.promptChainVersion;
const workflow = config.workflow;
const failedEntryIndices = [];
const localEvaluatorNameToIdAndPassFailCriteriaMap = (0, utils_2.getLocalEvaluatorNameToIdAndPassFailCriteriaMap)(evaluators);
const APITestRunService = new testRun_1.MaximTestRunAPI(config.baseUrl, config.apiKey);
async function processEntry(testRun, index, mappingKeys, getRow, datasetId) {
const row = await getRow(index);
if (!row) {
throw new Error(`No row found at index ${index}`);
}
const input = mappingKeys.input ? (row.data[mappingKeys.input] ? String(row.data[mappingKeys.input]) : undefined) : undefined;
const expectedOutput = mappingKeys.expectedOutput
? row.data[mappingKeys.expectedOutput]
? String(row.data[mappingKeys.expectedOutput])
: undefined
: undefined;
let contextToEvaluate = (mappingKeys.contextToEvaluate
? row.data[mappingKeys.contextToEvaluate] === null
? undefined
: row.data[mappingKeys.contextToEvaluate]
: undefined);
if (outputFunction || evaluators.filter((e) => typeof e !== "string").length > 0) {
let outputFunctionToExecute;
if (outputFunction) {
outputFunctionToExecute = outputFunction;
}
else {
if (workflow) {
outputFunctionToExecute = (0, runUtils_1.workflowIdOutputFunctionClosure)(workflow.id, APITestRunService, workflow.contextToEvaluate);
}
else if (promptVersion) {
outputFunctionToExecute = (0, runUtils_1.promptVersionIdOutputFunctionClosure)(promptVersion.id, input !== null && input !== void 0 ? input : "", APITestRunService, promptVersion.contextToEvaluate);
}
else if (promptChainVersion) {
outputFunctionToExecute = (0, runUtils_1.promptChainVersionIdOutputFunctionClosure)(promptChainVersion.id, input !== null && input !== void 0 ? input : "", APITestRunService, promptChainVersion.contextToEvaluate);
}
else {
throw new Error("Found no output function to execute, please make sure you have either `yieldsOutput`, `withPromptVersionId`, `withPromptChainVersionId` or `withWorkflowId` set.");
}
}
const output = await (0, runUtils_1.runOutputFunction)(outputFunctionToExecute, row.data);
if (output.retrievedContextToEvaluate) {
if (contextToEvaluate) {
logger.info(`Detected retrieved context returned from output function for row ${index + 1} that had contextToEvaluate set from the dataset.\nOverriding the contextToEvaluate from dataset with the retrieved context`);
}
contextToEvaluate = output.retrievedContextToEvaluate;
}
let localEvaluationResults = undefined;
if (evaluators.length > 0) {
localEvaluationResults = await (0, runUtils_1.runLocalEvaluations)(evaluators.filter((e) => typeof e !== "string"), row.data, {
output: output.data,
contextToEvaluate,
});
}
await APITestRunService.pushTestRunEntry({
testRun: { ...testRun, datasetId, datasetEntryId: row.id },
runConfig: output.meta
? {
cost: output.meta.cost,
usage: output.meta.usage
? "completionTokens" in output.meta.usage
? {
completion_tokens: output.meta.usage.completionTokens,
prompt_tokens: output.meta.usage.promptTokens,
total_tokens: output.meta.usage.totalTokens,
latency: output.meta.usage.latency,
}
: {
latency: output.meta.usage.latency,
}
: undefined,
}
: undefined,
entry: {
input,
output: output.data,
expectedOutput,
contextToEvaluate,
dataEntry: row.data,
localEvaluationResults: localEvaluationResults
? localEvaluationResults.map((result) => ({
...result,
id: localEvaluatorNameToIdAndPassFailCriteriaMap.get(result.name).id,
}))
: undefined,
},
});
logger.processed(`Ran test run entry ${index + 1}`, {
datasetEntry: row.data,
output,
evaluationResults: localEvaluationResults,
});
return;
}
await APITestRunService.pushTestRunEntry({
testRun: { ...testRun, datasetId, datasetEntryId: row.id },
entry: {
input,
expectedOutput,
contextToEvaluate: (workflow === null || workflow === void 0 ? void 0 : workflow.contextToEvaluate)
? workflow.contextToEvaluate
: (promptVersion === null || promptVersion === void 0 ? void 0 : promptVersion.contextToEvaluate)
? promptVersion.contextToEvaluate
: (promptChainVersion === null || promptChainVersion === void 0 ? void 0 : promptChainVersion.contextToEvaluate)
? promptChainVersion.contextToEvaluate
: typeof mappingKeys.contextToEvaluate === "string"
? mappingKeys.contextToEvaluate
: undefined,
dataEntry: row.data,
},
});
logger.processed(`Ran test run entry ${index + 1}`, {
datasetEntry: row.data,
});
}
try {
logger.info(`Creating test run "${name}"...`);
const evalConfig = [
...platformEvaluatorsConfig,
...Array.from(localEvaluatorNameToIdAndPassFailCriteriaMap.entries()).map(([name, value]) => ({
id: value.id,
name,
type: "Local",
builtin: false,
reversed: undefined,
config: {
passFailCriteria: {
entryLevel: {
value: typeof value.passFailCriteria.onEachEntry.value === "boolean"
? value.passFailCriteria.onEachEntry.value
? "Yes"
: "No"
: value.passFailCriteria.onEachEntry.value,
operator: value.passFailCriteria.onEachEntry.scoreShouldBe,
name: "score",
},
runLevel: {
value: value.passFailCriteria.forTestrunOverall.value,
operator: value.passFailCriteria.forTestrunOverall.overallShouldBe,
name: value.passFailCriteria.forTestrunOverall.for === "average" ? "meanScore" : "queriesPassed",
},
},
},
})),
];
const testRun = await APITestRunService.createTestRun(name, workspaceId, "SINGLE", evalConfig, evaluators.filter((e) => typeof e !== "string").length > 0 ? true : false, workflow === null || workflow === void 0 ? void 0 : workflow.id, promptVersion === null || promptVersion === void 0 ? void 0 : promptVersion.id, promptChainVersion === null || promptChainVersion === void 0 ? void 0 : promptChainVersion.id, humanEvaluationConfig);
try {
const semaphore = semaphore_1.Semaphore.get(`${workspaceId}:${name}:${testRun.id}`, concurrency);
if (data) {
if (dataStructure) {
const inputKey = (0, utils_1.getAllKeysByValue)(dataStructure, "INPUT")[0];
const expectedOutputKey = (0, utils_1.getAllKeysByValue)(dataStructure, "EXPECTED_OUTPUT")[0];
const contextToEvaluateKey = (0, utils_1.getAllKeysByValue)(dataStructure, "CONTEXT_TO_EVALUATE")[0];
if (typeof data === "string") {
const APIDatasetService = new dataset_1.MaximDatasetAPI(config.baseUrl, config.apiKey);
logger.info(`Fetching dataset "${data}" from platform...`);
const platformDataStructure = await APIDatasetService.getDatasetDatastructure(data);
(0, dataset_2.validateDataStructure)(dataStructure, platformDataStructure);
await APITestRunService.attachDatasetToTestRun(testRun.id, data);
async function processDatasetEntry(index, datasetId) {
try {
await semaphore.acquire();
await processEntry(testRun, index, {
input: inputKey,
expectedOutput: expectedOutputKey,
contextToEvaluate: contextToEvaluateKey,
}, async (index) => {
return (await APIDatasetService.getDatasetRow(datasetId, index));
}, datasetId);
}
catch (err) {
logger.error((0, utils_2.buildErrorMessage)(new Error(`Error while running data entry at index [${index}]`, {
cause: err,
})));
failedEntryIndices.push(index);
}
finally {
semaphore.release();
}
}
const totalRows = await APIDatasetService.getDatasetTotalRows(data);
const dataEntryPromises = [];
for (let i = 0; i < totalRows; i++) {
dataEntryPromises.push(processDatasetEntry(i, data));
}
await Promise.all(dataEntryPromises);
}
else if (data instanceof csvParser_1.CSVFile) {
const columnStructure = {};
Object.keys(dataStructure).forEach((key, index) => {
columnStructure[key] = index;
});
const csv = await csvParser_1.CSVFile.restructure(data, columnStructure);
async function processCSVEntry(index) {
try {
await semaphore.acquire();
await processEntry(testRun, index, {
input: inputKey,
expectedOutput: expectedOutputKey,
contextToEvaluate: contextToEvaluateKey,
}, async (index) => {
return { data: (await csv.getRow(index)) };
});
}
catch (err) {
logger.error((0, utils_2.buildErrorMessage)(new Error(`Error while running data entry at index [${index}]`, {
cause: err,
})));
failedEntryIndices.push(index);
}
finally {
semaphore.release();
}
}
const totalRows = await csv.getRowCount();
const dataEntryPromises = [];
for (let i = 0; i < totalRows; i++) {
dataEntryPromises.push(processCSVEntry(i));
}
await Promise.all(dataEntryPromises);
}
else if (Array.isArray(data)) {
async function processDataEntry(index, getRow) {
try {
await semaphore.acquire();
await processEntry(testRun, index, {
input: inputKey,
expectedOutput: expectedOutputKey,
contextToEvaluate: contextToEvaluateKey,
}, getRow);
}
catch (err) {
logger.error((0, utils_2.buildErrorMessage)(new Error(`Error while running data entry at index [${index}]`, {
cause: err,
})));
failedEntryIndices.push(index);
}
finally {
semaphore.release();
}
}
const totalRows = data.length;
const dataEntryPromises = [];
for (let i = 0; i < totalRows; i++) {
dataEntryPromises.push(processDataEntry(i, (index) => ({
data: data[index],
})));
}
await Promise.all(dataEntryPromises);
}
else if (typeof data === "function") {
async function processDataEntry(mainIndex, index, getRow) {
try {
await semaphore.acquire();
await processEntry(testRun, index, {
input: inputKey,
expectedOutput: expectedOutputKey,
contextToEvaluate: contextToEvaluateKey,
}, getRow);
}
catch (err) {
logger.error((0, utils_2.buildErrorMessage)(new Error(`Error while running data entry at index [${mainIndex}]`, {
cause: err,
})));
failedEntryIndices.push(mainIndex);
}
finally {
semaphore.release();
}
}
let page = 0;
let index = 0;
while (true) {
const dataEntryPromises = [];
const fetchedData = await data(page++);
if (fetchedData === null || fetchedData === undefined) {
break;
}
try {
(0, sanitizationUtils_1.sanitizeData)(fetchedData, dataStructure);
}
catch (err) {
if (err && err instanceof Error) {
logger.error((0, utils_2.buildErrorMessage)(new Error(`=> Skipping page ${page - 1}\nError while sanitizing reponse as per data structure: ${err.message}\n\tGot response: ${JSON.stringify(fetchedData)}`, {
cause: err,
})));
}
else {
logger.error((0, utils_2.buildErrorMessage)(new Error(`=> Skipping page ${page - 1}\nError while sanitizing reponse as per data structure\n\tGot response: ${JSON.stringify(fetchedData)}`, {
cause: err,
})));
}
continue;
}
for (let i = 0; i < fetchedData.length; i++) {
dataEntryPromises.push(processDataEntry(index, i, (index) => ({
data: fetchedData[index],
})));
index++;
}
await Promise.all(dataEntryPromises);
}
}
else {
throw new Error(`Invalid data type ${typeof data}. Expected string, CSVFile or array of valid data type.`);
}
}
else {
const datasetId = data;
const APIDatasetService = new dataset_1.MaximDatasetAPI(config.baseUrl, config.apiKey);
logger.info(`Fetching dataset "${datasetId}" from platform...`);
const dataStructure = await APIDatasetService.getDatasetDatastructure(datasetId);
await APITestRunService.attachDatasetToTestRun(testRun.id, datasetId);
const inputKey = (0, utils_1.getAllKeysByValue)(dataStructure, "INPUT")[0];
const expectedOutputKey = (0, utils_1.getAllKeysByValue)(dataStructure, "EXPECTED_OUTPUT")[0];
async function processDatasetEntry(index, datasetId) {
try {
await semaphore.acquire();
await processEntry(testRun, index, {
input: inputKey,
expectedOutput: expectedOutputKey,
}, async (index) => {
return (await APIDatasetService.getDatasetRow(datasetId, index));
}, datasetId);
}
catch (err) {
logger.error((0, utils_2.buildErrorMessage)(new Error(`Error while running data entry at index [${index}]`, {
cause: err,
})));
failedEntryIndices.push(index);
}
finally {
semaphore.release();
}
}
const totalRows = await APIDatasetService.getDatasetTotalRows(datasetId);
const dataEntryPromises = [];
for (let i = 0; i < totalRows; i++) {
dataEntryPromises.push(processDatasetEntry(i, datasetId));
}
await Promise.all(dataEntryPromises);
}
}
logger.info("Marking test run as processed...");
await APITestRunService.markTestRunProcessed(testRun.id);
logger.info(`You can now either quit and view the report on our web portal here: \n\t\t${config.baseUrl}/workspace/${config.workspaceId}/testrun/${testRun.id}\n\tOR\n\tWait for the test run to complete to get back the results to use through the SDK.`);
}
catch (e) {
await APITestRunService.markTestRunFailed(testRun.id);
throw e;
}
let pollCount = 0;
const pollingInterval = (0, utils_2.calculatePollingInterval)(timeoutInMinutes, platformEvaluatorsConfig.some((e) => e.type === "AI"));
const maxIterations = Math.ceil((Math.round(timeoutInMinutes) * 60) / pollingInterval);
logger.info("Waiting for test run to complete...");
logger.info(`Polling interval: ${pollingInterval} seconds`);
let status;
do {
status = await APITestRunService.getTestRunStatus(testRun.id);
logger.info(`Test run is ${status.testRunStatus}, breakdown:\n${(0, utils_2.createStatusTable)(status.entryStatus)}`);
if (++pollCount > maxIterations) {
throw new Error(`Test run is taking over timeout period (${Math.round(timeoutInMinutes)} minutes) to complete, please check the report on our web portal directly: ${config.baseUrl}/workspace/${config.workspaceId}/testrun/${testRun.id}`);
}
if (!(status.testRunStatus === "FAILED" ||
status.testRunStatus === "STOPPED" ||
(status.testRunStatus === "COMPLETE" &&
status.entryStatus.total === status.entryStatus.completed + status.entryStatus.failed + status.entryStatus.stopped))) {
await new Promise((resolve) => setTimeout(resolve, pollingInterval * 1000));
}
} while (!(status.testRunStatus === "FAILED" ||
status.testRunStatus === "STOPPED" ||
(status.testRunStatus === "COMPLETE" &&
status.entryStatus.total === status.entryStatus.completed + status.entryStatus.failed + status.entryStatus.stopped)));
if (status.testRunStatus === "FAILED") {
throw new Error(`💥 Test run failed, please check the report on our web portal: ${config.baseUrl}/workspace/${config.workspaceId}/testrun/${testRun.id}`);
}
if (status.testRunStatus === "STOPPED") {
throw new Error(`🛑 Test run was stopped, please check the report on our web portal: ${config.baseUrl}/workspace/${config.workspaceId}/testrun/${testRun.id}`);
}
const testRunResult = await APITestRunService.getTestRunFinalResult(testRun.id);
testRunResult.link = config.baseUrl + testRunResult.link;
logger.info(`Test run "${name}" completed successfully!🎉 \nView the report here: ${testRunResult.link}`);
return { testRunResult, failedEntryIndices };
}
catch (err) {
logger.error((0, utils_2.buildErrorMessage)(new Error(`Error while running test run ${name}`, {
cause: err,
})));
throw err;
}
},
});
exports.createTestRunBuilder = createTestRunBuilder;
//# sourceMappingURL=testRun.js.map