@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
164 lines (147 loc) • 5.54 kB
text/typescript
import { ChatModelCard } from '@/types/llm';
import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatMethodOptions, ChatStreamPayload, ModelProvider } from '../types';
import {
CloudflareStreamTransformer,
DEFAULT_BASE_URL_PREFIX,
desensitizeCloudflareUrl,
fillUrl,
} from '../utils/cloudflareHelpers';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { StreamingResponse } from '../utils/response';
import { createCallbacksTransformer } from '../utils/streams';
export interface CloudflareModelCard {
description: string;
name: string;
properties?: Record<string, string>;
task?: {
description?: string;
name: string;
};
}
export interface LobeCloudflareParams {
apiKey?: string;
baseURLOrAccountID?: string;
}
export class LobeCloudflareAI implements LobeRuntimeAI {
baseURL: string;
accountID: string;
apiKey?: string;
constructor({ apiKey, baseURLOrAccountID }: LobeCloudflareParams = {}) {
if (!baseURLOrAccountID) {
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
}
if (baseURLOrAccountID.startsWith('http')) {
this.baseURL = baseURLOrAccountID.endsWith('/')
? baseURLOrAccountID
: baseURLOrAccountID + '/';
// Try get accountID from baseURL
this.accountID = baseURLOrAccountID.replaceAll(/^.*\/([\dA-Fa-f]{32})\/.*$/g, '$1');
} else {
if (!apiKey) {
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
}
this.accountID = baseURLOrAccountID;
this.baseURL = fillUrl(baseURLOrAccountID);
}
this.apiKey = apiKey;
}
async chat(payload: ChatStreamPayload, options?: ChatMethodOptions): Promise<Response> {
try {
const { model, tools, ...restPayload } = payload;
const functions = tools?.map((tool) => tool.function);
const headers = options?.headers || {};
if (this.apiKey) {
headers['Authorization'] = `Bearer ${this.apiKey}`;
}
const url = new URL(model, this.baseURL);
const response = await fetch(url, {
body: JSON.stringify({ tools: functions, ...restPayload }),
headers: { 'Content-Type': 'application/json', ...headers },
method: 'POST',
signal: options?.signal,
});
const desensitizedEndpoint = desensitizeCloudflareUrl(url.toString());
switch (response.status) {
case 400: {
throw AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: response,
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Cloudflare,
});
}
}
// Only tee when debugging
let responseBody: ReadableStream;
if (process.env.DEBUG_CLOUDFLARE_CHAT_COMPLETION === '1') {
const [prod, useForDebug] = response.body!.tee();
debugStream(useForDebug).catch();
responseBody = prod;
} else {
responseBody = response.body!;
}
return StreamingResponse(
responseBody
.pipeThrough(new TransformStream(new CloudflareStreamTransformer()))
.pipeThrough(createCallbacksTransformer(options?.callback)),
{ headers: options?.headers },
);
} catch (error) {
const desensitizedEndpoint = desensitizeCloudflareUrl(this.baseURL);
throw AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: error as any,
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Cloudflare,
});
}
}
async models(): Promise<ChatModelCard[]> {
const { LOBE_DEFAULT_MODEL_LIST } = await import('@/config/aiModels');
const url = `${DEFAULT_BASE_URL_PREFIX}/client/v4/accounts/${this.accountID}/ai/models/search`;
const response = await fetch(url, {
headers: {
'Authorization': `Bearer ${this.apiKey}`,
'Content-Type': 'application/json',
},
method: 'GET',
});
const json = await response.json();
const modelList: CloudflareModelCard[] = json.result;
return modelList
.map((model) => {
const knownModel = LOBE_DEFAULT_MODEL_LIST.find(
(m) => model.name.toLowerCase() === m.id.toLowerCase(),
);
return {
contextWindowTokens: model.properties?.max_total_tokens
? Number(model.properties.max_total_tokens)
: (knownModel?.contextWindowTokens ?? undefined),
displayName:
knownModel?.displayName ??
(model.properties?.['beta'] === 'true' ? `${model.name} (Beta)` : undefined),
enabled: knownModel?.enabled || false,
functionCall:
model.description.toLowerCase().includes('function call') ||
model.properties?.['function_calling'] === 'true' ||
knownModel?.abilities?.functionCall ||
false,
id: model.name,
reasoning:
model.name.toLowerCase().includes('deepseek-r1') ||
knownModel?.abilities?.reasoning ||
false,
vision:
model.name.toLowerCase().includes('vision') ||
model.task?.name.toLowerCase().includes('image-to-text') ||
model.description.toLowerCase().includes('vision') ||
knownModel?.abilities?.vision ||
false,
};
})
.filter(Boolean) as ChatModelCard[];
}
}