UNPKG

ai-utils.js

Version:

Build AI applications, chatbots, and agents with JavaScript and TypeScript.

284 lines (283 loc) 10.1 kB
"use strict"; var __importDefault = (this && this.__importDefault) || function (mod) { return (mod && mod.__esModule) ? mod : { "default": mod }; }; Object.defineProperty(exports, "__esModule", { value: true }); exports.CohereTextGenerationResponseFormat = exports.CohereTextGenerationModel = exports.COHERE_TEXT_GENERATION_MODELS = void 0; const secure_json_parse_1 = __importDefault(require("secure-json-parse")); const zod_1 = require("zod"); const AbstractModel_js_1 = require("../../model-function/AbstractModel.cjs"); const AsyncQueue_js_1 = require("../../model-function/generate-text/AsyncQueue.cjs"); const countTokens_js_1 = require("../../model-function/tokenize-text/countTokens.cjs"); const PromptMappingTextGenerationModel_js_1 = require("../../prompt/PromptMappingTextGenerationModel.cjs"); const callWithRetryAndThrottle_js_1 = require("../../util/api/callWithRetryAndThrottle.cjs"); const postToApi_js_1 = require("../../util/api/postToApi.cjs"); const CohereError_js_1 = require("./CohereError.cjs"); const CohereTokenizer_js_1 = require("./CohereTokenizer.cjs"); exports.COHERE_TEXT_GENERATION_MODELS = { command: { contextWindowSize: 2048, }, "command-nightly": { contextWindowSize: 2048, }, "command-light": { contextWindowSize: 2048, }, "command-light-nightly": { contextWindowSize: 2048, }, }; /** * Create a text generation model that calls the Cohere Co.Generate API. * * @see https://docs.cohere.com/reference/generate * * @example * const model = new CohereTextGenerationModel({ * model: "command-nightly", * temperature: 0.7, * maxTokens: 500, * }); * * const { text } = await generateText( * model, * "Write a short story about a robot learning to love:\n\n" * ); */ class CohereTextGenerationModel extends AbstractModel_js_1.AbstractModel { constructor(settings) { super({ settings }); Object.defineProperty(this, "provider", { enumerable: true, configurable: true, writable: true, value: "cohere" }); Object.defineProperty(this, "contextWindowSize", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "tokenizer", { enumerable: true, configurable: true, writable: true, value: void 0 }); this.contextWindowSize = exports.COHERE_TEXT_GENERATION_MODELS[this.settings.model].contextWindowSize; this.tokenizer = new CohereTokenizer_js_1.CohereTokenizer({ baseUrl: this.settings.baseUrl, apiKey: this.settings.apiKey, model: this.settings.model, retry: this.settings.tokenizerSettings?.retry, throttle: this.settings.tokenizerSettings?.throttle, }); } get modelName() { return this.settings.model; } get apiKey() { const apiKey = this.settings.apiKey ?? process.env.COHERE_API_KEY; if (apiKey == null) { throw new Error("No Cohere API key provided. Pass an API key to the constructor or set the COHERE_API_KEY environment variable."); } return apiKey; } async countPromptTokens(input) { return (0, countTokens_js_1.countTokens)(this.tokenizer, input); } async callAPI(prompt, options) { const { run, settings, responseFormat } = options; const callSettings = Object.assign({ apiKey: this.apiKey, }, this.settings, settings, { abortSignal: run?.abortSignal, prompt, responseFormat, }); return (0, callWithRetryAndThrottle_js_1.callWithRetryAndThrottle)({ retry: this.settings.retry, throttle: this.settings.throttle, call: async () => callCohereTextGenerationAPI(callSettings), }); } generateTextResponse(prompt, options) { return this.callAPI(prompt, { ...options, responseFormat: exports.CohereTextGenerationResponseFormat.json, }); } extractText(response) { return response.generations[0].text; } generateDeltaStreamResponse(prompt, options) { return this.callAPI(prompt, { ...options, responseFormat: exports.CohereTextGenerationResponseFormat.deltaIterable, }); } extractTextDelta(fullDelta) { return fullDelta.delta; } mapPrompt(promptMapping) { return new PromptMappingTextGenerationModel_js_1.PromptMappingTextGenerationModel({ model: this.withStopTokens(promptMapping.stopTokens), promptMapping, }); } withSettings(additionalSettings) { return new CohereTextGenerationModel(Object.assign({}, this.settings, additionalSettings)); } get maxCompletionTokens() { return this.settings.maxTokens; } withMaxCompletionTokens(maxCompletionTokens) { return this.withSettings({ maxTokens: maxCompletionTokens }); } withStopTokens(stopTokens) { // use endSequences instead of stopSequences // to exclude stop tokens from the generated text return this.withSettings({ endSequences: stopTokens }); } } exports.CohereTextGenerationModel = CohereTextGenerationModel; const cohereTextGenerationResponseSchema = zod_1.z.object({ id: zod_1.z.string(), generations: zod_1.z.array(zod_1.z.object({ id: zod_1.z.string(), text: zod_1.z.string(), finish_reason: zod_1.z.string().optional(), })), prompt: zod_1.z.string(), meta: zod_1.z .object({ api_version: zod_1.z.object({ version: zod_1.z.string(), }), }) .optional(), }); /** * Call the Cohere Co.Generate API to generate a text completion for the given prompt. * * @see https://docs.cohere.com/reference/generate * * @example * const response = await callCohereTextGenerationAPI({ * apiKey: COHERE_API_KEY, * model: "command-nightly", * prompt: "Write a short story about a robot learning to love:\n\n", * temperature: 0.7, * maxTokens: 500, * }); * * console.log(response.generations[0].text); */ async function callCohereTextGenerationAPI({ baseUrl = "https://api.cohere.ai/v1", abortSignal, responseFormat, apiKey, model, prompt, numGenerations, maxTokens, temperature, k, p, frequencyPenalty, presencePenalty, endSequences, stopSequences, returnLikelihoods, logitBias, truncate, }) { return (0, postToApi_js_1.postJsonToApi)({ url: `${baseUrl}/generate`, apiKey, body: { stream: responseFormat.stream, model, prompt, num_generations: numGenerations, max_tokens: maxTokens, temperature, k, p, frequency_penalty: frequencyPenalty, presence_penalty: presencePenalty, end_sequences: endSequences, stop_sequences: stopSequences, return_likelihoods: returnLikelihoods, logit_bias: logitBias, truncate, }, failedResponseHandler: CohereError_js_1.failedCohereCallResponseHandler, successfulResponseHandler: responseFormat.handler, abortSignal, }); } const cohereTextStreamingResponseSchema = zod_1.z.discriminatedUnion("is_finished", [ zod_1.z.object({ text: zod_1.z.string(), is_finished: zod_1.z.literal(false), }), zod_1.z.object({ is_finished: zod_1.z.literal(true), finish_reason: zod_1.z.string(), response: cohereTextGenerationResponseSchema, }), ]); async function createCohereTextGenerationFullDeltaIterableQueue(stream) { const queue = new AsyncQueue_js_1.AsyncQueue(); let accumulatedText = ""; function processLine(line) { const event = cohereTextStreamingResponseSchema.parse(secure_json_parse_1.default.parse(line)); if (event.is_finished === true) { queue.push({ type: "delta", fullDelta: { content: accumulatedText, isComplete: true, delta: "", }, }); } else { accumulatedText += event.text; queue.push({ type: "delta", fullDelta: { content: accumulatedText, isComplete: false, delta: event.text, }, }); } } // process the stream asynchonously (no 'await' on purpose): (async () => { let unprocessedText = ""; const reader = new ReadableStreamDefaultReader(stream); const utf8Decoder = new TextDecoder("utf-8"); // eslint-disable-next-line no-constant-condition while (true) { const { value: chunk, done } = await reader.read(); if (done) { break; } unprocessedText += utf8Decoder.decode(chunk, { stream: true }); const processableLines = unprocessedText.split(/\r\n|\n|\r/g); unprocessedText = processableLines.pop() || ""; processableLines.forEach(processLine); } // processing remaining text: if (unprocessedText) { processLine(unprocessedText); } queue.close(); })(); return queue; } exports.CohereTextGenerationResponseFormat = { /** * Returns the response as a JSON object. */ json: { stream: false, handler: (0, postToApi_js_1.createJsonResponseHandler)(cohereTextGenerationResponseSchema), }, /** * Returns an async iterable over the full deltas (all choices, including full current state at time of event) * of the response stream. */ deltaIterable: { stream: true, handler: async ({ response }) => createCohereTextGenerationFullDeltaIterableQueue(response.body), }, };