@juspay/neurolink
Version:
Universal AI Development Platform with working MCP integration, multi-provider support, voice (TTS/STT/realtime), and professional CLI. 58+ external MCP servers discoverable, multimodal file processing, RAG pipelines. Build, test, and deploy AI applicatio
253 lines (252 loc) • 11.8 kB
JavaScript
import { createOpenAI } from "@ai-sdk/openai";
import { stepCountIs, streamText } from "ai";
import { CohereModels } from "../constants/enums.js";
import { BaseProvider } from "../core/baseProvider.js";
import { DEFAULT_MAX_STEPS } from "../core/constants.js";
import { streamAnalyticsCollector } from "../core/streamAnalytics.js";
import { isNeuroLink } from "../neurolink.js";
import { createLoggingFetch } from "../utils/loggingFetch.js";
import { tracers, ATTR, withClientStreamSpan } from "../telemetry/index.js";
import { AuthenticationError, InvalidModelError, NetworkError, ProviderError, RateLimitError, } from "../types/index.js";
import { logger } from "../utils/logger.js";
import { createCohereConfig, getProviderModel, validateApiKey, } from "../utils/providerConfig.js";
import { composeAbortSignals, createTimeoutController, TimeoutError, } from "../utils/timeout.js";
import { emitToolEndFromStepFinish } from "../utils/toolEndEmitter.js";
import { resolveToolChoice } from "../utils/toolChoice.js";
import { toAnalyticsStreamResult } from "./providerTypeUtils.js";
/**
* Cohere uses an OpenAI-compatible endpoint at /compatibility/v1 that
* accepts the same chat-completions shape. Embeddings + Rerank live on
* the native API and are not exposed through this LLM provider class
* (use the Cohere SDK directly or the embed/rerank routes when added).
*/
const COHERE_DEFAULT_BASE_URL = "https://api.cohere.com/compatibility/v1";
const getCohereApiKey = () => validateApiKey(createCohereConfig());
const getDefaultCohereModel = () => getProviderModel("COHERE_MODEL", CohereModels.COMMAND_R_PLUS);
/**
* Cohere Provider
*
* Routes Command R / Command R+ chat completions through Cohere's OpenAI-
* compatible endpoint. Embed v3 and Rerank v3 are top-tier for RAG but are
* accessed via the Cohere native SDK / dedicated embedding routes (out of
* scope for the LLM provider).
*
* @see https://docs.cohere.com/docs/compatibility-api
*/
export class CohereProvider extends BaseProvider {
model;
apiKey;
baseURL;
constructor(modelName, sdk, _region, credentials) {
const validatedNeurolink = isNeuroLink(sdk) ? sdk : undefined;
super(modelName, "cohere", validatedNeurolink);
const overrideApiKey = credentials?.apiKey?.trim();
this.apiKey =
overrideApiKey && overrideApiKey.length > 0
? overrideApiKey
: getCohereApiKey();
this.baseURL =
credentials?.baseURL ??
process.env.COHERE_BASE_URL ??
COHERE_DEFAULT_BASE_URL;
const cohere = createOpenAI({
apiKey: this.apiKey,
baseURL: this.baseURL,
fetch: createLoggingFetch("cohere"),
});
this.model = cohere.chat(this.modelName);
logger.debug("Cohere Provider initialized", {
modelName: this.modelName,
providerName: this.providerName,
baseURL: this.baseURL,
});
}
async executeStream(options, _analysisSchema) {
// withClientStreamSpan: keeps the span open until the consumer reaches
// end-of-stream / error, so the recorded duration reflects the actual
// stream lifetime instead of just setup.
return withClientStreamSpan({
name: "neurolink.provider.stream",
tracer: tracers.provider,
attributes: {
[ATTR.GEN_AI_SYSTEM]: "cohere",
[ATTR.GEN_AI_MODEL]: this.modelName,
[ATTR.GEN_AI_OPERATION]: "stream",
[ATTR.NL_STREAM_MODE]: true,
},
}, async () => this.executeStreamInner(options), (r) => r.stream, (r, wrapped) => ({ ...r, stream: wrapped }));
}
async executeStreamInner(options) {
this.validateStreamOptions(options);
// Resolve per-call credentials first, then fall back to instance-level.
const perCallCreds = options.credentials?.cohere;
const effectiveApiKey = perCallCreds?.apiKey?.trim() || this.apiKey;
const effectiveBaseURL = perCallCreds?.baseURL || this.baseURL;
const startTime = Date.now();
const timeout = this.getTimeout(options);
const timeoutController = createTimeoutController(timeout, this.providerName, "stream");
try {
const shouldUseTools = !options.disableTools && this.supportsTools();
const tools = shouldUseTools
? options.tools || (await this.getAllTools())
: {};
const messages = await this.buildMessagesForStream(options);
// When per-call credentials differ from instance, build a fresh client.
const hasDifferentCreds = effectiveApiKey !== this.apiKey || effectiveBaseURL !== this.baseURL;
const model = hasDifferentCreds
? createOpenAI({
apiKey: effectiveApiKey,
baseURL: effectiveBaseURL,
fetch: createLoggingFetch("cohere"),
}).chat(this.modelName)
: await this.getAISDKModelWithMiddleware(options);
const result = await streamText({
model,
messages,
temperature: options.temperature,
maxOutputTokens: options.maxTokens,
tools,
stopWhen: stepCountIs(options.maxSteps || DEFAULT_MAX_STEPS),
toolChoice: resolveToolChoice(options, tools, shouldUseTools),
abortSignal: composeAbortSignals(options.abortSignal, timeoutController?.controller.signal),
experimental_telemetry: this.telemetryHandler.getTelemetryConfig(options),
experimental_repairToolCall: this.getToolCallRepairFn(options),
onStepFinish: ({ toolCalls, toolResults }) => {
emitToolEndFromStepFinish(this.neurolink?.getEventEmitter(), toolResults);
this.handleToolExecutionStorage(toolCalls, toolResults, options, new Date()).catch((error) => {
logger.warn("[CohereProvider] Failed to store tool executions", {
provider: this.providerName,
error: error instanceof Error ? error.message : String(error),
});
});
},
});
timeoutController?.cleanup();
const transformedStream = this.createTextStream(result);
const analyticsPromise = streamAnalyticsCollector.createAnalytics(this.providerName, this.modelName, toAnalyticsStreamResult(result), Date.now() - startTime, {
requestId: `cohere-stream-${Date.now()}`,
streamingMode: true,
});
return {
stream: transformedStream,
provider: this.providerName,
model: this.modelName,
analytics: analyticsPromise,
metadata: { startTime, streamId: `cohere-${Date.now()}` },
};
}
catch (error) {
timeoutController?.cleanup();
throw this.handleProviderError(error);
}
}
getProviderName() {
return this.providerName;
}
getDefaultModel() {
return getDefaultCohereModel();
}
getAISDKModel() {
return this.model;
}
formatProviderError(error) {
if (error instanceof TimeoutError) {
return new NetworkError(`Request timed out: ${error.message}`, "cohere");
}
const errorRecord = error;
const message = typeof errorRecord?.message === "string"
? errorRecord.message
: "Unknown error";
if (message.includes("invalid api token") ||
message.includes("Authentication") ||
message.includes("401") ||
message.includes("invalid_api_token")) {
return new AuthenticationError("Invalid Cohere API key. Check COHERE_API_KEY. Get one at https://dashboard.cohere.com/api-keys", "cohere");
}
if (message.includes("rate limit") || message.includes("429")) {
return new RateLimitError("Cohere rate limit exceeded. Back off and retry.", "cohere");
}
if (message.includes("model_not_found") || message.includes("404")) {
return new InvalidModelError(`Cohere model '${this.modelName}' not found. Use command-r, command-r-plus, or command-r7b-12-2024.`, "cohere");
}
if (message.includes("trial limit") || message.includes("trial_limit")) {
return new ProviderError("Cohere trial usage limit exceeded. Upgrade at https://dashboard.cohere.com/billing.", "cohere");
}
return new ProviderError(`Cohere error: ${message}`, "cohere");
}
async validateConfiguration() {
return typeof this.apiKey === "string" && this.apiKey.trim().length > 0;
}
getConfiguration() {
return {
provider: this.providerName,
model: this.modelName,
defaultModel: getDefaultCohereModel(),
baseURL: this.baseURL,
};
}
/**
* Default embedding model for Cohere.
*/
getDefaultEmbeddingModel() {
return CohereModels.EMBED_ENGLISH_V3;
}
/**
* Generate an embedding for a single text via Cohere's native /v2/embed
* endpoint. Returns the float[] embedding vector.
*
* The shared OpenAI-compatible /compatibility/v1 path is chat-only; embed
* lives on the native API (POST /v2/embed). Documented at
* https://docs.cohere.com/reference/embed.
*/
async embed(text, modelName) {
const vectors = await this.embedMany([text], modelName);
if (!vectors[0]) {
throw new ProviderError("Cohere /v2/embed returned no embeddings.", "cohere");
}
return vectors[0];
}
/**
* Batch embedding via Cohere's native /v2/embed endpoint. Cohere caps at
* 96 inputs per request; larger batches are chunked.
*/
async embedMany(texts, modelName) {
if (texts.length === 0) {
return [];
}
const model = modelName ?? this.getDefaultEmbeddingModel();
const baseUrl = this.baseURL.replace(/\/compatibility\/v\d+\/?$/, "");
const url = `${baseUrl}/v2/embed`;
const BATCH_SIZE = 96;
const results = [];
for (let i = 0; i < texts.length; i += BATCH_SIZE) {
const batch = texts.slice(i, i + BATCH_SIZE);
const response = await fetch(url, {
method: "POST",
headers: {
Authorization: `Bearer ${this.apiKey}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
model,
texts: batch,
input_type: "search_document",
embedding_types: ["float"],
}),
});
if (!response.ok) {
const body = await response.text();
throw this.formatProviderError(new Error(`Cohere /v2/embed failed: ${response.status} — ${body.slice(0, 500)}`));
}
const json = (await response.json());
const floatVecs = json.embeddings?.float ??
(Array.isArray(json.embeddings) ? json.embeddings : undefined);
if (!floatVecs || floatVecs.length !== batch.length) {
throw new ProviderError(`Cohere /v2/embed returned ${floatVecs?.length ?? 0} embeddings for ${batch.length} inputs.`, "cohere");
}
results.push(...floatVecs);
}
return results;
}
}
export default CohereProvider;