UNPKG

@ai-sdk/cohere

Version:

The **[Cohere provider](https://sdk.vercel.ai/providers/ai-sdk-providers/cohere)** for the [AI SDK](https://sdk.vercel.ai/docs) contains language model support for the Cohere API.

697 lines (687 loc) 21.9 kB
"use strict"; var __defProp = Object.defineProperty; var __getOwnPropDesc = Object.getOwnPropertyDescriptor; var __getOwnPropNames = Object.getOwnPropertyNames; var __hasOwnProp = Object.prototype.hasOwnProperty; var __export = (target, all) => { for (var name in all) __defProp(target, name, { get: all[name], enumerable: true }); }; var __copyProps = (to, from, except, desc) => { if (from && typeof from === "object" || typeof from === "function") { for (let key of __getOwnPropNames(from)) if (!__hasOwnProp.call(to, key) && key !== except) __defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable }); } return to; }; var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod); // src/index.ts var src_exports = {}; __export(src_exports, { cohere: () => cohere, createCohere: () => createCohere }); module.exports = __toCommonJS(src_exports); // src/cohere-provider.ts var import_provider_utils4 = require("@ai-sdk/provider-utils"); // src/cohere-chat-language-model.ts var import_provider3 = require("@ai-sdk/provider"); var import_provider_utils2 = require("@ai-sdk/provider-utils"); var import_zod2 = require("zod"); // src/cohere-error.ts var import_provider_utils = require("@ai-sdk/provider-utils"); var import_zod = require("zod"); var cohereErrorDataSchema = import_zod.z.object({ message: import_zod.z.string() }); var cohereFailedResponseHandler = (0, import_provider_utils.createJsonErrorResponseHandler)({ errorSchema: cohereErrorDataSchema, errorToMessage: (data) => data.message }); // src/convert-to-cohere-chat-prompt.ts var import_provider = require("@ai-sdk/provider"); function convertToCohereChatPrompt(prompt) { const messages = []; for (const { role, content } of prompt) { switch (role) { case "system": { messages.push({ role: "system", content }); break; } case "user": { messages.push({ role: "user", content: content.map((part) => { switch (part.type) { case "text": { return part.text; } case "image": { throw new import_provider.UnsupportedFunctionalityError({ functionality: "image-part" }); } } }).join("") }); break; } case "assistant": { let text = ""; const toolCalls = []; for (const part of content) { switch (part.type) { case "text": { text += part.text; break; } case "tool-call": { toolCalls.push({ id: part.toolCallId, type: "function", function: { name: part.toolName, arguments: JSON.stringify(part.args) } }); break; } } } messages.push({ role: "assistant", content: toolCalls.length > 0 ? void 0 : text, tool_calls: toolCalls.length > 0 ? toolCalls : void 0, tool_plan: void 0 }); break; } case "tool": { messages.push( ...content.map((toolResult) => ({ role: "tool", content: JSON.stringify(toolResult.result), tool_call_id: toolResult.toolCallId })) ); break; } default: { const _exhaustiveCheck = role; throw new Error(`Unsupported role: ${_exhaustiveCheck}`); } } } return messages; } // src/map-cohere-finish-reason.ts function mapCohereFinishReason(finishReason) { switch (finishReason) { case "COMPLETE": case "STOP_SEQUENCE": return "stop"; case "MAX_TOKENS": return "length"; case "ERROR": return "error"; case "TOOL_CALL": return "tool-calls"; default: return "unknown"; } } // src/cohere-prepare-tools.ts var import_provider2 = require("@ai-sdk/provider"); function prepareTools(mode) { var _a; const tools = ((_a = mode.tools) == null ? void 0 : _a.length) ? mode.tools : void 0; const toolWarnings = []; if (tools == null) { return { tools: void 0, toolChoice: void 0, toolWarnings }; } const cohereTools = []; for (const tool of tools) { if (tool.type === "provider-defined") { toolWarnings.push({ type: "unsupported-tool", tool }); } else { cohereTools.push({ type: "function", function: { name: tool.name, description: tool.description, parameters: tool.parameters } }); } } const toolChoice = mode.toolChoice; if (toolChoice == null) { return { tools: cohereTools, toolChoice: void 0, toolWarnings }; } const type = toolChoice.type; switch (type) { case "auto": return { tools: cohereTools, toolChoice: void 0, toolWarnings }; case "none": return { tools: cohereTools, toolChoice: "NONE", toolWarnings }; case "required": return { tools: cohereTools, toolChoice: "REQUIRED", toolWarnings }; case "tool": return { tools: cohereTools.filter( (tool) => tool.function.name === toolChoice.toolName ), toolChoice: "REQUIRED", toolWarnings }; default: { const _exhaustiveCheck = type; throw new import_provider2.UnsupportedFunctionalityError({ functionality: `Unsupported tool choice type: ${_exhaustiveCheck}` }); } } } // src/cohere-chat-language-model.ts var CohereChatLanguageModel = class { constructor(modelId, settings, config) { this.specificationVersion = "v1"; this.defaultObjectGenerationMode = "json"; this.modelId = modelId; this.settings = settings; this.config = config; } get provider() { return this.config.provider; } getArgs({ mode, prompt, maxTokens, temperature, topP, topK, frequencyPenalty, presencePenalty, stopSequences, responseFormat, seed }) { var _a; const type = mode.type; const chatPrompt = convertToCohereChatPrompt(prompt); const baseArgs = { // model id: model: this.modelId, // standardized settings: frequency_penalty: frequencyPenalty, presence_penalty: presencePenalty, max_tokens: maxTokens, temperature, p: topP, k: topK, seed, stop_sequences: stopSequences, // response format: response_format: (responseFormat == null ? void 0 : responseFormat.type) === "json" ? { type: "json_object", json_schema: responseFormat.schema } : void 0, // messages: messages: chatPrompt }; switch (type) { case "regular": { const { tools, toolChoice, toolWarnings } = prepareTools(mode); return { args: { ...baseArgs, tools, tool_choice: toolChoice }, warnings: toolWarnings }; } case "object-json": { return { args: { ...baseArgs, response_format: mode.schema == null ? { type: "json_object" } : { type: "json_object", json_schema: mode.schema } }, warnings: [] }; } case "object-tool": { return { args: { ...baseArgs, tools: [ { type: "function", function: { name: mode.tool.name, description: (_a = mode.tool.description) != null ? _a : "", parameters: mode.tool.parameters } } ], tool_choice: "REQUIRED" }, warnings: [] }; } default: { const _exhaustiveCheck = type; throw new import_provider3.UnsupportedFunctionalityError({ functionality: `Unsupported mode: ${_exhaustiveCheck}` }); } } } async doGenerate(options) { var _a, _b, _c, _d; const { args, warnings } = this.getArgs(options); const { responseHeaders, value: response, rawValue: rawResponse } = await (0, import_provider_utils2.postJsonToApi)({ url: `${this.config.baseURL}/chat`, headers: (0, import_provider_utils2.combineHeaders)(this.config.headers(), options.headers), body: args, failedResponseHandler: cohereFailedResponseHandler, successfulResponseHandler: (0, import_provider_utils2.createJsonResponseHandler)( cohereChatResponseSchema ), abortSignal: options.abortSignal, fetch: this.config.fetch }); const { messages, ...rawSettings } = args; const text = (_c = (_b = (_a = response.message.content) == null ? void 0 : _a[0]) == null ? void 0 : _b.text) != null ? _c : ""; return { text, toolCalls: response.message.tool_calls ? response.message.tool_calls.map((toolCall) => ({ toolCallId: toolCall.id, toolName: toolCall.function.name, // Cohere sometimes returns `null` for tool call arguments for tools // defined as having no arguments. args: toolCall.function.arguments.replace(/^null$/, "{}"), toolCallType: "function" })) : [], finishReason: mapCohereFinishReason(response.finish_reason), usage: { promptTokens: response.usage.tokens.input_tokens, completionTokens: response.usage.tokens.output_tokens }, rawCall: { rawPrompt: { messages }, rawSettings }, response: { id: (_d = response.generation_id) != null ? _d : void 0 }, rawResponse: { headers: responseHeaders, body: rawResponse }, warnings, request: { body: JSON.stringify(args) } }; } async doStream(options) { const { args, warnings } = this.getArgs(options); const { responseHeaders, value: response } = await (0, import_provider_utils2.postJsonToApi)({ url: `${this.config.baseURL}/chat`, headers: (0, import_provider_utils2.combineHeaders)(this.config.headers(), options.headers), body: { ...args, stream: true }, failedResponseHandler: cohereFailedResponseHandler, successfulResponseHandler: (0, import_provider_utils2.createEventSourceResponseHandler)( cohereChatChunkSchema ), abortSignal: options.abortSignal, fetch: this.config.fetch }); const { messages, ...rawSettings } = args; let finishReason = "unknown"; let usage = { promptTokens: Number.NaN, completionTokens: Number.NaN }; let pendingToolCallDelta = { toolCallId: "", toolName: "", argsTextDelta: "" }; return { stream: response.pipeThrough( new TransformStream({ transform(chunk, controller) { var _a, _b; if (!chunk.success) { finishReason = "error"; controller.enqueue({ type: "error", error: chunk.error }); return; } const value = chunk.value; const type = value.type; switch (type) { case "content-delta": { controller.enqueue({ type: "text-delta", textDelta: value.delta.message.content.text }); return; } case "tool-call-start": { pendingToolCallDelta = { toolCallId: value.delta.message.tool_calls.id, toolName: value.delta.message.tool_calls.function.name, argsTextDelta: value.delta.message.tool_calls.function.arguments }; controller.enqueue({ type: "tool-call-delta", toolCallId: pendingToolCallDelta.toolCallId, toolName: pendingToolCallDelta.toolName, toolCallType: "function", argsTextDelta: pendingToolCallDelta.argsTextDelta }); return; } case "tool-call-delta": { pendingToolCallDelta.argsTextDelta += value.delta.message.tool_calls.function.arguments; controller.enqueue({ type: "tool-call-delta", toolCallId: pendingToolCallDelta.toolCallId, toolName: pendingToolCallDelta.toolName, toolCallType: "function", argsTextDelta: value.delta.message.tool_calls.function.arguments }); return; } case "tool-call-end": { controller.enqueue({ type: "tool-call", toolCallId: pendingToolCallDelta.toolCallId, toolName: pendingToolCallDelta.toolName, toolCallType: "function", args: JSON.stringify( JSON.parse( ((_a = pendingToolCallDelta.argsTextDelta) == null ? void 0 : _a.trim()) || "{}" ) ) }); pendingToolCallDelta = { toolCallId: "", toolName: "", argsTextDelta: "" }; return; } case "message-start": { controller.enqueue({ type: "response-metadata", id: (_b = value.id) != null ? _b : void 0 }); return; } case "message-end": { finishReason = mapCohereFinishReason(value.delta.finish_reason); const tokens = value.delta.usage.tokens; usage = { promptTokens: tokens.input_tokens, completionTokens: tokens.output_tokens }; } default: { return; } } }, flush(controller) { controller.enqueue({ type: "finish", finishReason, usage }); } }) ), rawCall: { rawPrompt: { messages }, rawSettings }, rawResponse: { headers: responseHeaders }, warnings, request: { body: JSON.stringify({ ...args, stream: true }) } }; } }; var cohereChatResponseSchema = import_zod2.z.object({ generation_id: import_zod2.z.string().nullish(), message: import_zod2.z.object({ role: import_zod2.z.string(), content: import_zod2.z.array( import_zod2.z.object({ type: import_zod2.z.string(), text: import_zod2.z.string() }) ).nullish(), tool_plan: import_zod2.z.string().nullish(), tool_calls: import_zod2.z.array( import_zod2.z.object({ id: import_zod2.z.string(), type: import_zod2.z.literal("function"), function: import_zod2.z.object({ name: import_zod2.z.string(), arguments: import_zod2.z.string() }) }) ).nullish() }), finish_reason: import_zod2.z.string(), usage: import_zod2.z.object({ billed_units: import_zod2.z.object({ input_tokens: import_zod2.z.number(), output_tokens: import_zod2.z.number() }), tokens: import_zod2.z.object({ input_tokens: import_zod2.z.number(), output_tokens: import_zod2.z.number() }) }) }); var cohereChatChunkSchema = import_zod2.z.discriminatedUnion("type", [ import_zod2.z.object({ type: import_zod2.z.literal("citation-start") }), import_zod2.z.object({ type: import_zod2.z.literal("citation-end") }), import_zod2.z.object({ type: import_zod2.z.literal("content-start") }), import_zod2.z.object({ type: import_zod2.z.literal("content-delta"), delta: import_zod2.z.object({ message: import_zod2.z.object({ content: import_zod2.z.object({ text: import_zod2.z.string() }) }) }) }), import_zod2.z.object({ type: import_zod2.z.literal("content-end") }), import_zod2.z.object({ type: import_zod2.z.literal("message-start"), id: import_zod2.z.string().nullish() }), import_zod2.z.object({ type: import_zod2.z.literal("message-end"), delta: import_zod2.z.object({ finish_reason: import_zod2.z.string(), usage: import_zod2.z.object({ tokens: import_zod2.z.object({ input_tokens: import_zod2.z.number(), output_tokens: import_zod2.z.number() }) }) }) }), // https://docs.cohere.com/v2/docs/streaming#tool-use-stream-events-for-tool-calling import_zod2.z.object({ type: import_zod2.z.literal("tool-plan-delta"), delta: import_zod2.z.object({ message: import_zod2.z.object({ tool_plan: import_zod2.z.string() }) }) }), import_zod2.z.object({ type: import_zod2.z.literal("tool-call-start"), delta: import_zod2.z.object({ message: import_zod2.z.object({ tool_calls: import_zod2.z.object({ id: import_zod2.z.string(), type: import_zod2.z.literal("function"), function: import_zod2.z.object({ name: import_zod2.z.string(), arguments: import_zod2.z.string() }) }) }) }) }), // A single tool call's `arguments` stream in chunks and must be accumulated // in a string and so the full tool object info can only be parsed once we see // `tool-call-end`. import_zod2.z.object({ type: import_zod2.z.literal("tool-call-delta"), delta: import_zod2.z.object({ message: import_zod2.z.object({ tool_calls: import_zod2.z.object({ function: import_zod2.z.object({ arguments: import_zod2.z.string() }) }) }) }) }), import_zod2.z.object({ type: import_zod2.z.literal("tool-call-end") }) ]); // src/cohere-embedding-model.ts var import_provider4 = require("@ai-sdk/provider"); var import_provider_utils3 = require("@ai-sdk/provider-utils"); var import_zod3 = require("zod"); var CohereEmbeddingModel = class { constructor(modelId, settings, config) { this.specificationVersion = "v1"; this.maxEmbeddingsPerCall = 96; this.supportsParallelCalls = true; this.modelId = modelId; this.settings = settings; this.config = config; } get provider() { return this.config.provider; } async doEmbed({ values, headers, abortSignal }) { var _a; if (values.length > this.maxEmbeddingsPerCall) { throw new import_provider4.TooManyEmbeddingValuesForCallError({ provider: this.provider, modelId: this.modelId, maxEmbeddingsPerCall: this.maxEmbeddingsPerCall, values }); } const { responseHeaders, value: response } = await (0, import_provider_utils3.postJsonToApi)({ url: `${this.config.baseURL}/embed`, headers: (0, import_provider_utils3.combineHeaders)(this.config.headers(), headers), body: { model: this.modelId, // The AI SDK only supports 'float' embeddings which are also the only ones // the Cohere API docs state are supported for all models. // https://docs.cohere.com/v2/reference/embed#request.body.embedding_types embedding_types: ["float"], texts: values, input_type: (_a = this.settings.inputType) != null ? _a : "search_query", truncate: this.settings.truncate }, failedResponseHandler: cohereFailedResponseHandler, successfulResponseHandler: (0, import_provider_utils3.createJsonResponseHandler)( cohereTextEmbeddingResponseSchema ), abortSignal, fetch: this.config.fetch }); return { embeddings: response.embeddings.float, usage: { tokens: response.meta.billed_units.input_tokens }, rawResponse: { headers: responseHeaders } }; } }; var cohereTextEmbeddingResponseSchema = import_zod3.z.object({ embeddings: import_zod3.z.object({ float: import_zod3.z.array(import_zod3.z.array(import_zod3.z.number())) }), meta: import_zod3.z.object({ billed_units: import_zod3.z.object({ input_tokens: import_zod3.z.number() }) }) }); // src/cohere-provider.ts function createCohere(options = {}) { var _a; const baseURL = (_a = (0, import_provider_utils4.withoutTrailingSlash)(options.baseURL)) != null ? _a : "https://api.cohere.com/v2"; const getHeaders = () => ({ Authorization: `Bearer ${(0, import_provider_utils4.loadApiKey)({ apiKey: options.apiKey, environmentVariableName: "COHERE_API_KEY", description: "Cohere" })}`, ...options.headers }); const createChatModel = (modelId, settings = {}) => new CohereChatLanguageModel(modelId, settings, { provider: "cohere.chat", baseURL, headers: getHeaders, fetch: options.fetch }); const createTextEmbeddingModel = (modelId, settings = {}) => new CohereEmbeddingModel(modelId, settings, { provider: "cohere.textEmbedding", baseURL, headers: getHeaders, fetch: options.fetch }); const provider = function(modelId, settings) { if (new.target) { throw new Error( "The Cohere model function cannot be called with the new keyword." ); } return createChatModel(modelId, settings); }; provider.languageModel = createChatModel; provider.embedding = createTextEmbeddingModel; provider.textEmbeddingModel = createTextEmbeddingModel; return provider; } var cohere = createCohere(); // Annotate the CommonJS export names for ESM import in node: 0 && (module.exports = { cohere, createCohere }); //# sourceMappingURL=index.js.map