UNPKG

ai-utils.js

Version:

Build AI applications, chatbots, and agents with JavaScript and TypeScript.

318 lines (317 loc) 10.8 kB
import SecureJSON from "secure-json-parse"; import z from "zod"; import { AbstractModel } from "../../model-function/AbstractModel.js"; import { AsyncQueue } from "../../model-function/generate-text/AsyncQueue.js"; import { parseEventSourceReadableStream } from "../../model-function/generate-text/parseEventSourceReadableStream.js"; import { countTokens } from "../../model-function/tokenize-text/countTokens.js"; import { PromptMappingTextGenerationModel } from "../../prompt/PromptMappingTextGenerationModel.js"; import { callWithRetryAndThrottle } from "../../util/api/callWithRetryAndThrottle.js"; import { createJsonResponseHandler, postJsonToApi, } from "../../util/api/postToApi.js"; import { failedOpenAICallResponseHandler } from "./OpenAIError.js"; import { TikTokenTokenizer } from "./TikTokenTokenizer.js"; /** * @see https://platform.openai.com/docs/models/ * @see https://openai.com/pricing */ export const OPENAI_TEXT_GENERATION_MODELS = { "text-davinci-003": { contextWindowSize: 4096, tokenCostInMillicents: 2, }, "text-davinci-002": { contextWindowSize: 4096, tokenCostInMillicents: 2, }, "code-davinci-002": { contextWindowSize: 8000, tokenCostInMillicents: 2, }, davinci: { contextWindowSize: 2048, tokenCostInMillicents: 2, }, "text-curie-001": { contextWindowSize: 2048, tokenCostInMillicents: 0.2, }, curie: { contextWindowSize: 2048, tokenCostInMillicents: 0.2, }, "text-babbage-001": { contextWindowSize: 2048, tokenCostInMillicents: 0.05, }, babbage: { contextWindowSize: 2048, tokenCostInMillicents: 0.05, }, "text-ada-001": { contextWindowSize: 2048, tokenCostInMillicents: 0.04, }, ada: { contextWindowSize: 2048, tokenCostInMillicents: 0.04, }, }; export const isOpenAITextGenerationModel = (model) => model in OPENAI_TEXT_GENERATION_MODELS; export const calculateOpenAITextGenerationCostInMillicents = ({ model, response, }) => response.usage.total_tokens * OPENAI_TEXT_GENERATION_MODELS[model].tokenCostInMillicents; /** * Create a text generation model that calls the OpenAI text completion API. * * @see https://platform.openai.com/docs/api-reference/completions/create * * @example * const model = new OpenAITextGenerationModel({ * model: "text-davinci-003", * temperature: 0.7, * maxTokens: 500, * retry: retryWithExponentialBackoff({ maxTries: 5 }), * }); * * const { text } = await generateText( * model, * "Write a short story about a robot learning to love:\n\n" * ); */ export class OpenAITextGenerationModel extends AbstractModel { constructor(settings) { super({ settings }); Object.defineProperty(this, "provider", { enumerable: true, configurable: true, writable: true, value: "openai" }); Object.defineProperty(this, "contextWindowSize", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "tokenizer", { enumerable: true, configurable: true, writable: true, value: void 0 }); this.tokenizer = new TikTokenTokenizer({ model: settings.model }); this.contextWindowSize = OPENAI_TEXT_GENERATION_MODELS[settings.model].contextWindowSize; } get modelName() { return this.settings.model; } get apiKey() { const apiKey = this.settings.apiKey ?? process.env.OPENAI_API_KEY; if (apiKey == null) { throw new Error(`OpenAI API key is missing. Pass it as an argument to the constructor or set it as an environment variable named OPENAI_API_KEY.`); } return apiKey; } async countPromptTokens(input) { return countTokens(this.tokenizer, input); } async callAPI(prompt, options) { const { run, settings, responseFormat } = options; const callSettings = Object.assign({ apiKey: this.apiKey, user: this.settings.isUserIdForwardingEnabled ? run?.userId : undefined, }, this.settings, settings, { abortSignal: run?.abortSignal, prompt, responseFormat, }); return callWithRetryAndThrottle({ retry: callSettings.retry, throttle: callSettings.throttle, call: async () => callOpenAITextGenerationAPI(callSettings), }); } generateTextResponse(prompt, options) { return this.callAPI(prompt, { ...options, responseFormat: OpenAITextResponseFormat.json, }); } extractText(response) { return response.choices[0].text; } generateDeltaStreamResponse(prompt, options) { return this.callAPI(prompt, { ...options, responseFormat: OpenAITextResponseFormat.deltaIterable, }); } extractTextDelta(fullDelta) { return fullDelta[0].delta; } mapPrompt(promptMapping) { return new PromptMappingTextGenerationModel({ model: this.withStopTokens(promptMapping.stopTokens), promptMapping, }); } withSettings(additionalSettings) { return new OpenAITextGenerationModel(Object.assign({}, this.settings, additionalSettings)); } get maxCompletionTokens() { return this.settings.maxTokens; } withMaxCompletionTokens(maxCompletionTokens) { return this.withSettings({ maxTokens: maxCompletionTokens }); } withStopTokens(stopTokens) { return this.withSettings({ stop: stopTokens }); } } const openAITextGenerationResponseSchema = z.object({ id: z.string(), object: z.literal("text_completion"), created: z.number(), model: z.string(), choices: z.array(z.object({ text: z.string(), index: z.number(), logprobs: z.nullable(z.any()), finish_reason: z.string(), })), usage: z.object({ prompt_tokens: z.number(), completion_tokens: z.number(), total_tokens: z.number(), }), }); /** * Call the OpenAI Text Completion API to generate a text completion for the given prompt. * * @see https://platform.openai.com/docs/api-reference/completions/create * * @example * const response = await callOpenAITextGenerationAPI({ * apiKey: OPENAI_API_KEY, * model: "text-davinci-003", * prompt: "Write a short story about a robot learning to love:\n\n", * temperature: 0.7, * maxTokens: 500, * }); * * console.log(response.choices[0].text); */ async function callOpenAITextGenerationAPI({ baseUrl = "https://api.openai.com/v1", abortSignal, responseFormat, apiKey, model, prompt, suffix, maxTokens, temperature, topP, n, logprobs, echo, stop, presencePenalty, frequencyPenalty, bestOf, user, }) { return postJsonToApi({ url: `${baseUrl}/completions`, apiKey, body: { stream: responseFormat.stream, model, prompt, suffix, max_tokens: maxTokens, temperature, top_p: topP, n, logprobs, echo, stop, presence_penalty: presencePenalty, frequency_penalty: frequencyPenalty, best_of: bestOf, user, }, failedResponseHandler: failedOpenAICallResponseHandler, successfulResponseHandler: responseFormat.handler, abortSignal, }); } export const OpenAITextResponseFormat = { /** * Returns the response as a JSON object. */ json: { stream: false, handler: createJsonResponseHandler(openAITextGenerationResponseSchema), }, /** * Returns an async iterable over the full deltas (all choices, including full current state at time of event) * of the response stream. */ deltaIterable: { stream: true, handler: async ({ response }) => createOpenAITextFullDeltaIterableQueue(response.body), }, }; const textResponseStreamEventSchema = z.object({ choices: z.array(z.object({ text: z.string(), finish_reason: z.enum(["stop", "length"]).nullable(), index: z.number(), })), created: z.number(), id: z.string(), model: z.string(), object: z.string(), }); async function createOpenAITextFullDeltaIterableQueue(stream) { const queue = new AsyncQueue(); const streamDelta = []; // process the stream asynchonously (no 'await' on purpose): parseEventSourceReadableStream({ stream, callback: (event) => { if (event.type !== "event") { return; } const data = event.data; if (data === "[DONE]") { queue.close(); return; } try { const json = SecureJSON.parse(data); const parseResult = textResponseStreamEventSchema.safeParse(json); if (!parseResult.success) { queue.push({ type: "error", error: parseResult.error, }); queue.close(); return; } const event = parseResult.data; for (let i = 0; i < event.choices.length; i++) { const eventChoice = event.choices[i]; const delta = eventChoice.text; if (streamDelta[i] == null) { streamDelta[i] = { content: "", isComplete: false, delta: "", }; } const choice = streamDelta[i]; choice.delta = delta; if (eventChoice.finish_reason != null) { choice.isComplete = true; } choice.content += delta; } // Since we're mutating the choices array in an async scenario, // we need to make a deep copy: const streamDeltaDeepCopy = JSON.parse(JSON.stringify(streamDelta)); queue.push({ type: "delta", fullDelta: streamDeltaDeepCopy, }); } catch (error) { queue.push({ type: "error", error }); queue.close(); return; } }, }); return queue; }