@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
579 lines (491 loc) • 17.7 kB
text/typescript
import type { VertexAI } from '@google-cloud/vertexai';
import {
Content,
FunctionCallPart,
FunctionDeclaration,
Tool as GoogleFunctionCallTool,
GoogleGenerativeAI,
GoogleSearchRetrievalTool,
Part,
SchemaType,
} from '@google/generative-ai';
import { imageUrlToBase64 } from '@/utils/imageToBase64';
import { safeParseJSON } from '@/utils/safeParseJSON';
import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType, ILobeAgentRuntimeErrorType } from '../error';
import {
ChatCompletionTool,
ChatMethodOptions,
ChatStreamPayload,
OpenAIChatMessage,
UserMessageContentPart,
} from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { StreamingResponse } from '../utils/response';
import { GoogleGenerativeAIStream, VertexAIStream } from '../utils/streams';
import { parseDataUri } from '../utils/uriParser';
const modelsOffSafetySettings = new Set(['gemini-2.0-flash-exp']);
const modelsWithModalities = new Set([
'gemini-2.0-flash-exp',
'gemini-2.0-flash-exp-image-generation',
'gemini-2.0-flash-preview-image-generation',
]);
const modelsDisableInstuction = new Set([
'gemini-2.0-flash-exp',
'gemini-2.0-flash-exp-image-generation',
'gemini-2.0-flash-preview-image-generation',
'gemma-3-1b-it',
'gemma-3-4b-it',
'gemma-3-12b-it',
'gemma-3-27b-it',
'gemma-3n-e4b-it',
]);
export interface GoogleModelCard {
displayName: string;
inputTokenLimit: number;
name: string;
outputTokenLimit: number;
}
enum HarmCategory {
HARM_CATEGORY_DANGEROUS_CONTENT = 'HARM_CATEGORY_DANGEROUS_CONTENT',
HARM_CATEGORY_HARASSMENT = 'HARM_CATEGORY_HARASSMENT',
HARM_CATEGORY_HATE_SPEECH = 'HARM_CATEGORY_HATE_SPEECH',
HARM_CATEGORY_SEXUALLY_EXPLICIT = 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
}
enum HarmBlockThreshold {
BLOCK_NONE = 'BLOCK_NONE',
}
function getThreshold(model: string): HarmBlockThreshold {
if (modelsOffSafetySettings.has(model)) {
return 'OFF' as HarmBlockThreshold; // https://discuss.ai.google.dev/t/59352
}
return HarmBlockThreshold.BLOCK_NONE;
}
const DEFAULT_BASE_URL = 'https://generativelanguage.googleapis.com';
interface LobeGoogleAIParams {
apiKey?: string;
baseURL?: string;
client?: GoogleGenerativeAI | VertexAI;
id?: string;
isVertexAi?: boolean;
}
interface GoogleAIThinkingConfig {
includeThoughts?: boolean;
thinkingBudget?: number;
}
const isAbortError = (error: Error): boolean => {
const message = error.message.toLowerCase();
return (
message.includes('aborted') ||
message.includes('cancelled') ||
message.includes('error reading from the stream') ||
message.includes('abort') ||
error.name === 'AbortError'
);
};
export class LobeGoogleAI implements LobeRuntimeAI {
private client: GoogleGenerativeAI;
private isVertexAi: boolean;
baseURL?: string;
apiKey?: string;
provider: string;
constructor({ apiKey, baseURL, client, isVertexAi, id }: LobeGoogleAIParams = {}) {
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
this.client = new GoogleGenerativeAI(apiKey);
this.apiKey = apiKey;
this.client = client ? (client as GoogleGenerativeAI) : new GoogleGenerativeAI(apiKey);
this.baseURL = client ? undefined : baseURL || DEFAULT_BASE_URL;
this.isVertexAi = isVertexAi || false;
this.provider = id || (isVertexAi ? 'vertexai' : 'google');
}
async chat(rawPayload: ChatStreamPayload, options?: ChatMethodOptions) {
try {
const payload = this.buildPayload(rawPayload);
const { model, thinkingBudget } = payload;
const thinkingConfig: GoogleAIThinkingConfig = {
includeThoughts:
!!thinkingBudget ||
(!thinkingBudget && model && (model.includes('-2.5-') || model.includes('thinking')))
? true
: undefined,
// https://ai.google.dev/gemini-api/docs/thinking#set-budget
thinkingBudget: (() => {
if (thinkingBudget !== undefined && thinkingBudget !== null) {
if (model.includes('-2.5-flash-lite')) {
if (thinkingBudget === 0 || thinkingBudget === -1) {
return thinkingBudget;
}
return Math.max(512, Math.min(thinkingBudget, 24_576));
} else if (model.includes('-2.5-flash')) {
return Math.min(thinkingBudget, 24_576);
} else if (model.includes('-2.5-pro')) {
return Math.max(128, Math.min(thinkingBudget, 32_768));
}
return Math.min(thinkingBudget, 24_576);
}
if (model.includes('-2.5-pro') || model.includes('-2.5-flash')) {
return -1;
} else if (model.includes('-2.5-flash-lite')) {
return 0;
}
return undefined;
})(),
};
const contents = await this.buildGoogleMessages(payload.messages);
const inputStartAt = Date.now();
const controller = new AbortController();
const originalSignal = options?.signal;
if (originalSignal) {
if (originalSignal.aborted) {
controller.abort();
} else {
originalSignal.addEventListener('abort', () => {
controller.abort();
});
}
}
const geminiStreamResult = await this.client
.getGenerativeModel(
{
generationConfig: {
maxOutputTokens: payload.max_tokens,
// @ts-expect-error - Google SDK 0.24.0 doesn't have this property for now with
response_modalities: modelsWithModalities.has(model) ? ['Text', 'Image'] : undefined,
temperature: payload.temperature,
topP: payload.top_p,
...(modelsDisableInstuction.has(model) || model.toLowerCase().includes('learnlm')
? {}
: { thinkingConfig }),
},
model,
// avoid wide sensitive words
// refs: https://github.com/lobehub/lobe-chat/pull/1418
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: getThreshold(model),
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: getThreshold(model),
},
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: getThreshold(model),
},
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: getThreshold(model),
},
],
},
{ apiVersion: 'v1beta', baseUrl: this.baseURL },
)
.generateContentStream(
{
contents,
systemInstruction: modelsDisableInstuction.has(model)
? undefined
: (payload.system as string),
tools: this.buildGoogleTools(payload.tools, payload),
},
{
signal: controller.signal,
},
);
const googleStream = this.createEnhancedStream(geminiStreamResult.stream, controller.signal);
const [prod, useForDebug] = googleStream.tee();
const key = this.isVertexAi
? 'DEBUG_VERTEX_AI_CHAT_COMPLETION'
: 'DEBUG_GOOGLE_CHAT_COMPLETION';
if (process.env[key] === '1') {
debugStream(useForDebug).catch();
}
// Convert the response into a friendly text-stream
const Stream = this.isVertexAi ? VertexAIStream : GoogleGenerativeAIStream;
const stream = Stream(prod, { callbacks: options?.callback, inputStartAt });
// Respond with the stream
return StreamingResponse(stream, { headers: options?.headers });
} catch (e) {
const err = e as Error;
// 移除之前的静默处理,统一抛出错误
if (isAbortError(err)) {
console.log('Request was cancelled');
throw AgentRuntimeError.chat({
error: { message: 'Request was cancelled' },
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: this.provider,
});
}
console.log(err);
const { errorType, error } = this.parseErrorMessage(err.message);
throw AgentRuntimeError.chat({ error, errorType, provider: this.provider });
}
}
private createEnhancedStream(originalStream: any, signal: AbortSignal): ReadableStream {
return new ReadableStream({
async start(controller) {
let hasData = false;
try {
for await (const chunk of originalStream) {
if (signal.aborted) {
// 如果有数据已经输出,优雅地关闭流而不是抛出错误
if (hasData) {
console.log('Stream cancelled gracefully, preserving existing output');
controller.close();
return;
} else {
// 如果还没有数据输出,则抛出取消错误
throw new Error('Stream cancelled');
}
}
hasData = true;
controller.enqueue(chunk);
}
} catch (error) {
const err = error as Error;
// 统一处理所有错误,包括 abort 错误
if (isAbortError(err) || signal.aborted) {
// 如果有数据已经输出,优雅地关闭流
if (hasData) {
console.log('Stream reading cancelled gracefully, preserving existing output');
controller.close();
return;
} else {
console.log('Stream reading cancelled before any output');
controller.error(new Error('Stream cancelled'));
return;
}
} else {
// 处理其他流解析错误
console.error('Stream parsing error:', err);
controller.error(err);
return;
}
}
controller.close();
},
});
}
async models(options?: { signal?: AbortSignal }) {
try {
const url = `${this.baseURL}/v1beta/models?key=${this.apiKey}`;
const response = await fetch(url, {
method: 'GET',
signal: options?.signal,
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const json = await response.json();
const modelList: GoogleModelCard[] = json.models;
const processedModels = modelList.map((model) => {
const id = model.name.replace(/^models\//, '');
return {
contextWindowTokens: (model.inputTokenLimit || 0) + (model.outputTokenLimit || 0),
displayName: model.displayName || id,
id,
maxOutput: model.outputTokenLimit || undefined,
};
});
const { MODEL_LIST_CONFIGS, processModelList } = await import('../utils/modelParse');
return processModelList(processedModels, MODEL_LIST_CONFIGS.google);
} catch (error) {
console.error('Failed to fetch Google models:', error);
throw error;
}
}
private buildPayload(payload: ChatStreamPayload) {
const system_message = payload.messages.find((m) => m.role === 'system');
const user_messages = payload.messages.filter((m) => m.role !== 'system');
return {
...payload,
messages: user_messages,
system: system_message?.content,
};
}
private convertContentToGooglePart = async (
content: UserMessageContentPart,
): Promise<Part | undefined> => {
switch (content.type) {
default: {
return undefined;
}
case 'text': {
return { text: content.text };
}
case 'image_url': {
const { mimeType, base64, type } = parseDataUri(content.image_url.url);
if (type === 'base64') {
if (!base64) {
throw new TypeError("Image URL doesn't contain base64 data");
}
return {
inlineData: {
data: base64,
mimeType: mimeType || 'image/png',
},
};
}
if (type === 'url') {
const { base64, mimeType } = await imageUrlToBase64(content.image_url.url);
return {
inlineData: {
data: base64,
mimeType,
},
};
}
throw new TypeError(`currently we don't support image url: ${content.image_url.url}`);
}
}
};
private convertOAIMessagesToGoogleMessage = async (
message: OpenAIChatMessage,
): Promise<Content> => {
const content = message.content as string | UserMessageContentPart[];
if (!!message.tool_calls) {
return {
parts: message.tool_calls.map<FunctionCallPart>((tool) => ({
functionCall: {
args: safeParseJSON(tool.function.arguments)!,
name: tool.function.name,
},
})),
role: 'function',
};
}
const getParts = async () => {
if (typeof content === 'string') return [{ text: content }];
const parts = await Promise.all(
content.map(async (c) => await this.convertContentToGooglePart(c)),
);
return parts.filter(Boolean) as Part[];
};
return {
parts: await getParts(),
role: message.role === 'assistant' ? 'model' : 'user',
};
};
// convert messages from the OpenAI format to Google GenAI SDK
private buildGoogleMessages = async (messages: OpenAIChatMessage[]): Promise<Content[]> => {
const pools = messages
.filter((message) => message.role !== 'function')
.map(async (msg) => await this.convertOAIMessagesToGoogleMessage(msg));
return Promise.all(pools);
};
private parseErrorMessage(message: string): {
error: any;
errorType: ILobeAgentRuntimeErrorType;
} {
const defaultError = {
error: { message },
errorType: AgentRuntimeErrorType.ProviderBizError,
};
if (message.includes('location is not supported'))
return { error: { message }, errorType: AgentRuntimeErrorType.LocationNotSupportError };
const startIndex = message.lastIndexOf('[');
if (startIndex === -1) {
return defaultError;
}
try {
// 从开始位置截取字符串到最后
const jsonString = message.slice(startIndex);
// 尝试解析 JSON 字符串
const json: GoogleChatErrors = JSON.parse(jsonString);
const bizError = json[0];
switch (bizError.reason) {
case 'API_KEY_INVALID': {
return { ...defaultError, errorType: AgentRuntimeErrorType.InvalidProviderAPIKey };
}
default: {
return { error: json, errorType: AgentRuntimeErrorType.ProviderBizError };
}
}
} catch {
//
}
const errorObj = this.extractErrorObjectFromError(message);
const { errorDetails } = errorObj;
if (errorDetails) {
return { error: errorDetails, errorType: AgentRuntimeErrorType.ProviderBizError };
}
return defaultError;
}
private buildGoogleTools(
tools: ChatCompletionTool[] | undefined,
payload?: ChatStreamPayload,
): GoogleFunctionCallTool[] | undefined {
// 目前 Tools (例如 googleSearch) 无法与其他 FunctionCall 同时使用
if (payload?.messages?.some((m) => m.tool_calls?.length)) {
return; // 若历史消息中已有 function calling,则不再注入任何 Tools
}
if (payload?.enabledSearch) {
return [{ googleSearch: {} } as GoogleSearchRetrievalTool];
}
if (!tools || tools.length === 0) return;
return [
{
functionDeclarations: tools.map((tool) => this.convertToolToGoogleTool(tool)),
},
];
}
private convertToolToGoogleTool = (tool: ChatCompletionTool): FunctionDeclaration => {
const functionDeclaration = tool.function;
const parameters = functionDeclaration.parameters;
// refs: https://github.com/lobehub/lobe-chat/pull/5002
const properties =
parameters?.properties && Object.keys(parameters.properties).length > 0
? parameters.properties
: { dummy: { type: 'string' } }; // dummy property to avoid empty object
return {
description: functionDeclaration.description,
name: functionDeclaration.name,
parameters: {
description: parameters?.description,
properties: properties,
required: parameters?.required,
type: SchemaType.OBJECT,
},
};
};
private extractErrorObjectFromError(message: string) {
// 使用正则表达式匹配状态码部分 [数字 描述文本]
const regex = /^(.*?)(\[\d+ [^\]]+])(.*)$/;
const match = message.match(regex);
if (match) {
const prefix = match[1].trim();
const statusCodeWithBrackets = match[2].trim();
const message = match[3].trim();
// 提取状态码数字
const statusCodeMatch = statusCodeWithBrackets.match(/\[(\d+)/);
const statusCode = statusCodeMatch ? parseInt(statusCodeMatch[1]) : null;
// 创建包含状态码和消息的JSON
const resultJson = {
message: message,
statusCode: statusCode,
statusCodeText: statusCodeWithBrackets,
};
return {
errorDetails: resultJson,
prefix: prefix,
};
}
// 如果无法匹配,返回原始消息
return {
errorDetails: null,
prefix: message,
};
}
}
export default LobeGoogleAI;
type GoogleChatErrors = GoogleChatError[];
interface GoogleChatError {
'@type': string;
'domain': string;
'metadata': {
service: string;
};
'reason': string;
}