@ai-sdk/cohere
Version:
The **[Cohere provider](https://sdk.vercel.ai/providers/ai-sdk-providers/cohere)** for the [AI SDK](https://sdk.vercel.ai/docs) contains language model support for the Cohere API.
697 lines (687 loc) • 21.9 kB
JavaScript
"use strict";
var __defProp = Object.defineProperty;
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
var __getOwnPropNames = Object.getOwnPropertyNames;
var __hasOwnProp = Object.prototype.hasOwnProperty;
var __export = (target, all) => {
for (var name in all)
__defProp(target, name, { get: all[name], enumerable: true });
};
var __copyProps = (to, from, except, desc) => {
if (from && typeof from === "object" || typeof from === "function") {
for (let key of __getOwnPropNames(from))
if (!__hasOwnProp.call(to, key) && key !== except)
__defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable });
}
return to;
};
var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod);
// src/index.ts
var src_exports = {};
__export(src_exports, {
cohere: () => cohere,
createCohere: () => createCohere
});
module.exports = __toCommonJS(src_exports);
// src/cohere-provider.ts
var import_provider_utils4 = require("@ai-sdk/provider-utils");
// src/cohere-chat-language-model.ts
var import_provider3 = require("@ai-sdk/provider");
var import_provider_utils2 = require("@ai-sdk/provider-utils");
var import_zod2 = require("zod");
// src/cohere-error.ts
var import_provider_utils = require("@ai-sdk/provider-utils");
var import_zod = require("zod");
var cohereErrorDataSchema = import_zod.z.object({
message: import_zod.z.string()
});
var cohereFailedResponseHandler = (0, import_provider_utils.createJsonErrorResponseHandler)({
errorSchema: cohereErrorDataSchema,
errorToMessage: (data) => data.message
});
// src/convert-to-cohere-chat-prompt.ts
var import_provider = require("@ai-sdk/provider");
function convertToCohereChatPrompt(prompt) {
const messages = [];
for (const { role, content } of prompt) {
switch (role) {
case "system": {
messages.push({ role: "system", content });
break;
}
case "user": {
messages.push({
role: "user",
content: content.map((part) => {
switch (part.type) {
case "text": {
return part.text;
}
case "image": {
throw new import_provider.UnsupportedFunctionalityError({
functionality: "image-part"
});
}
}
}).join("")
});
break;
}
case "assistant": {
let text = "";
const toolCalls = [];
for (const part of content) {
switch (part.type) {
case "text": {
text += part.text;
break;
}
case "tool-call": {
toolCalls.push({
id: part.toolCallId,
type: "function",
function: {
name: part.toolName,
arguments: JSON.stringify(part.args)
}
});
break;
}
}
}
messages.push({
role: "assistant",
content: toolCalls.length > 0 ? void 0 : text,
tool_calls: toolCalls.length > 0 ? toolCalls : void 0,
tool_plan: void 0
});
break;
}
case "tool": {
messages.push(
...content.map((toolResult) => ({
role: "tool",
content: JSON.stringify(toolResult.result),
tool_call_id: toolResult.toolCallId
}))
);
break;
}
default: {
const _exhaustiveCheck = role;
throw new Error(`Unsupported role: ${_exhaustiveCheck}`);
}
}
}
return messages;
}
// src/map-cohere-finish-reason.ts
function mapCohereFinishReason(finishReason) {
switch (finishReason) {
case "COMPLETE":
case "STOP_SEQUENCE":
return "stop";
case "MAX_TOKENS":
return "length";
case "ERROR":
return "error";
case "TOOL_CALL":
return "tool-calls";
default:
return "unknown";
}
}
// src/cohere-prepare-tools.ts
var import_provider2 = require("@ai-sdk/provider");
function prepareTools(mode) {
var _a;
const tools = ((_a = mode.tools) == null ? void 0 : _a.length) ? mode.tools : void 0;
const toolWarnings = [];
if (tools == null) {
return { tools: void 0, toolChoice: void 0, toolWarnings };
}
const cohereTools = [];
for (const tool of tools) {
if (tool.type === "provider-defined") {
toolWarnings.push({ type: "unsupported-tool", tool });
} else {
cohereTools.push({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters
}
});
}
}
const toolChoice = mode.toolChoice;
if (toolChoice == null) {
return { tools: cohereTools, toolChoice: void 0, toolWarnings };
}
const type = toolChoice.type;
switch (type) {
case "auto":
return { tools: cohereTools, toolChoice: void 0, toolWarnings };
case "none":
return { tools: cohereTools, toolChoice: "NONE", toolWarnings };
case "required":
return { tools: cohereTools, toolChoice: "REQUIRED", toolWarnings };
case "tool":
return {
tools: cohereTools.filter(
(tool) => tool.function.name === toolChoice.toolName
),
toolChoice: "REQUIRED",
toolWarnings
};
default: {
const _exhaustiveCheck = type;
throw new import_provider2.UnsupportedFunctionalityError({
functionality: `Unsupported tool choice type: ${_exhaustiveCheck}`
});
}
}
}
// src/cohere-chat-language-model.ts
var CohereChatLanguageModel = class {
constructor(modelId, settings, config) {
this.specificationVersion = "v1";
this.defaultObjectGenerationMode = "json";
this.modelId = modelId;
this.settings = settings;
this.config = config;
}
get provider() {
return this.config.provider;
}
getArgs({
mode,
prompt,
maxTokens,
temperature,
topP,
topK,
frequencyPenalty,
presencePenalty,
stopSequences,
responseFormat,
seed
}) {
var _a;
const type = mode.type;
const chatPrompt = convertToCohereChatPrompt(prompt);
const baseArgs = {
// model id:
model: this.modelId,
// standardized settings:
frequency_penalty: frequencyPenalty,
presence_penalty: presencePenalty,
max_tokens: maxTokens,
temperature,
p: topP,
k: topK,
seed,
stop_sequences: stopSequences,
// response format:
response_format: (responseFormat == null ? void 0 : responseFormat.type) === "json" ? { type: "json_object", json_schema: responseFormat.schema } : void 0,
// messages:
messages: chatPrompt
};
switch (type) {
case "regular": {
const { tools, toolChoice, toolWarnings } = prepareTools(mode);
return {
args: {
...baseArgs,
tools,
tool_choice: toolChoice
},
warnings: toolWarnings
};
}
case "object-json": {
return {
args: {
...baseArgs,
response_format: mode.schema == null ? { type: "json_object" } : { type: "json_object", json_schema: mode.schema }
},
warnings: []
};
}
case "object-tool": {
return {
args: {
...baseArgs,
tools: [
{
type: "function",
function: {
name: mode.tool.name,
description: (_a = mode.tool.description) != null ? _a : "",
parameters: mode.tool.parameters
}
}
],
tool_choice: "REQUIRED"
},
warnings: []
};
}
default: {
const _exhaustiveCheck = type;
throw new import_provider3.UnsupportedFunctionalityError({
functionality: `Unsupported mode: ${_exhaustiveCheck}`
});
}
}
}
async doGenerate(options) {
var _a, _b, _c, _d;
const { args, warnings } = this.getArgs(options);
const {
responseHeaders,
value: response,
rawValue: rawResponse
} = await (0, import_provider_utils2.postJsonToApi)({
url: `${this.config.baseURL}/chat`,
headers: (0, import_provider_utils2.combineHeaders)(this.config.headers(), options.headers),
body: args,
failedResponseHandler: cohereFailedResponseHandler,
successfulResponseHandler: (0, import_provider_utils2.createJsonResponseHandler)(
cohereChatResponseSchema
),
abortSignal: options.abortSignal,
fetch: this.config.fetch
});
const { messages, ...rawSettings } = args;
const text = (_c = (_b = (_a = response.message.content) == null ? void 0 : _a[0]) == null ? void 0 : _b.text) != null ? _c : "";
return {
text,
toolCalls: response.message.tool_calls ? response.message.tool_calls.map((toolCall) => ({
toolCallId: toolCall.id,
toolName: toolCall.function.name,
// Cohere sometimes returns `null` for tool call arguments for tools
// defined as having no arguments.
args: toolCall.function.arguments.replace(/^null$/, "{}"),
toolCallType: "function"
})) : [],
finishReason: mapCohereFinishReason(response.finish_reason),
usage: {
promptTokens: response.usage.tokens.input_tokens,
completionTokens: response.usage.tokens.output_tokens
},
rawCall: {
rawPrompt: {
messages
},
rawSettings
},
response: {
id: (_d = response.generation_id) != null ? _d : void 0
},
rawResponse: {
headers: responseHeaders,
body: rawResponse
},
warnings,
request: { body: JSON.stringify(args) }
};
}
async doStream(options) {
const { args, warnings } = this.getArgs(options);
const { responseHeaders, value: response } = await (0, import_provider_utils2.postJsonToApi)({
url: `${this.config.baseURL}/chat`,
headers: (0, import_provider_utils2.combineHeaders)(this.config.headers(), options.headers),
body: { ...args, stream: true },
failedResponseHandler: cohereFailedResponseHandler,
successfulResponseHandler: (0, import_provider_utils2.createEventSourceResponseHandler)(
cohereChatChunkSchema
),
abortSignal: options.abortSignal,
fetch: this.config.fetch
});
const { messages, ...rawSettings } = args;
let finishReason = "unknown";
let usage = {
promptTokens: Number.NaN,
completionTokens: Number.NaN
};
let pendingToolCallDelta = {
toolCallId: "",
toolName: "",
argsTextDelta: ""
};
return {
stream: response.pipeThrough(
new TransformStream({
transform(chunk, controller) {
var _a, _b;
if (!chunk.success) {
finishReason = "error";
controller.enqueue({ type: "error", error: chunk.error });
return;
}
const value = chunk.value;
const type = value.type;
switch (type) {
case "content-delta": {
controller.enqueue({
type: "text-delta",
textDelta: value.delta.message.content.text
});
return;
}
case "tool-call-start": {
pendingToolCallDelta = {
toolCallId: value.delta.message.tool_calls.id,
toolName: value.delta.message.tool_calls.function.name,
argsTextDelta: value.delta.message.tool_calls.function.arguments
};
controller.enqueue({
type: "tool-call-delta",
toolCallId: pendingToolCallDelta.toolCallId,
toolName: pendingToolCallDelta.toolName,
toolCallType: "function",
argsTextDelta: pendingToolCallDelta.argsTextDelta
});
return;
}
case "tool-call-delta": {
pendingToolCallDelta.argsTextDelta += value.delta.message.tool_calls.function.arguments;
controller.enqueue({
type: "tool-call-delta",
toolCallId: pendingToolCallDelta.toolCallId,
toolName: pendingToolCallDelta.toolName,
toolCallType: "function",
argsTextDelta: value.delta.message.tool_calls.function.arguments
});
return;
}
case "tool-call-end": {
controller.enqueue({
type: "tool-call",
toolCallId: pendingToolCallDelta.toolCallId,
toolName: pendingToolCallDelta.toolName,
toolCallType: "function",
args: JSON.stringify(
JSON.parse(
((_a = pendingToolCallDelta.argsTextDelta) == null ? void 0 : _a.trim()) || "{}"
)
)
});
pendingToolCallDelta = {
toolCallId: "",
toolName: "",
argsTextDelta: ""
};
return;
}
case "message-start": {
controller.enqueue({
type: "response-metadata",
id: (_b = value.id) != null ? _b : void 0
});
return;
}
case "message-end": {
finishReason = mapCohereFinishReason(value.delta.finish_reason);
const tokens = value.delta.usage.tokens;
usage = {
promptTokens: tokens.input_tokens,
completionTokens: tokens.output_tokens
};
}
default: {
return;
}
}
},
flush(controller) {
controller.enqueue({
type: "finish",
finishReason,
usage
});
}
})
),
rawCall: {
rawPrompt: {
messages
},
rawSettings
},
rawResponse: { headers: responseHeaders },
warnings,
request: { body: JSON.stringify({ ...args, stream: true }) }
};
}
};
var cohereChatResponseSchema = import_zod2.z.object({
generation_id: import_zod2.z.string().nullish(),
message: import_zod2.z.object({
role: import_zod2.z.string(),
content: import_zod2.z.array(
import_zod2.z.object({
type: import_zod2.z.string(),
text: import_zod2.z.string()
})
).nullish(),
tool_plan: import_zod2.z.string().nullish(),
tool_calls: import_zod2.z.array(
import_zod2.z.object({
id: import_zod2.z.string(),
type: import_zod2.z.literal("function"),
function: import_zod2.z.object({
name: import_zod2.z.string(),
arguments: import_zod2.z.string()
})
})
).nullish()
}),
finish_reason: import_zod2.z.string(),
usage: import_zod2.z.object({
billed_units: import_zod2.z.object({
input_tokens: import_zod2.z.number(),
output_tokens: import_zod2.z.number()
}),
tokens: import_zod2.z.object({
input_tokens: import_zod2.z.number(),
output_tokens: import_zod2.z.number()
})
})
});
var cohereChatChunkSchema = import_zod2.z.discriminatedUnion("type", [
import_zod2.z.object({
type: import_zod2.z.literal("citation-start")
}),
import_zod2.z.object({
type: import_zod2.z.literal("citation-end")
}),
import_zod2.z.object({
type: import_zod2.z.literal("content-start")
}),
import_zod2.z.object({
type: import_zod2.z.literal("content-delta"),
delta: import_zod2.z.object({
message: import_zod2.z.object({
content: import_zod2.z.object({
text: import_zod2.z.string()
})
})
})
}),
import_zod2.z.object({
type: import_zod2.z.literal("content-end")
}),
import_zod2.z.object({
type: import_zod2.z.literal("message-start"),
id: import_zod2.z.string().nullish()
}),
import_zod2.z.object({
type: import_zod2.z.literal("message-end"),
delta: import_zod2.z.object({
finish_reason: import_zod2.z.string(),
usage: import_zod2.z.object({
tokens: import_zod2.z.object({
input_tokens: import_zod2.z.number(),
output_tokens: import_zod2.z.number()
})
})
})
}),
// https://docs.cohere.com/v2/docs/streaming#tool-use-stream-events-for-tool-calling
import_zod2.z.object({
type: import_zod2.z.literal("tool-plan-delta"),
delta: import_zod2.z.object({
message: import_zod2.z.object({
tool_plan: import_zod2.z.string()
})
})
}),
import_zod2.z.object({
type: import_zod2.z.literal("tool-call-start"),
delta: import_zod2.z.object({
message: import_zod2.z.object({
tool_calls: import_zod2.z.object({
id: import_zod2.z.string(),
type: import_zod2.z.literal("function"),
function: import_zod2.z.object({
name: import_zod2.z.string(),
arguments: import_zod2.z.string()
})
})
})
})
}),
// A single tool call's `arguments` stream in chunks and must be accumulated
// in a string and so the full tool object info can only be parsed once we see
// `tool-call-end`.
import_zod2.z.object({
type: import_zod2.z.literal("tool-call-delta"),
delta: import_zod2.z.object({
message: import_zod2.z.object({
tool_calls: import_zod2.z.object({
function: import_zod2.z.object({
arguments: import_zod2.z.string()
})
})
})
})
}),
import_zod2.z.object({
type: import_zod2.z.literal("tool-call-end")
})
]);
// src/cohere-embedding-model.ts
var import_provider4 = require("@ai-sdk/provider");
var import_provider_utils3 = require("@ai-sdk/provider-utils");
var import_zod3 = require("zod");
var CohereEmbeddingModel = class {
constructor(modelId, settings, config) {
this.specificationVersion = "v1";
this.maxEmbeddingsPerCall = 96;
this.supportsParallelCalls = true;
this.modelId = modelId;
this.settings = settings;
this.config = config;
}
get provider() {
return this.config.provider;
}
async doEmbed({
values,
headers,
abortSignal
}) {
var _a;
if (values.length > this.maxEmbeddingsPerCall) {
throw new import_provider4.TooManyEmbeddingValuesForCallError({
provider: this.provider,
modelId: this.modelId,
maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
values
});
}
const { responseHeaders, value: response } = await (0, import_provider_utils3.postJsonToApi)({
url: `${this.config.baseURL}/embed`,
headers: (0, import_provider_utils3.combineHeaders)(this.config.headers(), headers),
body: {
model: this.modelId,
// The AI SDK only supports 'float' embeddings which are also the only ones
// the Cohere API docs state are supported for all models.
// https://docs.cohere.com/v2/reference/embed#request.body.embedding_types
embedding_types: ["float"],
texts: values,
input_type: (_a = this.settings.inputType) != null ? _a : "search_query",
truncate: this.settings.truncate
},
failedResponseHandler: cohereFailedResponseHandler,
successfulResponseHandler: (0, import_provider_utils3.createJsonResponseHandler)(
cohereTextEmbeddingResponseSchema
),
abortSignal,
fetch: this.config.fetch
});
return {
embeddings: response.embeddings.float,
usage: { tokens: response.meta.billed_units.input_tokens },
rawResponse: { headers: responseHeaders }
};
}
};
var cohereTextEmbeddingResponseSchema = import_zod3.z.object({
embeddings: import_zod3.z.object({
float: import_zod3.z.array(import_zod3.z.array(import_zod3.z.number()))
}),
meta: import_zod3.z.object({
billed_units: import_zod3.z.object({
input_tokens: import_zod3.z.number()
})
})
});
// src/cohere-provider.ts
function createCohere(options = {}) {
var _a;
const baseURL = (_a = (0, import_provider_utils4.withoutTrailingSlash)(options.baseURL)) != null ? _a : "https://api.cohere.com/v2";
const getHeaders = () => ({
Authorization: `Bearer ${(0, import_provider_utils4.loadApiKey)({
apiKey: options.apiKey,
environmentVariableName: "COHERE_API_KEY",
description: "Cohere"
})}`,
...options.headers
});
const createChatModel = (modelId, settings = {}) => new CohereChatLanguageModel(modelId, settings, {
provider: "cohere.chat",
baseURL,
headers: getHeaders,
fetch: options.fetch
});
const createTextEmbeddingModel = (modelId, settings = {}) => new CohereEmbeddingModel(modelId, settings, {
provider: "cohere.textEmbedding",
baseURL,
headers: getHeaders,
fetch: options.fetch
});
const provider = function(modelId, settings) {
if (new.target) {
throw new Error(
"The Cohere model function cannot be called with the new keyword."
);
}
return createChatModel(modelId, settings);
};
provider.languageModel = createChatModel;
provider.embedding = createTextEmbeddingModel;
provider.textEmbeddingModel = createTextEmbeddingModel;
return provider;
}
var cohere = createCohere();
// Annotate the CommonJS export names for ESM import in node:
0 && (module.exports = {
cohere,
createCohere
});
//# sourceMappingURL=index.js.map