UNPKG

parea-ai

Version:

Client SDK library to connect to Parea AI.

348 lines (347 loc) 13.9 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.Parea = void 0; const types_1 = require("./types"); const api_client_1 = require("./api-client"); const parea_logger_1 = require("./parea_logger"); const helpers_1 = require("./helpers"); const project_1 = require("./project"); const datasets_1 = require("./experiment/datasets"); const experiment_1 = require("./experiment/experiment"); const TraceManager_1 = require("./utils/core/TraceManager"); const COMPLETION_ENDPOINT = '/completion'; const DEPLOYED_PROMPT_ENDPOINT = '/deployed-prompt'; const RECORD_FEEDBACK_ENDPOINT = '/feedback'; const EXPERIMENT_ENDPOINT = '/experiment'; const EXPERIMENT_STATS_ENDPOINT = '/experiment/{experiment_uuid}/stats'; const EXPERIMENT_FINISHED_ENDPOINT = '/experiment/{experiment_uuid}/finished'; const GET_COLLECTION_ENDPOINT = '/collection/{test_collection_identifier}'; const CREATE_COLLECTION_ENDPOINT = '/collection'; const ADD_TEST_CASES_ENDPOINT = '/testcases'; const LIST_EXPERIMENTS_ENDPOINT = '/experiments'; const GET_EXP_LOGS_ENDPOINT = '/experiment/{experiment_uuid}/trace_logs'; const GET_TRACE_LOG_ENDPOINT = '/trace_log/{trace_id}'; const GET_TRACE_LOGS_ENDPOINT = '/get_trace_logs'; const UPDATE_TEST_CASE_ENDPOINT = '/update_test_case/{dataset_id}/{test_case_id}'; /** * Main class for interacting with the Parea API. */ class Parea { /** * Creates a new Parea instance. * @param apiKey - The API key for authentication. * @param projectName - The name of the project (default: 'default'). */ constructor(apiKey = '', projectName = 'default') { this.apiKey = apiKey; this.client = api_client_1.HTTPClient.getInstance(); this.client.setApiKey(this.apiKey); this.client.setBaseURL(process.env.PAREA_BASE_URL || 'https://parea-ai-backend-us-9ac16cdbc7a7b006.onporter.run/api/parea/v1'); if (process.env.PAREA_TEST_MODE === 'true') { this.enableTestMode(true); } project_1.pareaProject.setProjectName(projectName); project_1.pareaProject.setClient(this.client); parea_logger_1.pareaLogger.setClient(this.client); if (this.apiKey) { // fire and forget // noinspection JSIgnoredPromiseFromCall this.getProjectUUID(); } } /** * Retrieves and sets the project UUID. */ async getProjectUUID() { this.project_uuid = await project_1.pareaProject.getProjectUUID(); parea_logger_1.pareaLogger.setProjectUUID(this.project_uuid); } /** * Enables or disables test mode. * @param enable - Whether to enable test mode. */ enableTestMode(enable) { this.client.enableMockMode(enable); } /** * Sets a mock handler for testing. * @param mockMessage - The mock message to use. */ setMockHandler(mockMessage) { this.client.setMockHandler(mockMessage); } /** * Sends a completion request to the API. * @param data - The completion request data. * @returns A promise resolving to the completion response. */ async completion(data) { const requestData = await this.updateDataAndTrace(data); const response = await this.client.request({ method: 'POST', endpoint: COMPLETION_ENDPOINT, data: requestData, }); return response.data; } /** * Retrieves a deployed prompt from the API. * @param data - The request data for retrieving the prompt. * @returns A promise resolving to the deployed prompt response. */ async getPrompt(data) { const response = await this.client.request({ method: 'POST', endpoint: DEPLOYED_PROMPT_ENDPOINT, data }); return response.data; } /** * Records feedback for a completion. * @param data - The feedback request data. */ async recordFeedback(data) { await new Promise((resolve) => setTimeout(resolve, 2000)); // give logs time to update await this.client.request({ method: 'POST', endpoint: RECORD_FEEDBACK_ENDPOINT, data }); } /** * Creates a new experiment. * @param data - The experiment creation request data. * @returns A promise resolving to the created experiment schema. */ async createExperiment(data) { const response = await this.client.request({ method: 'POST', endpoint: EXPERIMENT_ENDPOINT, data: { ...data, project_uuid: await project_1.pareaProject.getProjectUUID(), }, }); return response.data; } /** * Retrieves statistics for a specific experiment. * @param experimentUUID - The UUID of the experiment. * @returns A promise resolving to the experiment statistics. */ async getExperimentStats(experimentUUID) { const response = await this.client.request({ method: 'GET', endpoint: EXPERIMENT_STATS_ENDPOINT.replace('{experiment_uuid}', experimentUUID), }); return response.data; } /** * Marks an experiment as finished and retrieves its statistics. * @param experimentUUID - The UUID of the experiment. * @param fin_req - The finish experiment request data. * @returns A promise resolving to the experiment statistics. */ async finishExperiment(experimentUUID, fin_req) { const response = await this.client.request({ method: 'POST', endpoint: EXPERIMENT_FINISHED_ENDPOINT.replace('{experiment_uuid}', experimentUUID), data: fin_req, }); return response.data; } /** * Retrieves a test case collection. * @param testCollectionIdentifier - The identifier of the test collection. * @returns A promise resolving to the test case collection or null if not found. */ async getCollection(testCollectionIdentifier) { const response = await this.client.request({ method: 'GET', endpoint: GET_COLLECTION_ENDPOINT.replace('{test_collection_identifier}', String(testCollectionIdentifier)), }); if (!response.data) { console.error(`No test collection found with identifier ${testCollectionIdentifier}`); return null; } return new types_1.TestCaseCollection(response.data.id, response.data.name, response.data.created_at, response.data.last_updated_at, response.data.column_names, response.data.test_cases); } /** * Creates a new test case collection. * @param data - The test case data. * @param name - Optional name for the collection. */ async createTestCollection(data, name) { const request = await (0, datasets_1.createTestCollection)(data, name); await this.client.request({ method: 'POST', endpoint: CREATE_COLLECTION_ENDPOINT, data: request, }); } /** * Adds test cases to an existing collection. * @param data - The test case data to add. * @param name - Optional name for the test cases. * @param datasetId - Optional dataset ID to add the test cases to. */ async addTestCases(data, name, datasetId) { const request = { id: datasetId, name, test_cases: await (0, datasets_1.createTestCases)(data), }; await this.client.request({ method: 'POST', endpoint: ADD_TEST_CASES_ENDPOINT, data: request, }); } /** * Updates a specific test case. * @param testCaseId - The ID of the test case to update. * @param datasetId - The ID of the dataset containing the test case. * @param updateRequest - The update request data. */ async updateTestCase(testCaseId, datasetId, updateRequest) { await this.client.request({ method: 'POST', endpoint: UPDATE_TEST_CASE_ENDPOINT.replace('{dataset_id}', String(datasetId)).replace('{test_case_id}', String(testCaseId)), data: updateRequest, }); } /** * Instantiates an experiment on a dataset. * @param name - The name of the experiment. * @param data - If your dataset is defined locally it should be an iterable of k/v pairs matching the expected inputs of your function. To reference a dataset you have saved on Parea, use the dataset name as a string or the dataset id as an int. * @param func - The function to run. This function should accept inputs that match the keys of the data field. * @param options - Additional options for the experiment. * @returns An Experiment instance. */ experiment(name, data, func, options) { const traceDisabled = process.env.PAREA_TRACE_ENABLED === 'false'; if (traceDisabled) { throw new Error('Tracing is disabled. Please enable tracing to run experiments.'); } return new experiment_1.Experiment(name, data, func, options || {}, this); } /** * Lists experiments based on provided filters. * @param filters - Filters to apply when listing experiments. * @returns A promise resolving to an array of experiments with stats. */ async listExperiments(filters = {}) { const response = await this.client.request({ method: 'POST', endpoint: LIST_EXPERIMENTS_ENDPOINT, data: filters, }); return response.data; } /** * Retrieves logs for a specific experiment. * @param experimentUUID - The UUID of the experiment. * @param filter - Optional filters to apply to the logs. * @returns A promise resolving to an array of trace log trees. */ async getExperimentLogs(experimentUUID, filter = {}) { const response = await this.client.request({ method: 'POST', endpoint: GET_EXP_LOGS_ENDPOINT.replace('{experiment_uuid}', experimentUUID), data: filter, }); return response.data; } /** * Get the trace log tree for the given trace ID. * @param traceId - The trace ID to fetch the log for. * @returns The trace log tree. */ async getTraceLog(traceId) { const response = await this.client.request({ method: 'GET', endpoint: GET_TRACE_LOG_ENDPOINT.replace('{trace_id}', traceId), }); return response.data; } /** * Get the evaluation scores from the trace log. If the scores are not present in the trace log, fetch them from the DB. * @param traceId - The trace ID to get the scores for. * @param checkContext - If true, will check the context for the scores first before fetching from the DB. * @returns A list of evaluation results. */ async getTraceLogScores(traceId, checkContext = true) { if (checkContext) { const traceManager = TraceManager_1.TraceManager.getInstance(); const currentTrace = traceManager.getCurrentTrace(); if (currentTrace) { const scores = currentTrace.getLog()?.scores; if (scores) { return scores; } } } const response = await this.client.request({ method: 'GET', endpoint: GET_TRACE_LOG_ENDPOINT.replace('{trace_id}', traceId), }); const tree = response.data; return extractScores(tree); } /** * Fetches trace logs for a given query. * @param queryParams - The query parameters for the trace logs. * @returns A paginated response of trace logs. */ async getTraceLogs(queryParams = types_1.defaultQueryParams) { const response = await this.client.request({ method: 'POST', endpoint: GET_TRACE_LOGS_ENDPOINT, data: { ...types_1.defaultQueryParams, ...queryParams }, }); return response.data; } /** * Updates the data and trace information for a completion request. * @param data - The completion request data. * @returns The updated completion request data. * @private */ async updateDataAndTrace(data) { // @ts-ignore data = (0, helpers_1.serializeMetadataValues)(data); const traceManager = TraceManager_1.TraceManager.getInstance(); let experiment_uuid; const inference_id = (0, helpers_1.genTraceId)(); data.inference_id = inference_id; data.project_uuid = this.project_uuid || (await project_1.pareaProject.getProjectUUID()); try { const parentTrace = traceManager.getCurrentTrace(); data.root_trace_id = parentTrace ? parentTrace.getLog().root_trace_id : inference_id; data.parent_trace_id = parentTrace ? parentTrace.id : undefined; if (process.env.PAREA_OS_ENV_EXPERIMENT_UUID) { experiment_uuid = process.env.PAREA_OS_ENV_EXPERIMENT_UUID; data.experiment_uuid = experiment_uuid; } if (parentTrace) { parentTrace.addChild(inference_id); } } catch (e) { console.debug(`Error updating trace ids for completion. Trace log will be absent: ${e}`); } return data; } } exports.Parea = Parea; /** * Extracts evaluation scores from a trace log tree. * @param tree - The trace log tree to extract scores from. * @returns An array of evaluation results. */ function extractScores(tree) { const scores = []; function traverse(node) { if (node.scores) { scores.push(...(node.scores || [])); } for (const child of node.children_logs) { traverse(child); } } traverse(tree); return scores; }