ai-utils.js
Version:
Build AI applications, chatbots, and agents with JavaScript and TypeScript.
277 lines (276 loc) • 9.34 kB
JavaScript
import SecureJSON from "secure-json-parse";
import { z } from "zod";
import { AbstractModel } from "../../model-function/AbstractModel.js";
import { AsyncQueue } from "../../model-function/generate-text/AsyncQueue.js";
import { countTokens } from "../../model-function/tokenize-text/countTokens.js";
import { PromptMappingTextGenerationModel } from "../../prompt/PromptMappingTextGenerationModel.js";
import { callWithRetryAndThrottle } from "../../util/api/callWithRetryAndThrottle.js";
import { createJsonResponseHandler, postJsonToApi, } from "../../util/api/postToApi.js";
import { failedCohereCallResponseHandler } from "./CohereError.js";
import { CohereTokenizer } from "./CohereTokenizer.js";
export const COHERE_TEXT_GENERATION_MODELS = {
command: {
contextWindowSize: 2048,
},
"command-nightly": {
contextWindowSize: 2048,
},
"command-light": {
contextWindowSize: 2048,
},
"command-light-nightly": {
contextWindowSize: 2048,
},
};
/**
* Create a text generation model that calls the Cohere Co.Generate API.
*
* @see https://docs.cohere.com/reference/generate
*
* @example
* const model = new CohereTextGenerationModel({
* model: "command-nightly",
* temperature: 0.7,
* maxTokens: 500,
* });
*
* const { text } = await generateText(
* model,
* "Write a short story about a robot learning to love:\n\n"
* );
*/
export class CohereTextGenerationModel extends AbstractModel {
constructor(settings) {
super({ settings });
Object.defineProperty(this, "provider", {
enumerable: true,
configurable: true,
writable: true,
value: "cohere"
});
Object.defineProperty(this, "contextWindowSize", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "tokenizer", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.contextWindowSize =
COHERE_TEXT_GENERATION_MODELS[this.settings.model].contextWindowSize;
this.tokenizer = new CohereTokenizer({
baseUrl: this.settings.baseUrl,
apiKey: this.settings.apiKey,
model: this.settings.model,
retry: this.settings.tokenizerSettings?.retry,
throttle: this.settings.tokenizerSettings?.throttle,
});
}
get modelName() {
return this.settings.model;
}
get apiKey() {
const apiKey = this.settings.apiKey ?? process.env.COHERE_API_KEY;
if (apiKey == null) {
throw new Error("No Cohere API key provided. Pass an API key to the constructor or set the COHERE_API_KEY environment variable.");
}
return apiKey;
}
async countPromptTokens(input) {
return countTokens(this.tokenizer, input);
}
async callAPI(prompt, options) {
const { run, settings, responseFormat } = options;
const callSettings = Object.assign({
apiKey: this.apiKey,
}, this.settings, settings, {
abortSignal: run?.abortSignal,
prompt,
responseFormat,
});
return callWithRetryAndThrottle({
retry: this.settings.retry,
throttle: this.settings.throttle,
call: async () => callCohereTextGenerationAPI(callSettings),
});
}
generateTextResponse(prompt, options) {
return this.callAPI(prompt, {
...options,
responseFormat: CohereTextGenerationResponseFormat.json,
});
}
extractText(response) {
return response.generations[0].text;
}
generateDeltaStreamResponse(prompt, options) {
return this.callAPI(prompt, {
...options,
responseFormat: CohereTextGenerationResponseFormat.deltaIterable,
});
}
extractTextDelta(fullDelta) {
return fullDelta.delta;
}
mapPrompt(promptMapping) {
return new PromptMappingTextGenerationModel({
model: this.withStopTokens(promptMapping.stopTokens),
promptMapping,
});
}
withSettings(additionalSettings) {
return new CohereTextGenerationModel(Object.assign({}, this.settings, additionalSettings));
}
get maxCompletionTokens() {
return this.settings.maxTokens;
}
withMaxCompletionTokens(maxCompletionTokens) {
return this.withSettings({ maxTokens: maxCompletionTokens });
}
withStopTokens(stopTokens) {
// use endSequences instead of stopSequences
// to exclude stop tokens from the generated text
return this.withSettings({ endSequences: stopTokens });
}
}
const cohereTextGenerationResponseSchema = z.object({
id: z.string(),
generations: z.array(z.object({
id: z.string(),
text: z.string(),
finish_reason: z.string().optional(),
})),
prompt: z.string(),
meta: z
.object({
api_version: z.object({
version: z.string(),
}),
})
.optional(),
});
/**
* Call the Cohere Co.Generate API to generate a text completion for the given prompt.
*
* @see https://docs.cohere.com/reference/generate
*
* @example
* const response = await callCohereTextGenerationAPI({
* apiKey: COHERE_API_KEY,
* model: "command-nightly",
* prompt: "Write a short story about a robot learning to love:\n\n",
* temperature: 0.7,
* maxTokens: 500,
* });
*
* console.log(response.generations[0].text);
*/
async function callCohereTextGenerationAPI({ baseUrl = "https://api.cohere.ai/v1", abortSignal, responseFormat, apiKey, model, prompt, numGenerations, maxTokens, temperature, k, p, frequencyPenalty, presencePenalty, endSequences, stopSequences, returnLikelihoods, logitBias, truncate, }) {
return postJsonToApi({
url: `${baseUrl}/generate`,
apiKey,
body: {
stream: responseFormat.stream,
model,
prompt,
num_generations: numGenerations,
max_tokens: maxTokens,
temperature,
k,
p,
frequency_penalty: frequencyPenalty,
presence_penalty: presencePenalty,
end_sequences: endSequences,
stop_sequences: stopSequences,
return_likelihoods: returnLikelihoods,
logit_bias: logitBias,
truncate,
},
failedResponseHandler: failedCohereCallResponseHandler,
successfulResponseHandler: responseFormat.handler,
abortSignal,
});
}
const cohereTextStreamingResponseSchema = z.discriminatedUnion("is_finished", [
z.object({
text: z.string(),
is_finished: z.literal(false),
}),
z.object({
is_finished: z.literal(true),
finish_reason: z.string(),
response: cohereTextGenerationResponseSchema,
}),
]);
async function createCohereTextGenerationFullDeltaIterableQueue(stream) {
const queue = new AsyncQueue();
let accumulatedText = "";
function processLine(line) {
const event = cohereTextStreamingResponseSchema.parse(SecureJSON.parse(line));
if (event.is_finished === true) {
queue.push({
type: "delta",
fullDelta: {
content: accumulatedText,
isComplete: true,
delta: "",
},
});
}
else {
accumulatedText += event.text;
queue.push({
type: "delta",
fullDelta: {
content: accumulatedText,
isComplete: false,
delta: event.text,
},
});
}
}
// process the stream asynchonously (no 'await' on purpose):
(async () => {
let unprocessedText = "";
const reader = new ReadableStreamDefaultReader(stream);
const utf8Decoder = new TextDecoder("utf-8");
// eslint-disable-next-line no-constant-condition
while (true) {
const { value: chunk, done } = await reader.read();
if (done) {
break;
}
unprocessedText += utf8Decoder.decode(chunk, { stream: true });
const processableLines = unprocessedText.split(/\r\n|\n|\r/g);
unprocessedText = processableLines.pop() || "";
processableLines.forEach(processLine);
}
// processing remaining text:
if (unprocessedText) {
processLine(unprocessedText);
}
queue.close();
})();
return queue;
}
export const CohereTextGenerationResponseFormat = {
/**
* Returns the response as a JSON object.
*/
json: {
stream: false,
handler: createJsonResponseHandler(cohereTextGenerationResponseSchema),
},
/**
* Returns an async iterable over the full deltas (all choices, including full current state at time of event)
* of the response stream.
*/
deltaIterable: {
stream: true,
handler: async ({ response }) => createCohereTextGenerationFullDeltaIterableQueue(response.body),
},
};