UNPKG

ai.libx.js

Version:

Unified API bridge for various AI models (LLMs, image/video generation, TTS, STT) - stateless, edge-compatible

160 lines (136 loc) 4.4 kB
import { BaseAdapter } from './base/BaseAdapter'; import { ChatOptions, ChatResponse, StreamChunk, Message } from '../types'; import { streamLines } from '../utils/stream'; import { handleProviderError } from '../utils/errors'; import { contentToString } from '../utils/content-helpers'; interface CohereChatMessage { role: 'USER' | 'CHATBOT' | 'SYSTEM'; message: string; } interface CohereRequest { model?: string; message: string; chat_history?: Array<{ role: 'USER' | 'CHATBOT'; message: string; }>; preamble?: string; temperature?: number; max_tokens?: number; p?: number; stop_sequences?: string[]; stream?: boolean; } /** * Cohere API adapter */ export class CohereAdapter extends BaseAdapter { get name(): string { return 'cohere'; } async chat(options: ChatOptions): Promise<ChatResponse | AsyncIterable<StreamChunk>> { try { const apiKey = this.getApiKey(options); const baseUrl = this.getBaseUrl('https://api.cohere.ai/v1'); // Strip provider prefix from model if present const model = options.model.replace(/^cohere\//, ''); // Transform messages to Cohere format const { message, chatHistory, preamble } = this.transformMessages(options.messages); const request: CohereRequest = { message, stream: options.stream || false, }; // Add model if specified if (model) request.model = model; // Add chat history if (chatHistory.length > 0) { request.chat_history = chatHistory; } // Add preamble (system message) if (preamble) { request.preamble = preamble; } // Add optional parameters if (options.temperature !== undefined) request.temperature = options.temperature; if (options.maxTokens !== undefined) request.max_tokens = options.maxTokens; if (options.topP !== undefined) request.p = options.topP; if (options.stop && Array.isArray(options.stop)) { request.stop_sequences = options.stop; } // Merge provider-specific options if (options.providerOptions) { Object.assign(request, options.providerOptions); } const response = await this.fetchWithErrorHandling( `${baseUrl}/chat`, { method: 'POST', headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer ${apiKey}`, }, body: JSON.stringify(request), }, this.name ); if (options.stream) { return this.handleStreamResponse(response, model); } return this.handleNonStreamResponse(await response.json(), model); } catch (error) { handleProviderError(error, this.name); } } private transformMessages(messages: Message[]): { message: string; chatHistory: Array<{ role: 'USER' | 'CHATBOT'; message: string; }>; preamble?: string; } { const systemMessage = messages.find((m) => m.role === 'system'); const nonSystemMessages = messages.filter((m) => m.role !== 'system'); // Last message is the current message const lastMessage = nonSystemMessages[nonSystemMessages.length - 1]; const historyMessages = nonSystemMessages.slice(0, -1); return { message: lastMessage ? contentToString(lastMessage.content) : '', chatHistory: historyMessages.map((msg) => ({ role: msg.role === 'user' ? 'USER' : 'CHATBOT', message: contentToString(msg.content), })), preamble: systemMessage ? contentToString(systemMessage.content) : undefined, }; } private handleNonStreamResponse(data: any, model: string): ChatResponse { return { content: data.text || '', finishReason: data.finish_reason, usage: data.meta?.tokens ? { promptTokens: data.meta.tokens.input_tokens || 0, completionTokens: data.meta.tokens.output_tokens || 0, totalTokens: (data.meta.tokens.input_tokens || 0) + (data.meta.tokens.output_tokens || 0), } : undefined, model, raw: data, }; } private async *handleStreamResponse(response: Response, model: string): AsyncIterable<StreamChunk> { if (!response.body) { throw new Error('No response body for streaming'); } for await (const line of streamLines(response.body)) { try { const chunk = JSON.parse(line); if (chunk.event_type === 'text-generation') { yield { content: chunk.text || '', }; } else if (chunk.event_type === 'stream-end') { yield { content: '', finishReason: chunk.finish_reason, }; } } catch (e) { // Skip invalid JSON continue; } } } }