UNPKG

@maximai/maxim-js

Version:

Maxim AI JS SDK. Visit https://getmaxim.ai for more info.

488 lines 29.8 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.createTestRunBuilder = void 0; const dataset_1 = require("../apis/dataset"); const evaluator_1 = require("../apis/evaluator"); const testRun_1 = require("../apis/testRun"); const csvParser_1 = require("../utils/csvParser"); const semaphore_1 = require("../utils/semaphore"); const defaultLogger_1 = require("./defaultLogger"); const dataset_2 = require("../dataset/dataset"); const utils_1 = require("../utils/utils"); const runUtils_1 = require("./runUtils"); const sanitizationUtils_1 = require("./sanitizationUtils"); const utils_2 = require("./utils"); const createTestRunBuilder = (config) => ({ withDataStructure: (dataStructure) => { (0, dataset_2.sanitizeDataStructure)(dataStructure); return (0, exports.createTestRunBuilder)({ ...config, dataStructure }); }, withData: (data) => { (0, sanitizationUtils_1.sanitizeData)(data, config.dataStructure); return (0, exports.createTestRunBuilder)({ ...config, data }); }, withEvaluators: (...evaluators) => { (0, sanitizationUtils_1.sanitizeEvaluators)(evaluators); return (0, exports.createTestRunBuilder)({ ...config, evaluators: [...evaluators] }); }, withHumanEvaluationConfig: (humanEvaluationConfig) => { const emailRegex = /^(?!\.)(?!.*\.\.)([A-Z0-9_'+\-\.]*)[A-Z0-9_+-]@([A-Z0-9][A-Z0-9\-]*\.)+[A-Z]{2,}$/i; humanEvaluationConfig.emails.forEach((email) => { if (!emailRegex.test(email)) { throw new Error(`Invalid email address: ${email}`); } }); return (0, exports.createTestRunBuilder)({ ...config, humanEvaluationConfig }); }, withPromptVersionId: (id, contextToEvaluate) => (0, exports.createTestRunBuilder)({ ...config, promptVersion: { id, contextToEvaluate } }), withPromptChainVersionId: (id, contextToEvaluate) => (0, exports.createTestRunBuilder)({ ...config, promptChainVersion: { id, contextToEvaluate } }), withWorkflowId: (id, contextToEvaluate) => (0, exports.createTestRunBuilder)({ ...config, workflow: { id, contextToEvaluate } }), yieldsOutput: (outputFunction) => (0, exports.createTestRunBuilder)({ ...config, outputFunction }), withLogger: (logger) => (0, exports.createTestRunBuilder)({ ...config, logger }), getConfig: () => config, withConcurrency: (concurrency) => (0, exports.createTestRunBuilder)({ ...config, concurrency }), run: async (timeoutInMinutes = 15) => { var _a, _b; let errors = []; const logger = (_a = config.logger) !== null && _a !== void 0 ? _a : new defaultLogger_1.DefaultLogger(); logger.info("Running sanitization checks..."); if (!config.name) { errors.push("Name is required to run a test."); } if (!config.workspaceId) { errors.push("Workspace Id is required to run a test."); } if (!config.outputFunction && !config.promptVersion && !config.promptChainVersion && !config.workflow) { errors.push("Output function or prompt version id, prompt chain version id, or workflow id is required to run a test. You can use either yieldsOutput, withPromptVersionId, withPromptChainVersionId or withWorkflowId to set them respectively."); } if ((config.outputFunction ? 1 : 0) + (config.promptVersion ? 1 : 0) + (config.promptChainVersion ? 1 : 0) + (config.workflow ? 1 : 0) !== 1) { errors.push("Exactly one of outputFunction, promptVersionId, promptChainVersionId, or workflowId must be set."); } if (!config.data) { errors.push("Data or dataset id is required to run a test."); } if (errors.length > 0) { throw new Error(`Missing required configuration for test run ${config.name ? ` "${config.name}"` : ""}:\n\t${errors.join(", \n\t")}`, { cause: JSON.stringify({ config }, null, 2), }); } (0, dataset_2.sanitizeDataStructure)(config.dataStructure); (0, sanitizationUtils_1.sanitizeData)(config.data, config.dataStructure); (0, sanitizationUtils_1.sanitizeEvaluators)(config.evaluators); const APIEvaluatorService = new evaluator_1.MaximEvaluatorAPI(config.baseUrl, config.apiKey); const platformEvaluatorsConfig = await Promise.all(config.evaluators .filter((e) => typeof e === "string") .map(async (e) => { const evaluatorConfig = await APIEvaluatorService.fetchPlatformEvaluator(e, config.workspaceId); return evaluatorConfig; })); if (platformEvaluatorsConfig.some((e) => e.type === "Human")) { if (!config.humanEvaluationConfig) { throw new Error("Human evaluator found in evaluators, but no human evaluation config was provided."); } } const dataStructure = config.dataStructure; const concurrency = (_b = config.concurrency) !== null && _b !== void 0 ? _b : 10; const name = config.name; const workspaceId = config.workspaceId; const data = config.data; const testConfigId = config.testConfigId; const evaluators = config.evaluators; const humanEvaluationConfig = config.humanEvaluationConfig; const outputFunction = config.outputFunction; const promptVersion = config.promptVersion; const promptChainVersion = config.promptChainVersion; const workflow = config.workflow; const failedEntryIndices = []; const localEvaluatorNameToIdAndPassFailCriteriaMap = (0, utils_2.getLocalEvaluatorNameToIdAndPassFailCriteriaMap)(evaluators); const APITestRunService = new testRun_1.MaximTestRunAPI(config.baseUrl, config.apiKey); async function processEntry(testRun, index, mappingKeys, getRow, datasetId) { const row = await getRow(index); if (!row) { throw new Error(`No row found at index ${index}`); } const input = mappingKeys.input ? (row.data[mappingKeys.input] ? String(row.data[mappingKeys.input]) : undefined) : undefined; const expectedOutput = mappingKeys.expectedOutput ? row.data[mappingKeys.expectedOutput] ? String(row.data[mappingKeys.expectedOutput]) : undefined : undefined; let contextToEvaluate = (mappingKeys.contextToEvaluate ? row.data[mappingKeys.contextToEvaluate] === null ? undefined : row.data[mappingKeys.contextToEvaluate] : undefined); if (outputFunction || evaluators.filter((e) => typeof e !== "string").length > 0) { let outputFunctionToExecute; if (outputFunction) { outputFunctionToExecute = outputFunction; } else { if (workflow) { outputFunctionToExecute = (0, runUtils_1.workflowIdOutputFunctionClosure)(workflow.id, APITestRunService, workflow.contextToEvaluate); } else if (promptVersion) { outputFunctionToExecute = (0, runUtils_1.promptVersionIdOutputFunctionClosure)(promptVersion.id, input !== null && input !== void 0 ? input : "", APITestRunService, promptVersion.contextToEvaluate); } else if (promptChainVersion) { outputFunctionToExecute = (0, runUtils_1.promptChainVersionIdOutputFunctionClosure)(promptChainVersion.id, input !== null && input !== void 0 ? input : "", APITestRunService, promptChainVersion.contextToEvaluate); } else { throw new Error("Found no output function to execute, please make sure you have either `yieldsOutput`, `withPromptVersionId`, `withPromptChainVersionId` or `withWorkflowId` set."); } } const output = await (0, runUtils_1.runOutputFunction)(outputFunctionToExecute, row.data); if (output.retrievedContextToEvaluate) { if (contextToEvaluate) { logger.info(`Detected retrieved context returned from output function for row ${index + 1} that had contextToEvaluate set from the dataset.\nOverriding the contextToEvaluate from dataset with the retrieved context`); } contextToEvaluate = output.retrievedContextToEvaluate; } let localEvaluationResults = undefined; if (evaluators.length > 0) { localEvaluationResults = await (0, runUtils_1.runLocalEvaluations)(evaluators.filter((e) => typeof e !== "string"), row.data, { output: output.data, contextToEvaluate, }); } await APITestRunService.pushTestRunEntry({ testRun: { ...testRun, datasetId, datasetEntryId: row.id }, runConfig: output.meta ? { cost: output.meta.cost, usage: output.meta.usage ? "completionTokens" in output.meta.usage ? { completion_tokens: output.meta.usage.completionTokens, prompt_tokens: output.meta.usage.promptTokens, total_tokens: output.meta.usage.totalTokens, latency: output.meta.usage.latency, } : { latency: output.meta.usage.latency, } : undefined, } : undefined, entry: { input, output: output.data, expectedOutput, contextToEvaluate, dataEntry: row.data, localEvaluationResults: localEvaluationResults ? localEvaluationResults.map((result) => ({ ...result, id: localEvaluatorNameToIdAndPassFailCriteriaMap.get(result.name).id, })) : undefined, }, }); logger.processed(`Ran test run entry ${index + 1}`, { datasetEntry: row.data, output, evaluationResults: localEvaluationResults, }); return; } await APITestRunService.pushTestRunEntry({ testRun: { ...testRun, datasetId, datasetEntryId: row.id }, entry: { input, expectedOutput, contextToEvaluate: (workflow === null || workflow === void 0 ? void 0 : workflow.contextToEvaluate) ? workflow.contextToEvaluate : (promptVersion === null || promptVersion === void 0 ? void 0 : promptVersion.contextToEvaluate) ? promptVersion.contextToEvaluate : (promptChainVersion === null || promptChainVersion === void 0 ? void 0 : promptChainVersion.contextToEvaluate) ? promptChainVersion.contextToEvaluate : typeof mappingKeys.contextToEvaluate === "string" ? mappingKeys.contextToEvaluate : undefined, dataEntry: row.data, }, }); logger.processed(`Ran test run entry ${index + 1}`, { datasetEntry: row.data, }); } try { logger.info(`Creating test run "${name}"...`); const evalConfig = [ ...platformEvaluatorsConfig, ...Array.from(localEvaluatorNameToIdAndPassFailCriteriaMap.entries()).map(([name, value]) => ({ id: value.id, name, type: "Local", builtin: false, reversed: undefined, config: { passFailCriteria: { entryLevel: { value: typeof value.passFailCriteria.onEachEntry.value === "boolean" ? value.passFailCriteria.onEachEntry.value ? "Yes" : "No" : value.passFailCriteria.onEachEntry.value, operator: value.passFailCriteria.onEachEntry.scoreShouldBe, name: "score", }, runLevel: { value: value.passFailCriteria.forTestrunOverall.value, operator: value.passFailCriteria.forTestrunOverall.overallShouldBe, name: value.passFailCriteria.forTestrunOverall.for === "average" ? "meanScore" : "queriesPassed", }, }, }, })), ]; const testRun = await APITestRunService.createTestRun(name, workspaceId, "SINGLE", evalConfig, evaluators.filter((e) => typeof e !== "string").length > 0 ? true : false, workflow === null || workflow === void 0 ? void 0 : workflow.id, promptVersion === null || promptVersion === void 0 ? void 0 : promptVersion.id, promptChainVersion === null || promptChainVersion === void 0 ? void 0 : promptChainVersion.id, humanEvaluationConfig); try { const semaphore = semaphore_1.Semaphore.get(`${workspaceId}:${name}:${testRun.id}`, concurrency); if (data) { if (dataStructure) { const inputKey = (0, utils_1.getAllKeysByValue)(dataStructure, "INPUT")[0]; const expectedOutputKey = (0, utils_1.getAllKeysByValue)(dataStructure, "EXPECTED_OUTPUT")[0]; const contextToEvaluateKey = (0, utils_1.getAllKeysByValue)(dataStructure, "CONTEXT_TO_EVALUATE")[0]; if (typeof data === "string") { const APIDatasetService = new dataset_1.MaximDatasetAPI(config.baseUrl, config.apiKey); logger.info(`Fetching dataset "${data}" from platform...`); const platformDataStructure = await APIDatasetService.getDatasetDatastructure(data); (0, dataset_2.validateDataStructure)(dataStructure, platformDataStructure); await APITestRunService.attachDatasetToTestRun(testRun.id, data); async function processDatasetEntry(index, datasetId) { try { await semaphore.acquire(); await processEntry(testRun, index, { input: inputKey, expectedOutput: expectedOutputKey, contextToEvaluate: contextToEvaluateKey, }, async (index) => { return (await APIDatasetService.getDatasetRow(datasetId, index)); }, datasetId); } catch (err) { logger.error((0, utils_2.buildErrorMessage)(new Error(`Error while running data entry at index [${index}]`, { cause: err, }))); failedEntryIndices.push(index); } finally { semaphore.release(); } } const totalRows = await APIDatasetService.getDatasetTotalRows(data); const dataEntryPromises = []; for (let i = 0; i < totalRows; i++) { dataEntryPromises.push(processDatasetEntry(i, data)); } await Promise.all(dataEntryPromises); } else if (data instanceof csvParser_1.CSVFile) { const columnStructure = {}; Object.keys(dataStructure).forEach((key, index) => { columnStructure[key] = index; }); const csv = await csvParser_1.CSVFile.restructure(data, columnStructure); async function processCSVEntry(index) { try { await semaphore.acquire(); await processEntry(testRun, index, { input: inputKey, expectedOutput: expectedOutputKey, contextToEvaluate: contextToEvaluateKey, }, async (index) => { return { data: (await csv.getRow(index)) }; }); } catch (err) { logger.error((0, utils_2.buildErrorMessage)(new Error(`Error while running data entry at index [${index}]`, { cause: err, }))); failedEntryIndices.push(index); } finally { semaphore.release(); } } const totalRows = await csv.getRowCount(); const dataEntryPromises = []; for (let i = 0; i < totalRows; i++) { dataEntryPromises.push(processCSVEntry(i)); } await Promise.all(dataEntryPromises); } else if (Array.isArray(data)) { async function processDataEntry(index, getRow) { try { await semaphore.acquire(); await processEntry(testRun, index, { input: inputKey, expectedOutput: expectedOutputKey, contextToEvaluate: contextToEvaluateKey, }, getRow); } catch (err) { logger.error((0, utils_2.buildErrorMessage)(new Error(`Error while running data entry at index [${index}]`, { cause: err, }))); failedEntryIndices.push(index); } finally { semaphore.release(); } } const totalRows = data.length; const dataEntryPromises = []; for (let i = 0; i < totalRows; i++) { dataEntryPromises.push(processDataEntry(i, (index) => ({ data: data[index], }))); } await Promise.all(dataEntryPromises); } else if (typeof data === "function") { async function processDataEntry(mainIndex, index, getRow) { try { await semaphore.acquire(); await processEntry(testRun, index, { input: inputKey, expectedOutput: expectedOutputKey, contextToEvaluate: contextToEvaluateKey, }, getRow); } catch (err) { logger.error((0, utils_2.buildErrorMessage)(new Error(`Error while running data entry at index [${mainIndex}]`, { cause: err, }))); failedEntryIndices.push(mainIndex); } finally { semaphore.release(); } } let page = 0; let index = 0; while (true) { const dataEntryPromises = []; const fetchedData = await data(page++); if (fetchedData === null || fetchedData === undefined) { break; } try { (0, sanitizationUtils_1.sanitizeData)(fetchedData, dataStructure); } catch (err) { if (err && err instanceof Error) { logger.error((0, utils_2.buildErrorMessage)(new Error(`=> Skipping page ${page - 1}\nError while sanitizing reponse as per data structure: ${err.message}\n\tGot response: ${JSON.stringify(fetchedData)}`, { cause: err, }))); } else { logger.error((0, utils_2.buildErrorMessage)(new Error(`=> Skipping page ${page - 1}\nError while sanitizing reponse as per data structure\n\tGot response: ${JSON.stringify(fetchedData)}`, { cause: err, }))); } continue; } for (let i = 0; i < fetchedData.length; i++) { dataEntryPromises.push(processDataEntry(index, i, (index) => ({ data: fetchedData[index], }))); index++; } await Promise.all(dataEntryPromises); } } else { throw new Error(`Invalid data type ${typeof data}. Expected string, CSVFile or array of valid data type.`); } } else { const datasetId = data; const APIDatasetService = new dataset_1.MaximDatasetAPI(config.baseUrl, config.apiKey); logger.info(`Fetching dataset "${datasetId}" from platform...`); const dataStructure = await APIDatasetService.getDatasetDatastructure(datasetId); await APITestRunService.attachDatasetToTestRun(testRun.id, datasetId); const inputKey = (0, utils_1.getAllKeysByValue)(dataStructure, "INPUT")[0]; const expectedOutputKey = (0, utils_1.getAllKeysByValue)(dataStructure, "EXPECTED_OUTPUT")[0]; async function processDatasetEntry(index, datasetId) { try { await semaphore.acquire(); await processEntry(testRun, index, { input: inputKey, expectedOutput: expectedOutputKey, }, async (index) => { return (await APIDatasetService.getDatasetRow(datasetId, index)); }, datasetId); } catch (err) { logger.error((0, utils_2.buildErrorMessage)(new Error(`Error while running data entry at index [${index}]`, { cause: err, }))); failedEntryIndices.push(index); } finally { semaphore.release(); } } const totalRows = await APIDatasetService.getDatasetTotalRows(datasetId); const dataEntryPromises = []; for (let i = 0; i < totalRows; i++) { dataEntryPromises.push(processDatasetEntry(i, datasetId)); } await Promise.all(dataEntryPromises); } } logger.info("Marking test run as processed..."); await APITestRunService.markTestRunProcessed(testRun.id); logger.info(`You can now either quit and view the report on our web portal here: \n\t\t${config.baseUrl}/workspace/${config.workspaceId}/testrun/${testRun.id}\n\tOR\n\tWait for the test run to complete to get back the results to use through the SDK.`); } catch (e) { await APITestRunService.markTestRunFailed(testRun.id); throw e; } let pollCount = 0; const pollingInterval = (0, utils_2.calculatePollingInterval)(timeoutInMinutes, platformEvaluatorsConfig.some((e) => e.type === "AI")); const maxIterations = Math.ceil((Math.round(timeoutInMinutes) * 60) / pollingInterval); logger.info("Waiting for test run to complete..."); logger.info(`Polling interval: ${pollingInterval} seconds`); let status; do { status = await APITestRunService.getTestRunStatus(testRun.id); logger.info(`Test run is ${status.testRunStatus}, breakdown:\n${(0, utils_2.createStatusTable)(status.entryStatus)}`); if (++pollCount > maxIterations) { throw new Error(`Test run is taking over timeout period (${Math.round(timeoutInMinutes)} minutes) to complete, please check the report on our web portal directly: ${config.baseUrl}/workspace/${config.workspaceId}/testrun/${testRun.id}`); } if (!(status.testRunStatus === "FAILED" || status.testRunStatus === "STOPPED" || (status.testRunStatus === "COMPLETE" && status.entryStatus.total === status.entryStatus.completed + status.entryStatus.failed + status.entryStatus.stopped))) { await new Promise((resolve) => setTimeout(resolve, pollingInterval * 1000)); } } while (!(status.testRunStatus === "FAILED" || status.testRunStatus === "STOPPED" || (status.testRunStatus === "COMPLETE" && status.entryStatus.total === status.entryStatus.completed + status.entryStatus.failed + status.entryStatus.stopped))); if (status.testRunStatus === "FAILED") { throw new Error(`💥 Test run failed, please check the report on our web portal: ${config.baseUrl}/workspace/${config.workspaceId}/testrun/${testRun.id}`); } if (status.testRunStatus === "STOPPED") { throw new Error(`🛑 Test run was stopped, please check the report on our web portal: ${config.baseUrl}/workspace/${config.workspaceId}/testrun/${testRun.id}`); } const testRunResult = await APITestRunService.getTestRunFinalResult(testRun.id); testRunResult.link = config.baseUrl + testRunResult.link; logger.info(`Test run "${name}" completed successfully!🎉 \nView the report here: ${testRunResult.link}`); return { testRunResult, failedEntryIndices }; } catch (err) { logger.error((0, utils_2.buildErrorMessage)(new Error(`Error while running test run ${name}`, { cause: err, }))); throw err; } }, }); exports.createTestRunBuilder = createTestRunBuilder; //# sourceMappingURL=testRun.js.map