aiwrapper
Version:
A Universal AI Wrapper for JavaScript & TypeScript
100 lines • 4.21 kB
JavaScript
import { httpRequestWithRetry as fetch, } from "../../http-request.js";
import { processResponseStream } from "../../process-response-stream.js";
import { LangResultWithMessages, LanguageProvider, } from "../language-provider.js";
import { models } from 'aimodels';
import { calculateModelResponseTokens } from "../utils/token-calculator.js";
export class CohereLang extends LanguageProvider {
constructor(options) {
const modelName = options.model || "command-r-plus-08-2024";
super(modelName);
// Get model info from aimodels
const modelInfo = models.id(modelName);
if (!modelInfo) {
console.error(`Invalid Cohere model: ${modelName}. Model not found in aimodels database.`);
}
this.modelInfo = modelInfo;
this._apiKey = options.apiKey;
this._model = modelName;
this._systemPrompt = options.systemPrompt || "";
this._maxTokens = options.maxTokens;
}
async ask(prompt, onResult) {
const messages = [];
if (this._systemPrompt) {
messages.push({
role: "system",
content: this._systemPrompt,
});
}
messages.push({
role: "user",
content: prompt,
});
return await this.chat(messages, onResult);
}
async chat(messages, onResult) {
const result = new LangResultWithMessages(messages);
// Transform messages to Cohere's format (only user/assistant roles)
const transformedMessages = messages.map(msg => ({
role: msg.role === "assistant" ? "assistant" : "user",
content: msg.content
}));
// Calculate max tokens if we have model info
let maxTokens = this._maxTokens;
if (this.modelInfo && !maxTokens) {
maxTokens = calculateModelResponseTokens(this.modelInfo, messages, this._maxTokens);
}
const requestBody = {
messages: transformedMessages,
model: this._model,
stream: true,
max_tokens: maxTokens,
temperature: 0.7,
preamble_override: this._systemPrompt || undefined,
};
const response = await fetch(`https://api.cohere.com/v2/chat?alt=sse`, {
method: "POST",
headers: {
"Content-Type": "application/json",
"Authorization": `Bearer ${this._apiKey}`,
"Accept": "text/event-stream",
},
body: JSON.stringify(requestBody),
onNotOkResponse: async (res, decision) => {
if (res.status === 401) {
decision.retry = false;
throw new Error("API key is invalid. Please check your API key and try again.");
}
if (res.status === 400 || res.status === 422) {
const data = await res.text();
decision.retry = false;
throw new Error(data);
}
return decision;
},
}).catch((err) => {
throw new Error(err);
});
const onData = (data) => {
var _a, _b, _c;
if (data.type === "message-end") {
result.finished = true;
onResult === null || onResult === void 0 ? void 0 : onResult(result);
return;
}
// Handle Cohere's streaming format
if (data.type === "content-delta" && ((_c = (_b = (_a = data.delta) === null || _a === void 0 ? void 0 : _a.message) === null || _b === void 0 ? void 0 : _b.content) === null || _c === void 0 ? void 0 : _c.text)) {
const text = data.delta.message.content.text;
result.answer += text;
result.messages = [...messages, {
role: "assistant",
content: result.answer,
}];
onResult === null || onResult === void 0 ? void 0 : onResult(result);
}
};
await processResponseStream(response, onData);
return result;
}
}
//# sourceMappingURL=cohere-lang.js.map