UNPKG

ai-utils.js

Version:

Build AI applications, chatbots, and agents with JavaScript and TypeScript.

277 lines (276 loc) 9.34 kB
import SecureJSON from "secure-json-parse"; import { z } from "zod"; import { AbstractModel } from "../../model-function/AbstractModel.js"; import { AsyncQueue } from "../../model-function/generate-text/AsyncQueue.js"; import { countTokens } from "../../model-function/tokenize-text/countTokens.js"; import { PromptMappingTextGenerationModel } from "../../prompt/PromptMappingTextGenerationModel.js"; import { callWithRetryAndThrottle } from "../../util/api/callWithRetryAndThrottle.js"; import { createJsonResponseHandler, postJsonToApi, } from "../../util/api/postToApi.js"; import { failedCohereCallResponseHandler } from "./CohereError.js"; import { CohereTokenizer } from "./CohereTokenizer.js"; export const COHERE_TEXT_GENERATION_MODELS = { command: { contextWindowSize: 2048, }, "command-nightly": { contextWindowSize: 2048, }, "command-light": { contextWindowSize: 2048, }, "command-light-nightly": { contextWindowSize: 2048, }, }; /** * Create a text generation model that calls the Cohere Co.Generate API. * * @see https://docs.cohere.com/reference/generate * * @example * const model = new CohereTextGenerationModel({ * model: "command-nightly", * temperature: 0.7, * maxTokens: 500, * }); * * const { text } = await generateText( * model, * "Write a short story about a robot learning to love:\n\n" * ); */ export class CohereTextGenerationModel extends AbstractModel { constructor(settings) { super({ settings }); Object.defineProperty(this, "provider", { enumerable: true, configurable: true, writable: true, value: "cohere" }); Object.defineProperty(this, "contextWindowSize", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "tokenizer", { enumerable: true, configurable: true, writable: true, value: void 0 }); this.contextWindowSize = COHERE_TEXT_GENERATION_MODELS[this.settings.model].contextWindowSize; this.tokenizer = new CohereTokenizer({ baseUrl: this.settings.baseUrl, apiKey: this.settings.apiKey, model: this.settings.model, retry: this.settings.tokenizerSettings?.retry, throttle: this.settings.tokenizerSettings?.throttle, }); } get modelName() { return this.settings.model; } get apiKey() { const apiKey = this.settings.apiKey ?? process.env.COHERE_API_KEY; if (apiKey == null) { throw new Error("No Cohere API key provided. Pass an API key to the constructor or set the COHERE_API_KEY environment variable."); } return apiKey; } async countPromptTokens(input) { return countTokens(this.tokenizer, input); } async callAPI(prompt, options) { const { run, settings, responseFormat } = options; const callSettings = Object.assign({ apiKey: this.apiKey, }, this.settings, settings, { abortSignal: run?.abortSignal, prompt, responseFormat, }); return callWithRetryAndThrottle({ retry: this.settings.retry, throttle: this.settings.throttle, call: async () => callCohereTextGenerationAPI(callSettings), }); } generateTextResponse(prompt, options) { return this.callAPI(prompt, { ...options, responseFormat: CohereTextGenerationResponseFormat.json, }); } extractText(response) { return response.generations[0].text; } generateDeltaStreamResponse(prompt, options) { return this.callAPI(prompt, { ...options, responseFormat: CohereTextGenerationResponseFormat.deltaIterable, }); } extractTextDelta(fullDelta) { return fullDelta.delta; } mapPrompt(promptMapping) { return new PromptMappingTextGenerationModel({ model: this.withStopTokens(promptMapping.stopTokens), promptMapping, }); } withSettings(additionalSettings) { return new CohereTextGenerationModel(Object.assign({}, this.settings, additionalSettings)); } get maxCompletionTokens() { return this.settings.maxTokens; } withMaxCompletionTokens(maxCompletionTokens) { return this.withSettings({ maxTokens: maxCompletionTokens }); } withStopTokens(stopTokens) { // use endSequences instead of stopSequences // to exclude stop tokens from the generated text return this.withSettings({ endSequences: stopTokens }); } } const cohereTextGenerationResponseSchema = z.object({ id: z.string(), generations: z.array(z.object({ id: z.string(), text: z.string(), finish_reason: z.string().optional(), })), prompt: z.string(), meta: z .object({ api_version: z.object({ version: z.string(), }), }) .optional(), }); /** * Call the Cohere Co.Generate API to generate a text completion for the given prompt. * * @see https://docs.cohere.com/reference/generate * * @example * const response = await callCohereTextGenerationAPI({ * apiKey: COHERE_API_KEY, * model: "command-nightly", * prompt: "Write a short story about a robot learning to love:\n\n", * temperature: 0.7, * maxTokens: 500, * }); * * console.log(response.generations[0].text); */ async function callCohereTextGenerationAPI({ baseUrl = "https://api.cohere.ai/v1", abortSignal, responseFormat, apiKey, model, prompt, numGenerations, maxTokens, temperature, k, p, frequencyPenalty, presencePenalty, endSequences, stopSequences, returnLikelihoods, logitBias, truncate, }) { return postJsonToApi({ url: `${baseUrl}/generate`, apiKey, body: { stream: responseFormat.stream, model, prompt, num_generations: numGenerations, max_tokens: maxTokens, temperature, k, p, frequency_penalty: frequencyPenalty, presence_penalty: presencePenalty, end_sequences: endSequences, stop_sequences: stopSequences, return_likelihoods: returnLikelihoods, logit_bias: logitBias, truncate, }, failedResponseHandler: failedCohereCallResponseHandler, successfulResponseHandler: responseFormat.handler, abortSignal, }); } const cohereTextStreamingResponseSchema = z.discriminatedUnion("is_finished", [ z.object({ text: z.string(), is_finished: z.literal(false), }), z.object({ is_finished: z.literal(true), finish_reason: z.string(), response: cohereTextGenerationResponseSchema, }), ]); async function createCohereTextGenerationFullDeltaIterableQueue(stream) { const queue = new AsyncQueue(); let accumulatedText = ""; function processLine(line) { const event = cohereTextStreamingResponseSchema.parse(SecureJSON.parse(line)); if (event.is_finished === true) { queue.push({ type: "delta", fullDelta: { content: accumulatedText, isComplete: true, delta: "", }, }); } else { accumulatedText += event.text; queue.push({ type: "delta", fullDelta: { content: accumulatedText, isComplete: false, delta: event.text, }, }); } } // process the stream asynchonously (no 'await' on purpose): (async () => { let unprocessedText = ""; const reader = new ReadableStreamDefaultReader(stream); const utf8Decoder = new TextDecoder("utf-8"); // eslint-disable-next-line no-constant-condition while (true) { const { value: chunk, done } = await reader.read(); if (done) { break; } unprocessedText += utf8Decoder.decode(chunk, { stream: true }); const processableLines = unprocessedText.split(/\r\n|\n|\r/g); unprocessedText = processableLines.pop() || ""; processableLines.forEach(processLine); } // processing remaining text: if (unprocessedText) { processLine(unprocessedText); } queue.close(); })(); return queue; } export const CohereTextGenerationResponseFormat = { /** * Returns the response as a JSON object. */ json: { stream: false, handler: createJsonResponseHandler(cohereTextGenerationResponseSchema), }, /** * Returns an async iterable over the full deltas (all choices, including full current state at time of event) * of the response stream. */ deltaIterable: { stream: true, handler: async ({ response }) => createCohereTextGenerationFullDeltaIterableQueue(response.body), }, };