parea-ai
Version:
Client SDK library to connect to Parea AI.
348 lines (347 loc) • 13.9 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.Parea = void 0;
const types_1 = require("./types");
const api_client_1 = require("./api-client");
const parea_logger_1 = require("./parea_logger");
const helpers_1 = require("./helpers");
const project_1 = require("./project");
const datasets_1 = require("./experiment/datasets");
const experiment_1 = require("./experiment/experiment");
const TraceManager_1 = require("./utils/core/TraceManager");
const COMPLETION_ENDPOINT = '/completion';
const DEPLOYED_PROMPT_ENDPOINT = '/deployed-prompt';
const RECORD_FEEDBACK_ENDPOINT = '/feedback';
const EXPERIMENT_ENDPOINT = '/experiment';
const EXPERIMENT_STATS_ENDPOINT = '/experiment/{experiment_uuid}/stats';
const EXPERIMENT_FINISHED_ENDPOINT = '/experiment/{experiment_uuid}/finished';
const GET_COLLECTION_ENDPOINT = '/collection/{test_collection_identifier}';
const CREATE_COLLECTION_ENDPOINT = '/collection';
const ADD_TEST_CASES_ENDPOINT = '/testcases';
const LIST_EXPERIMENTS_ENDPOINT = '/experiments';
const GET_EXP_LOGS_ENDPOINT = '/experiment/{experiment_uuid}/trace_logs';
const GET_TRACE_LOG_ENDPOINT = '/trace_log/{trace_id}';
const GET_TRACE_LOGS_ENDPOINT = '/get_trace_logs';
const UPDATE_TEST_CASE_ENDPOINT = '/update_test_case/{dataset_id}/{test_case_id}';
/**
* Main class for interacting with the Parea API.
*/
class Parea {
/**
* Creates a new Parea instance.
* @param apiKey - The API key for authentication.
* @param projectName - The name of the project (default: 'default').
*/
constructor(apiKey = '', projectName = 'default') {
this.apiKey = apiKey;
this.client = api_client_1.HTTPClient.getInstance();
this.client.setApiKey(this.apiKey);
this.client.setBaseURL(process.env.PAREA_BASE_URL || 'https://parea-ai-backend-us-9ac16cdbc7a7b006.onporter.run/api/parea/v1');
if (process.env.PAREA_TEST_MODE === 'true') {
this.enableTestMode(true);
}
project_1.pareaProject.setProjectName(projectName);
project_1.pareaProject.setClient(this.client);
parea_logger_1.pareaLogger.setClient(this.client);
if (this.apiKey) {
// fire and forget
// noinspection JSIgnoredPromiseFromCall
this.getProjectUUID();
}
}
/**
* Retrieves and sets the project UUID.
*/
async getProjectUUID() {
this.project_uuid = await project_1.pareaProject.getProjectUUID();
parea_logger_1.pareaLogger.setProjectUUID(this.project_uuid);
}
/**
* Enables or disables test mode.
* @param enable - Whether to enable test mode.
*/
enableTestMode(enable) {
this.client.enableMockMode(enable);
}
/**
* Sets a mock handler for testing.
* @param mockMessage - The mock message to use.
*/
setMockHandler(mockMessage) {
this.client.setMockHandler(mockMessage);
}
/**
* Sends a completion request to the API.
* @param data - The completion request data.
* @returns A promise resolving to the completion response.
*/
async completion(data) {
const requestData = await this.updateDataAndTrace(data);
const response = await this.client.request({
method: 'POST',
endpoint: COMPLETION_ENDPOINT,
data: requestData,
});
return response.data;
}
/**
* Retrieves a deployed prompt from the API.
* @param data - The request data for retrieving the prompt.
* @returns A promise resolving to the deployed prompt response.
*/
async getPrompt(data) {
const response = await this.client.request({ method: 'POST', endpoint: DEPLOYED_PROMPT_ENDPOINT, data });
return response.data;
}
/**
* Records feedback for a completion.
* @param data - The feedback request data.
*/
async recordFeedback(data) {
await new Promise((resolve) => setTimeout(resolve, 2000)); // give logs time to update
await this.client.request({ method: 'POST', endpoint: RECORD_FEEDBACK_ENDPOINT, data });
}
/**
* Creates a new experiment.
* @param data - The experiment creation request data.
* @returns A promise resolving to the created experiment schema.
*/
async createExperiment(data) {
const response = await this.client.request({
method: 'POST',
endpoint: EXPERIMENT_ENDPOINT,
data: {
...data,
project_uuid: await project_1.pareaProject.getProjectUUID(),
},
});
return response.data;
}
/**
* Retrieves statistics for a specific experiment.
* @param experimentUUID - The UUID of the experiment.
* @returns A promise resolving to the experiment statistics.
*/
async getExperimentStats(experimentUUID) {
const response = await this.client.request({
method: 'GET',
endpoint: EXPERIMENT_STATS_ENDPOINT.replace('{experiment_uuid}', experimentUUID),
});
return response.data;
}
/**
* Marks an experiment as finished and retrieves its statistics.
* @param experimentUUID - The UUID of the experiment.
* @param fin_req - The finish experiment request data.
* @returns A promise resolving to the experiment statistics.
*/
async finishExperiment(experimentUUID, fin_req) {
const response = await this.client.request({
method: 'POST',
endpoint: EXPERIMENT_FINISHED_ENDPOINT.replace('{experiment_uuid}', experimentUUID),
data: fin_req,
});
return response.data;
}
/**
* Retrieves a test case collection.
* @param testCollectionIdentifier - The identifier of the test collection.
* @returns A promise resolving to the test case collection or null if not found.
*/
async getCollection(testCollectionIdentifier) {
const response = await this.client.request({
method: 'GET',
endpoint: GET_COLLECTION_ENDPOINT.replace('{test_collection_identifier}', String(testCollectionIdentifier)),
});
if (!response.data) {
console.error(`No test collection found with identifier ${testCollectionIdentifier}`);
return null;
}
return new types_1.TestCaseCollection(response.data.id, response.data.name, response.data.created_at, response.data.last_updated_at, response.data.column_names, response.data.test_cases);
}
/**
* Creates a new test case collection.
* @param data - The test case data.
* @param name - Optional name for the collection.
*/
async createTestCollection(data, name) {
const request = await (0, datasets_1.createTestCollection)(data, name);
await this.client.request({
method: 'POST',
endpoint: CREATE_COLLECTION_ENDPOINT,
data: request,
});
}
/**
* Adds test cases to an existing collection.
* @param data - The test case data to add.
* @param name - Optional name for the test cases.
* @param datasetId - Optional dataset ID to add the test cases to.
*/
async addTestCases(data, name, datasetId) {
const request = {
id: datasetId,
name,
test_cases: await (0, datasets_1.createTestCases)(data),
};
await this.client.request({
method: 'POST',
endpoint: ADD_TEST_CASES_ENDPOINT,
data: request,
});
}
/**
* Updates a specific test case.
* @param testCaseId - The ID of the test case to update.
* @param datasetId - The ID of the dataset containing the test case.
* @param updateRequest - The update request data.
*/
async updateTestCase(testCaseId, datasetId, updateRequest) {
await this.client.request({
method: 'POST',
endpoint: UPDATE_TEST_CASE_ENDPOINT.replace('{dataset_id}', String(datasetId)).replace('{test_case_id}', String(testCaseId)),
data: updateRequest,
});
}
/**
* Instantiates an experiment on a dataset.
* @param name - The name of the experiment.
* @param data - If your dataset is defined locally it should be an iterable of k/v pairs matching the expected inputs of your function. To reference a dataset you have saved on Parea, use the dataset name as a string or the dataset id as an int.
* @param func - The function to run. This function should accept inputs that match the keys of the data field.
* @param options - Additional options for the experiment.
* @returns An Experiment instance.
*/
experiment(name, data, func, options) {
const traceDisabled = process.env.PAREA_TRACE_ENABLED === 'false';
if (traceDisabled) {
throw new Error('Tracing is disabled. Please enable tracing to run experiments.');
}
return new experiment_1.Experiment(name, data, func, options || {}, this);
}
/**
* Lists experiments based on provided filters.
* @param filters - Filters to apply when listing experiments.
* @returns A promise resolving to an array of experiments with stats.
*/
async listExperiments(filters = {}) {
const response = await this.client.request({
method: 'POST',
endpoint: LIST_EXPERIMENTS_ENDPOINT,
data: filters,
});
return response.data;
}
/**
* Retrieves logs for a specific experiment.
* @param experimentUUID - The UUID of the experiment.
* @param filter - Optional filters to apply to the logs.
* @returns A promise resolving to an array of trace log trees.
*/
async getExperimentLogs(experimentUUID, filter = {}) {
const response = await this.client.request({
method: 'POST',
endpoint: GET_EXP_LOGS_ENDPOINT.replace('{experiment_uuid}', experimentUUID),
data: filter,
});
return response.data;
}
/**
* Get the trace log tree for the given trace ID.
* @param traceId - The trace ID to fetch the log for.
* @returns The trace log tree.
*/
async getTraceLog(traceId) {
const response = await this.client.request({
method: 'GET',
endpoint: GET_TRACE_LOG_ENDPOINT.replace('{trace_id}', traceId),
});
return response.data;
}
/**
* Get the evaluation scores from the trace log. If the scores are not present in the trace log, fetch them from the DB.
* @param traceId - The trace ID to get the scores for.
* @param checkContext - If true, will check the context for the scores first before fetching from the DB.
* @returns A list of evaluation results.
*/
async getTraceLogScores(traceId, checkContext = true) {
if (checkContext) {
const traceManager = TraceManager_1.TraceManager.getInstance();
const currentTrace = traceManager.getCurrentTrace();
if (currentTrace) {
const scores = currentTrace.getLog()?.scores;
if (scores) {
return scores;
}
}
}
const response = await this.client.request({
method: 'GET',
endpoint: GET_TRACE_LOG_ENDPOINT.replace('{trace_id}', traceId),
});
const tree = response.data;
return extractScores(tree);
}
/**
* Fetches trace logs for a given query.
* @param queryParams - The query parameters for the trace logs.
* @returns A paginated response of trace logs.
*/
async getTraceLogs(queryParams = types_1.defaultQueryParams) {
const response = await this.client.request({
method: 'POST',
endpoint: GET_TRACE_LOGS_ENDPOINT,
data: { ...types_1.defaultQueryParams, ...queryParams },
});
return response.data;
}
/**
* Updates the data and trace information for a completion request.
* @param data - The completion request data.
* @returns The updated completion request data.
* @private
*/
async updateDataAndTrace(data) {
// @ts-ignore
data = (0, helpers_1.serializeMetadataValues)(data);
const traceManager = TraceManager_1.TraceManager.getInstance();
let experiment_uuid;
const inference_id = (0, helpers_1.genTraceId)();
data.inference_id = inference_id;
data.project_uuid = this.project_uuid || (await project_1.pareaProject.getProjectUUID());
try {
const parentTrace = traceManager.getCurrentTrace();
data.root_trace_id = parentTrace ? parentTrace.getLog().root_trace_id : inference_id;
data.parent_trace_id = parentTrace ? parentTrace.id : undefined;
if (process.env.PAREA_OS_ENV_EXPERIMENT_UUID) {
experiment_uuid = process.env.PAREA_OS_ENV_EXPERIMENT_UUID;
data.experiment_uuid = experiment_uuid;
}
if (parentTrace) {
parentTrace.addChild(inference_id);
}
}
catch (e) {
console.debug(`Error updating trace ids for completion. Trace log will be absent: ${e}`);
}
return data;
}
}
exports.Parea = Parea;
/**
* Extracts evaluation scores from a trace log tree.
* @param tree - The trace log tree to extract scores from.
* @returns An array of evaluation results.
*/
function extractScores(tree) {
const scores = [];
function traverse(node) {
if (node.scores) {
scores.push(...(node.scores || []));
}
for (const child of node.children_logs) {
traverse(child);
}
}
traverse(tree);
return scores;
}