ai-utils.js
Version:
Build AI applications, chatbots, and agents with JavaScript and TypeScript.
318 lines (317 loc) • 10.8 kB
JavaScript
import SecureJSON from "secure-json-parse";
import z from "zod";
import { AbstractModel } from "../../model-function/AbstractModel.js";
import { AsyncQueue } from "../../model-function/generate-text/AsyncQueue.js";
import { parseEventSourceReadableStream } from "../../model-function/generate-text/parseEventSourceReadableStream.js";
import { countTokens } from "../../model-function/tokenize-text/countTokens.js";
import { PromptMappingTextGenerationModel } from "../../prompt/PromptMappingTextGenerationModel.js";
import { callWithRetryAndThrottle } from "../../util/api/callWithRetryAndThrottle.js";
import { createJsonResponseHandler, postJsonToApi, } from "../../util/api/postToApi.js";
import { failedOpenAICallResponseHandler } from "./OpenAIError.js";
import { TikTokenTokenizer } from "./TikTokenTokenizer.js";
/**
* @see https://platform.openai.com/docs/models/
* @see https://openai.com/pricing
*/
export const OPENAI_TEXT_GENERATION_MODELS = {
"text-davinci-003": {
contextWindowSize: 4096,
tokenCostInMillicents: 2,
},
"text-davinci-002": {
contextWindowSize: 4096,
tokenCostInMillicents: 2,
},
"code-davinci-002": {
contextWindowSize: 8000,
tokenCostInMillicents: 2,
},
davinci: {
contextWindowSize: 2048,
tokenCostInMillicents: 2,
},
"text-curie-001": {
contextWindowSize: 2048,
tokenCostInMillicents: 0.2,
},
curie: {
contextWindowSize: 2048,
tokenCostInMillicents: 0.2,
},
"text-babbage-001": {
contextWindowSize: 2048,
tokenCostInMillicents: 0.05,
},
babbage: {
contextWindowSize: 2048,
tokenCostInMillicents: 0.05,
},
"text-ada-001": {
contextWindowSize: 2048,
tokenCostInMillicents: 0.04,
},
ada: {
contextWindowSize: 2048,
tokenCostInMillicents: 0.04,
},
};
export const isOpenAITextGenerationModel = (model) => model in OPENAI_TEXT_GENERATION_MODELS;
export const calculateOpenAITextGenerationCostInMillicents = ({ model, response, }) => response.usage.total_tokens *
OPENAI_TEXT_GENERATION_MODELS[model].tokenCostInMillicents;
/**
* Create a text generation model that calls the OpenAI text completion API.
*
* @see https://platform.openai.com/docs/api-reference/completions/create
*
* @example
* const model = new OpenAITextGenerationModel({
* model: "text-davinci-003",
* temperature: 0.7,
* maxTokens: 500,
* retry: retryWithExponentialBackoff({ maxTries: 5 }),
* });
*
* const { text } = await generateText(
* model,
* "Write a short story about a robot learning to love:\n\n"
* );
*/
export class OpenAITextGenerationModel extends AbstractModel {
constructor(settings) {
super({ settings });
Object.defineProperty(this, "provider", {
enumerable: true,
configurable: true,
writable: true,
value: "openai"
});
Object.defineProperty(this, "contextWindowSize", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "tokenizer", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.tokenizer = new TikTokenTokenizer({ model: settings.model });
this.contextWindowSize =
OPENAI_TEXT_GENERATION_MODELS[settings.model].contextWindowSize;
}
get modelName() {
return this.settings.model;
}
get apiKey() {
const apiKey = this.settings.apiKey ?? process.env.OPENAI_API_KEY;
if (apiKey == null) {
throw new Error(`OpenAI API key is missing. Pass it as an argument to the constructor or set it as an environment variable named OPENAI_API_KEY.`);
}
return apiKey;
}
async countPromptTokens(input) {
return countTokens(this.tokenizer, input);
}
async callAPI(prompt, options) {
const { run, settings, responseFormat } = options;
const callSettings = Object.assign({
apiKey: this.apiKey,
user: this.settings.isUserIdForwardingEnabled ? run?.userId : undefined,
}, this.settings, settings, {
abortSignal: run?.abortSignal,
prompt,
responseFormat,
});
return callWithRetryAndThrottle({
retry: callSettings.retry,
throttle: callSettings.throttle,
call: async () => callOpenAITextGenerationAPI(callSettings),
});
}
generateTextResponse(prompt, options) {
return this.callAPI(prompt, {
...options,
responseFormat: OpenAITextResponseFormat.json,
});
}
extractText(response) {
return response.choices[0].text;
}
generateDeltaStreamResponse(prompt, options) {
return this.callAPI(prompt, {
...options,
responseFormat: OpenAITextResponseFormat.deltaIterable,
});
}
extractTextDelta(fullDelta) {
return fullDelta[0].delta;
}
mapPrompt(promptMapping) {
return new PromptMappingTextGenerationModel({
model: this.withStopTokens(promptMapping.stopTokens),
promptMapping,
});
}
withSettings(additionalSettings) {
return new OpenAITextGenerationModel(Object.assign({}, this.settings, additionalSettings));
}
get maxCompletionTokens() {
return this.settings.maxTokens;
}
withMaxCompletionTokens(maxCompletionTokens) {
return this.withSettings({ maxTokens: maxCompletionTokens });
}
withStopTokens(stopTokens) {
return this.withSettings({ stop: stopTokens });
}
}
const openAITextGenerationResponseSchema = z.object({
id: z.string(),
object: z.literal("text_completion"),
created: z.number(),
model: z.string(),
choices: z.array(z.object({
text: z.string(),
index: z.number(),
logprobs: z.nullable(z.any()),
finish_reason: z.string(),
})),
usage: z.object({
prompt_tokens: z.number(),
completion_tokens: z.number(),
total_tokens: z.number(),
}),
});
/**
* Call the OpenAI Text Completion API to generate a text completion for the given prompt.
*
* @see https://platform.openai.com/docs/api-reference/completions/create
*
* @example
* const response = await callOpenAITextGenerationAPI({
* apiKey: OPENAI_API_KEY,
* model: "text-davinci-003",
* prompt: "Write a short story about a robot learning to love:\n\n",
* temperature: 0.7,
* maxTokens: 500,
* });
*
* console.log(response.choices[0].text);
*/
async function callOpenAITextGenerationAPI({ baseUrl = "https://api.openai.com/v1", abortSignal, responseFormat, apiKey, model, prompt, suffix, maxTokens, temperature, topP, n, logprobs, echo, stop, presencePenalty, frequencyPenalty, bestOf, user, }) {
return postJsonToApi({
url: `${baseUrl}/completions`,
apiKey,
body: {
stream: responseFormat.stream,
model,
prompt,
suffix,
max_tokens: maxTokens,
temperature,
top_p: topP,
n,
logprobs,
echo,
stop,
presence_penalty: presencePenalty,
frequency_penalty: frequencyPenalty,
best_of: bestOf,
user,
},
failedResponseHandler: failedOpenAICallResponseHandler,
successfulResponseHandler: responseFormat.handler,
abortSignal,
});
}
export const OpenAITextResponseFormat = {
/**
* Returns the response as a JSON object.
*/
json: {
stream: false,
handler: createJsonResponseHandler(openAITextGenerationResponseSchema),
},
/**
* Returns an async iterable over the full deltas (all choices, including full current state at time of event)
* of the response stream.
*/
deltaIterable: {
stream: true,
handler: async ({ response }) => createOpenAITextFullDeltaIterableQueue(response.body),
},
};
const textResponseStreamEventSchema = z.object({
choices: z.array(z.object({
text: z.string(),
finish_reason: z.enum(["stop", "length"]).nullable(),
index: z.number(),
})),
created: z.number(),
id: z.string(),
model: z.string(),
object: z.string(),
});
async function createOpenAITextFullDeltaIterableQueue(stream) {
const queue = new AsyncQueue();
const streamDelta = [];
// process the stream asynchonously (no 'await' on purpose):
parseEventSourceReadableStream({
stream,
callback: (event) => {
if (event.type !== "event") {
return;
}
const data = event.data;
if (data === "[DONE]") {
queue.close();
return;
}
try {
const json = SecureJSON.parse(data);
const parseResult = textResponseStreamEventSchema.safeParse(json);
if (!parseResult.success) {
queue.push({
type: "error",
error: parseResult.error,
});
queue.close();
return;
}
const event = parseResult.data;
for (let i = 0; i < event.choices.length; i++) {
const eventChoice = event.choices[i];
const delta = eventChoice.text;
if (streamDelta[i] == null) {
streamDelta[i] = {
content: "",
isComplete: false,
delta: "",
};
}
const choice = streamDelta[i];
choice.delta = delta;
if (eventChoice.finish_reason != null) {
choice.isComplete = true;
}
choice.content += delta;
}
// Since we're mutating the choices array in an async scenario,
// we need to make a deep copy:
const streamDeltaDeepCopy = JSON.parse(JSON.stringify(streamDelta));
queue.push({
type: "delta",
fullDelta: streamDeltaDeepCopy,
});
}
catch (error) {
queue.push({ type: "error", error });
queue.close();
return;
}
},
});
return queue;
}