ai.libx.js
Version:
Unified API bridge for various AI models (LLMs, image/video generation, TTS, STT) - stateless, edge-compatible
160 lines (136 loc) • 4.4 kB
text/typescript
import { BaseAdapter } from './base/BaseAdapter';
import { ChatOptions, ChatResponse, StreamChunk, Message } from '../types';
import { streamLines } from '../utils/stream';
import { handleProviderError } from '../utils/errors';
import { contentToString } from '../utils/content-helpers';
interface CohereChatMessage {
role: 'USER' | 'CHATBOT' | 'SYSTEM';
message: string;
}
interface CohereRequest {
model?: string;
message: string;
chat_history?: Array<{ role: 'USER' | 'CHATBOT'; message: string; }>;
preamble?: string;
temperature?: number;
max_tokens?: number;
p?: number;
stop_sequences?: string[];
stream?: boolean;
}
/**
* Cohere API adapter
*/
export class CohereAdapter extends BaseAdapter {
get name(): string {
return 'cohere';
}
async chat(options: ChatOptions): Promise<ChatResponse | AsyncIterable<StreamChunk>> {
try {
const apiKey = this.getApiKey(options);
const baseUrl = this.getBaseUrl('https://api.cohere.ai/v1');
// Strip provider prefix from model if present
const model = options.model.replace(/^cohere\//, '');
// Transform messages to Cohere format
const { message, chatHistory, preamble } = this.transformMessages(options.messages);
const request: CohereRequest = {
message,
stream: options.stream || false,
};
// Add model if specified
if (model) request.model = model;
// Add chat history
if (chatHistory.length > 0) {
request.chat_history = chatHistory;
}
// Add preamble (system message)
if (preamble) {
request.preamble = preamble;
}
// Add optional parameters
if (options.temperature !== undefined) request.temperature = options.temperature;
if (options.maxTokens !== undefined) request.max_tokens = options.maxTokens;
if (options.topP !== undefined) request.p = options.topP;
if (options.stop && Array.isArray(options.stop)) {
request.stop_sequences = options.stop;
}
// Merge provider-specific options
if (options.providerOptions) {
Object.assign(request, options.providerOptions);
}
const response = await this.fetchWithErrorHandling(
`${baseUrl}/chat`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${apiKey}`,
},
body: JSON.stringify(request),
},
this.name
);
if (options.stream) {
return this.handleStreamResponse(response, model);
}
return this.handleNonStreamResponse(await response.json(), model);
} catch (error) {
handleProviderError(error, this.name);
}
}
private transformMessages(messages: Message[]): {
message: string;
chatHistory: Array<{ role: 'USER' | 'CHATBOT'; message: string; }>;
preamble?: string;
} {
const systemMessage = messages.find((m) => m.role === 'system');
const nonSystemMessages = messages.filter((m) => m.role !== 'system');
// Last message is the current message
const lastMessage = nonSystemMessages[nonSystemMessages.length - 1];
const historyMessages = nonSystemMessages.slice(0, -1);
return {
message: lastMessage ? contentToString(lastMessage.content) : '',
chatHistory: historyMessages.map((msg) => ({
role: msg.role === 'user' ? 'USER' : 'CHATBOT',
message: contentToString(msg.content),
})),
preamble: systemMessage ? contentToString(systemMessage.content) : undefined,
};
}
private handleNonStreamResponse(data: any, model: string): ChatResponse {
return {
content: data.text || '',
finishReason: data.finish_reason,
usage: data.meta?.tokens ? {
promptTokens: data.meta.tokens.input_tokens || 0,
completionTokens: data.meta.tokens.output_tokens || 0,
totalTokens: (data.meta.tokens.input_tokens || 0) + (data.meta.tokens.output_tokens || 0),
} : undefined,
model,
raw: data,
};
}
private async *handleStreamResponse(response: Response, model: string): AsyncIterable<StreamChunk> {
if (!response.body) {
throw new Error('No response body for streaming');
}
for await (const line of streamLines(response.body)) {
try {
const chunk = JSON.parse(line);
if (chunk.event_type === 'text-generation') {
yield {
content: chunk.text || '',
};
} else if (chunk.event_type === 'stream-end') {
yield {
content: '',
finishReason: chunk.finish_reason,
};
}
} catch (e) {
// Skip invalid JSON
continue;
}
}
}
}