UNPKG

aiwrapper

Version:

A Universal AI Wrapper for JavaScript & TypeScript

100 lines 4.21 kB
import { httpRequestWithRetry as fetch, } from "../../http-request.js"; import { processResponseStream } from "../../process-response-stream.js"; import { LangResultWithMessages, LanguageProvider, } from "../language-provider.js"; import { models } from 'aimodels'; import { calculateModelResponseTokens } from "../utils/token-calculator.js"; export class CohereLang extends LanguageProvider { constructor(options) { const modelName = options.model || "command-r-plus-08-2024"; super(modelName); // Get model info from aimodels const modelInfo = models.id(modelName); if (!modelInfo) { console.error(`Invalid Cohere model: ${modelName}. Model not found in aimodels database.`); } this.modelInfo = modelInfo; this._apiKey = options.apiKey; this._model = modelName; this._systemPrompt = options.systemPrompt || ""; this._maxTokens = options.maxTokens; } async ask(prompt, onResult) { const messages = []; if (this._systemPrompt) { messages.push({ role: "system", content: this._systemPrompt, }); } messages.push({ role: "user", content: prompt, }); return await this.chat(messages, onResult); } async chat(messages, onResult) { const result = new LangResultWithMessages(messages); // Transform messages to Cohere's format (only user/assistant roles) const transformedMessages = messages.map(msg => ({ role: msg.role === "assistant" ? "assistant" : "user", content: msg.content })); // Calculate max tokens if we have model info let maxTokens = this._maxTokens; if (this.modelInfo && !maxTokens) { maxTokens = calculateModelResponseTokens(this.modelInfo, messages, this._maxTokens); } const requestBody = { messages: transformedMessages, model: this._model, stream: true, max_tokens: maxTokens, temperature: 0.7, preamble_override: this._systemPrompt || undefined, }; const response = await fetch(`https://api.cohere.com/v2/chat?alt=sse`, { method: "POST", headers: { "Content-Type": "application/json", "Authorization": `Bearer ${this._apiKey}`, "Accept": "text/event-stream", }, body: JSON.stringify(requestBody), onNotOkResponse: async (res, decision) => { if (res.status === 401) { decision.retry = false; throw new Error("API key is invalid. Please check your API key and try again."); } if (res.status === 400 || res.status === 422) { const data = await res.text(); decision.retry = false; throw new Error(data); } return decision; }, }).catch((err) => { throw new Error(err); }); const onData = (data) => { var _a, _b, _c; if (data.type === "message-end") { result.finished = true; onResult === null || onResult === void 0 ? void 0 : onResult(result); return; } // Handle Cohere's streaming format if (data.type === "content-delta" && ((_c = (_b = (_a = data.delta) === null || _a === void 0 ? void 0 : _a.message) === null || _b === void 0 ? void 0 : _b.content) === null || _c === void 0 ? void 0 : _c.text)) { const text = data.delta.message.content.text; result.answer += text; result.messages = [...messages, { role: "assistant", content: result.answer, }]; onResult === null || onResult === void 0 ? void 0 : onResult(result); } }; await processResponseStream(response, onData); return result; } } //# sourceMappingURL=cohere-lang.js.map