UNPKG

parea-ai

Version:

Client SDK library to connect to Parea AI.

249 lines (248 loc) 10.6 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.OpenAIWrapper = void 0; exports.patchOpenAI = patchOpenAI; const TraceManager_1 = require("../core/TraceManager"); const StreamHandler_1 = require("../core/StreamHandler"); const message_converters_1 = require("../message-converters"); const model_prices_1 = require("../model-prices"); const helpers_1 = require("../../helpers"); const parser_1 = require("openai/lib/parser"); /** * Wrapper class for OpenAI API methods with tracing functionality. */ class OpenAIWrapper { /** * Wraps an OpenAI API method with tracing functionality. * @param method The method to wrap. * @param thisArg The `this` argument for the method. * @returns The wrapped method. */ static wrapMethod(method, thisArg) { return ((...args) => { return this.traceManager.runInContext(() => { const traceDisabled = process.env.PAREA_TRACE_ENABLED === 'false'; const parentTrace = this.traceManager.getCurrentTrace(); const insideEvalFuncSkipLogging = parentTrace ? parentTrace.getIsRunningEval() : false; if (traceDisabled || insideEvalFuncSkipLogging || this.isBetaCall(args)) { return method.apply(thisArg, args); } const configuration = this.extractConfiguration(args); const traceName = configuration?.model ? `llm-${configuration.model}` : 'llm-openai'; const trace = this.traceManager.createTrace(traceName, {}, true); const result = method.apply(thisArg, args); if (result instanceof Promise) { return result.then((value) => { if (this.isStreamingEnabled(args)) { const streamHandler = new StreamHandler_1.StreamHandler(trace, configuration); return streamHandler.handle(value); } else { this.finalizeTrace(trace, configuration, value); return value; } }, (error) => { this.finalizeTrace(trace, configuration, undefined, error); throw error; }); } else { this.finalizeTrace(trace, configuration, result); return result; } }); }); } /** * Wraps the `beta.chat.completions.parse` method with tracing functionality. * @param method The method to wrap. * @param thisArg The `this` argument for the method. * @returns The wrapped method. */ static wrapBetaParse(method, thisArg) { return ((...args) => { const traceDisabled = process.env.PAREA_TRACE_ENABLED === 'false'; const parentTrace = this.traceManager.getCurrentTrace(); const insideEvalFuncSkipLogging = parentTrace ? parentTrace.getIsRunningEval() : false; if (traceDisabled || insideEvalFuncSkipLogging) { return method.apply(thisArg, args); } const configuration = this.extractConfiguration(args); const traceName = configuration?.model ? `llm-${configuration.model}` : 'llm-openai-beta-parse'; const trace = this.traceManager.createTrace(traceName, {}, true); try { // Get the original create result const createResult = thisArg._client.chat.completions.create(args[0], { ...args[1], headers: { ...args[1]?.headers, 'X-Stainless-Helper-Method': 'beta.chat.completions.parse', }, }); return createResult .then((completion) => { const parsedResult = (0, parser_1.parseChatCompletion)(completion, args[0]); this.finalizeTrace(trace, configuration, parsedResult); return parsedResult; }) .catch((error) => { this.finalizeTrace(trace, configuration, undefined, error); throw error; }); } catch (error) { this.finalizeTrace(trace, configuration, undefined, error); throw error; } }); } /** * Checks if streaming is enabled in the given arguments. * @param args The arguments to check. * @returns True if streaming is enabled, false otherwise. */ static isStreamingEnabled(args) { return args[0]?.stream === true; } /** * Checks if trace was invoked from `beta.chat.completions.parse` method. * @param args The arguments to check. * @returns True if trace was invoked from `beta.chat.completions.parse` method, false otherwise. */ static isBetaCall(args) { if (!Array.isArray(args) || args.length < 2) { return false; } return args[1]?.headers?.['X-Stainless-Helper-Method'] === 'beta.chat.completions.parse'; } /** * Finalizes the trace with the given parameters. * @param trace The trace to finalize. * @param configuration The LLM configuration. * @param result The result of the API call. * @param error Optional error if the API call failed. */ static finalizeTrace(trace, configuration, result, error) { const endTime = new Date(); const end_timestamp = (0, helpers_1.toDateTimeString)(endTime); const latency = (endTime.getTime() - trace.startTime.getTime()) / 1000; const output = result ? this.getOutput(result) : undefined; const status = error ? 'error' : 'success'; if (result?.model) { configuration.model = result?.model; } trace.updateLog({ configuration, output, status, latency, end_timestamp, error: typeof error === 'string' ? error : error?.toString(), input_tokens: result?.usage?.prompt_tokens ?? 0, output_tokens: result?.usage?.completion_tokens ?? 0, total_tokens: result?.usage?.total_tokens ?? 0, cost: this.calculateCost(configuration?.model, result?.usage), }); this.traceManager.finalizeTrace(trace, true); } /** * Extracts the LLM configuration from the given arguments. * @param args The arguments to extract the configuration from. * @returns The extracted LLM configuration. */ static extractConfiguration(args) { try { const [options] = args; const inputs = options; const functions = inputs?.functions || inputs?.tools?.map((tool) => tool?.function) || []; const functionCallDefault = functions?.length > 0 ? 'auto' : null; const modelParams = { temp: inputs.temperature ?? 1.0, max_length: inputs.max_tokens || undefined, top_p: inputs.top_p ?? 1.0, frequency_penalty: inputs.frequency_penalty ?? 0.0, presence_penalty: inputs.presence_penalty ?? 0.0, response_format: inputs?.response_format, }; return { model: inputs?.model, messages: this.getMessages(inputs), functions: functions, function_call: options?.function_call || options?.tool_choice || functionCallDefault, model_params: modelParams, }; } catch (e) { console.error('Error extracting configuration:', e); return {}; } } /** * Calculates the cost of the API call based on the model and usage. * @param model The model used for the API call. * @param usage The token usage information. * @returns The calculated cost. */ static calculateCost(model, usage) { if (!model) { console.error(`Unknown model: ${model}. Please provide a valid OpenAI model name. Known models are: ${Object.keys(model_prices_1.MODEL_COST_MAPPING).join(', ')}`); return 0; } if (!usage) { return 0; } const modelCost = model_prices_1.MODEL_COST_MAPPING[model] || { prompt: 0, completion: 0 }; const promptCost = usage?.prompt_tokens * modelCost.prompt; const completionCost = usage?.completion_tokens * modelCost.completion; return (promptCost + completionCost) / 1000000; } /** * Extracts the output from the API result. * @param result The API result. * @returns The extracted output as a string. */ static getOutput(result) { try { const responseMessage = result?.choices[0]?.message; if (responseMessage) { return this.messageConverter.convert(responseMessage).content; } else { return JSON.stringify(result); } } catch (e) { console.error('Error extracting output:', e); return `${result}`; } } /** * Extracts the messages from OpenAi args. * @param inputs The inputs to extract messages from. * @returns The extracted messages. */ static getMessages(inputs) { try { return inputs?.messages?.map((message) => this.messageConverter.convert(message)); } catch (e) { console.error(`Error extracting messages from: ${inputs}`, e); return []; } } } exports.OpenAIWrapper = OpenAIWrapper; OpenAIWrapper.traceManager = TraceManager_1.TraceManager.getInstance(); OpenAIWrapper.messageConverter = new message_converters_1.OpenAIMessageConverter(); /** * Patches an OpenAI instance with tracing functionality. * @param openai The OpenAI instance to patch. */ function patchOpenAI(openai) { const originalCreate = openai.chat.completions.create; openai.chat.completions.create = OpenAIWrapper.wrapMethod(originalCreate, openai.chat.completions); if (openai.beta?.chat?.completions?.parse) { const originalParse = openai.beta.chat.completions.parse; openai.beta.chat.completions.parse = OpenAIWrapper.wrapBetaParse(originalParse, openai.beta.chat.completions); } }