ai-utils.js
Version:
Build AI applications, chatbots, and agents with JavaScript and TypeScript.
284 lines (283 loc) • 10.1 kB
JavaScript
"use strict";
var __importDefault = (this && this.__importDefault) || function (mod) {
return (mod && mod.__esModule) ? mod : { "default": mod };
};
Object.defineProperty(exports, "__esModule", { value: true });
exports.CohereTextGenerationResponseFormat = exports.CohereTextGenerationModel = exports.COHERE_TEXT_GENERATION_MODELS = void 0;
const secure_json_parse_1 = __importDefault(require("secure-json-parse"));
const zod_1 = require("zod");
const AbstractModel_js_1 = require("../../model-function/AbstractModel.cjs");
const AsyncQueue_js_1 = require("../../model-function/generate-text/AsyncQueue.cjs");
const countTokens_js_1 = require("../../model-function/tokenize-text/countTokens.cjs");
const PromptMappingTextGenerationModel_js_1 = require("../../prompt/PromptMappingTextGenerationModel.cjs");
const callWithRetryAndThrottle_js_1 = require("../../util/api/callWithRetryAndThrottle.cjs");
const postToApi_js_1 = require("../../util/api/postToApi.cjs");
const CohereError_js_1 = require("./CohereError.cjs");
const CohereTokenizer_js_1 = require("./CohereTokenizer.cjs");
exports.COHERE_TEXT_GENERATION_MODELS = {
command: {
contextWindowSize: 2048,
},
"command-nightly": {
contextWindowSize: 2048,
},
"command-light": {
contextWindowSize: 2048,
},
"command-light-nightly": {
contextWindowSize: 2048,
},
};
/**
* Create a text generation model that calls the Cohere Co.Generate API.
*
* @see https://docs.cohere.com/reference/generate
*
* @example
* const model = new CohereTextGenerationModel({
* model: "command-nightly",
* temperature: 0.7,
* maxTokens: 500,
* });
*
* const { text } = await generateText(
* model,
* "Write a short story about a robot learning to love:\n\n"
* );
*/
class CohereTextGenerationModel extends AbstractModel_js_1.AbstractModel {
constructor(settings) {
super({ settings });
Object.defineProperty(this, "provider", {
enumerable: true,
configurable: true,
writable: true,
value: "cohere"
});
Object.defineProperty(this, "contextWindowSize", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "tokenizer", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.contextWindowSize =
exports.COHERE_TEXT_GENERATION_MODELS[this.settings.model].contextWindowSize;
this.tokenizer = new CohereTokenizer_js_1.CohereTokenizer({
baseUrl: this.settings.baseUrl,
apiKey: this.settings.apiKey,
model: this.settings.model,
retry: this.settings.tokenizerSettings?.retry,
throttle: this.settings.tokenizerSettings?.throttle,
});
}
get modelName() {
return this.settings.model;
}
get apiKey() {
const apiKey = this.settings.apiKey ?? process.env.COHERE_API_KEY;
if (apiKey == null) {
throw new Error("No Cohere API key provided. Pass an API key to the constructor or set the COHERE_API_KEY environment variable.");
}
return apiKey;
}
async countPromptTokens(input) {
return (0, countTokens_js_1.countTokens)(this.tokenizer, input);
}
async callAPI(prompt, options) {
const { run, settings, responseFormat } = options;
const callSettings = Object.assign({
apiKey: this.apiKey,
}, this.settings, settings, {
abortSignal: run?.abortSignal,
prompt,
responseFormat,
});
return (0, callWithRetryAndThrottle_js_1.callWithRetryAndThrottle)({
retry: this.settings.retry,
throttle: this.settings.throttle,
call: async () => callCohereTextGenerationAPI(callSettings),
});
}
generateTextResponse(prompt, options) {
return this.callAPI(prompt, {
...options,
responseFormat: exports.CohereTextGenerationResponseFormat.json,
});
}
extractText(response) {
return response.generations[0].text;
}
generateDeltaStreamResponse(prompt, options) {
return this.callAPI(prompt, {
...options,
responseFormat: exports.CohereTextGenerationResponseFormat.deltaIterable,
});
}
extractTextDelta(fullDelta) {
return fullDelta.delta;
}
mapPrompt(promptMapping) {
return new PromptMappingTextGenerationModel_js_1.PromptMappingTextGenerationModel({
model: this.withStopTokens(promptMapping.stopTokens),
promptMapping,
});
}
withSettings(additionalSettings) {
return new CohereTextGenerationModel(Object.assign({}, this.settings, additionalSettings));
}
get maxCompletionTokens() {
return this.settings.maxTokens;
}
withMaxCompletionTokens(maxCompletionTokens) {
return this.withSettings({ maxTokens: maxCompletionTokens });
}
withStopTokens(stopTokens) {
// use endSequences instead of stopSequences
// to exclude stop tokens from the generated text
return this.withSettings({ endSequences: stopTokens });
}
}
exports.CohereTextGenerationModel = CohereTextGenerationModel;
const cohereTextGenerationResponseSchema = zod_1.z.object({
id: zod_1.z.string(),
generations: zod_1.z.array(zod_1.z.object({
id: zod_1.z.string(),
text: zod_1.z.string(),
finish_reason: zod_1.z.string().optional(),
})),
prompt: zod_1.z.string(),
meta: zod_1.z
.object({
api_version: zod_1.z.object({
version: zod_1.z.string(),
}),
})
.optional(),
});
/**
* Call the Cohere Co.Generate API to generate a text completion for the given prompt.
*
* @see https://docs.cohere.com/reference/generate
*
* @example
* const response = await callCohereTextGenerationAPI({
* apiKey: COHERE_API_KEY,
* model: "command-nightly",
* prompt: "Write a short story about a robot learning to love:\n\n",
* temperature: 0.7,
* maxTokens: 500,
* });
*
* console.log(response.generations[0].text);
*/
async function callCohereTextGenerationAPI({ baseUrl = "https://api.cohere.ai/v1", abortSignal, responseFormat, apiKey, model, prompt, numGenerations, maxTokens, temperature, k, p, frequencyPenalty, presencePenalty, endSequences, stopSequences, returnLikelihoods, logitBias, truncate, }) {
return (0, postToApi_js_1.postJsonToApi)({
url: `${baseUrl}/generate`,
apiKey,
body: {
stream: responseFormat.stream,
model,
prompt,
num_generations: numGenerations,
max_tokens: maxTokens,
temperature,
k,
p,
frequency_penalty: frequencyPenalty,
presence_penalty: presencePenalty,
end_sequences: endSequences,
stop_sequences: stopSequences,
return_likelihoods: returnLikelihoods,
logit_bias: logitBias,
truncate,
},
failedResponseHandler: CohereError_js_1.failedCohereCallResponseHandler,
successfulResponseHandler: responseFormat.handler,
abortSignal,
});
}
const cohereTextStreamingResponseSchema = zod_1.z.discriminatedUnion("is_finished", [
zod_1.z.object({
text: zod_1.z.string(),
is_finished: zod_1.z.literal(false),
}),
zod_1.z.object({
is_finished: zod_1.z.literal(true),
finish_reason: zod_1.z.string(),
response: cohereTextGenerationResponseSchema,
}),
]);
async function createCohereTextGenerationFullDeltaIterableQueue(stream) {
const queue = new AsyncQueue_js_1.AsyncQueue();
let accumulatedText = "";
function processLine(line) {
const event = cohereTextStreamingResponseSchema.parse(secure_json_parse_1.default.parse(line));
if (event.is_finished === true) {
queue.push({
type: "delta",
fullDelta: {
content: accumulatedText,
isComplete: true,
delta: "",
},
});
}
else {
accumulatedText += event.text;
queue.push({
type: "delta",
fullDelta: {
content: accumulatedText,
isComplete: false,
delta: event.text,
},
});
}
}
// process the stream asynchonously (no 'await' on purpose):
(async () => {
let unprocessedText = "";
const reader = new ReadableStreamDefaultReader(stream);
const utf8Decoder = new TextDecoder("utf-8");
// eslint-disable-next-line no-constant-condition
while (true) {
const { value: chunk, done } = await reader.read();
if (done) {
break;
}
unprocessedText += utf8Decoder.decode(chunk, { stream: true });
const processableLines = unprocessedText.split(/\r\n|\n|\r/g);
unprocessedText = processableLines.pop() || "";
processableLines.forEach(processLine);
}
// processing remaining text:
if (unprocessedText) {
processLine(unprocessedText);
}
queue.close();
})();
return queue;
}
exports.CohereTextGenerationResponseFormat = {
/**
* Returns the response as a JSON object.
*/
json: {
stream: false,
handler: (0, postToApi_js_1.createJsonResponseHandler)(cohereTextGenerationResponseSchema),
},
/**
* Returns an async iterable over the full deltas (all choices, including full current state at time of event)
* of the response stream.
*/
deltaIterable: {
stream: true,
handler: async ({ response }) => createCohereTextGenerationFullDeltaIterableQueue(response.body),
},
};