@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
530 lines (452 loc) • 17.5 kB
text/typescript
import {
GenerateContentConfig,
Tool as GoogleFunctionCallTool,
GoogleGenAI,
HttpOptions,
ThinkingConfig,
} from '@google/genai';
import debug from 'debug';
import { LobeRuntimeAI } from '../../core/BaseAI';
import { buildGoogleMessages, buildGoogleTools } from '../../core/contextBuilders/google';
import { GoogleGenerativeAIStream, VertexAIStream } from '../../core/streams';
import { LOBE_ERROR_KEY } from '../../core/streams/google';
import {
ChatCompletionTool,
ChatMethodOptions,
ChatStreamPayload,
GenerateObjectOptions,
GenerateObjectPayload,
} from '../../types';
import { AgentRuntimeErrorType } from '../../types/error';
import { CreateImagePayload, CreateImageResponse } from '../../types/image';
import { AgentRuntimeError } from '../../utils/createError';
import { debugStream } from '../../utils/debugStream';
import { getModelPricing } from '../../utils/getModelPricing';
import { parseGoogleErrorMessage } from '../../utils/googleErrorParser';
import { StreamingResponse } from '../../utils/response';
import { createGoogleImage } from './createImage';
import { createGoogleGenerateObject, createGoogleGenerateObjectWithTools } from './generateObject';
const log = debug('model-runtime:google');
const modelsOffSafetySettings = new Set(['gemini-2.0-flash-exp']);
const modelsWithModalities = new Set([
'gemini-2.0-flash-exp',
'gemini-2.0-flash-exp-image-generation',
'gemini-2.0-flash-preview-image-generation',
'gemini-2.5-flash-image-preview',
'gemini-2.5-flash-image',
]);
const modelsDisableInstuction = new Set([
'gemini-2.0-flash-exp',
'gemini-2.0-flash-exp-image-generation',
'gemini-2.0-flash-preview-image-generation',
'gemini-2.5-flash-image-preview',
'gemini-2.5-flash-image',
'gemma-3-1b-it',
'gemma-3-4b-it',
'gemma-3-12b-it',
'gemma-3-27b-it',
'gemma-3n-e4b-it',
]);
const PRO_THINKING_MIN = 128;
const PRO_THINKING_MAX = 32_768;
const FLASH_THINKING_MAX = 24_576;
const FLASH_LITE_THINKING_MIN = 512;
const FLASH_LITE_THINKING_MAX = 24_576;
const clamp = (value: number, min: number, max: number) => Math.min(Math.max(value, min), max);
type ThinkingModelCategory = 'pro' | 'flash' | 'flashLite' | 'robotics' | 'other';
const getThinkingModelCategory = (model?: string): ThinkingModelCategory => {
if (!model) return 'other';
const normalized = model.toLowerCase();
if (normalized.includes('robotics-er-1.5-preview')) return 'robotics';
if (normalized.includes('-2.5-flash-lite') || normalized.includes('flash-lite-latest'))
return 'flashLite';
if (normalized.includes('-2.5-flash') || normalized.includes('flash-latest')) return 'flash';
if (normalized.includes('-2.5-pro') || normalized.includes('pro-latest')) return 'pro';
return 'other';
};
export const resolveModelThinkingBudget = (
model: string,
thinkingBudget?: number | null,
): number | undefined => {
const category = getThinkingModelCategory(model);
const hasBudget = thinkingBudget !== undefined && thinkingBudget !== null;
switch (category) {
case 'pro': {
if (!hasBudget) return -1;
if (thinkingBudget === -1) return -1;
return clamp(thinkingBudget, PRO_THINKING_MIN, PRO_THINKING_MAX);
}
case 'flash': {
if (!hasBudget) return -1;
if (thinkingBudget === -1 || thinkingBudget === 0) return thinkingBudget;
return clamp(thinkingBudget, 0, FLASH_THINKING_MAX);
}
case 'flashLite':
case 'robotics': {
if (!hasBudget) return 0;
if (thinkingBudget === -1 || thinkingBudget === 0) return thinkingBudget;
return clamp(thinkingBudget, FLASH_LITE_THINKING_MIN, FLASH_LITE_THINKING_MAX);
}
default: {
if (!hasBudget) return undefined;
return Math.min(thinkingBudget, FLASH_THINKING_MAX);
}
}
};
export interface GoogleModelCard {
displayName: string;
inputTokenLimit: number;
name: string;
outputTokenLimit: number;
}
enum HarmCategory {
HARM_CATEGORY_DANGEROUS_CONTENT = 'HARM_CATEGORY_DANGEROUS_CONTENT',
HARM_CATEGORY_HARASSMENT = 'HARM_CATEGORY_HARASSMENT',
HARM_CATEGORY_HATE_SPEECH = 'HARM_CATEGORY_HATE_SPEECH',
HARM_CATEGORY_SEXUALLY_EXPLICIT = 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
}
enum HarmBlockThreshold {
BLOCK_NONE = 'BLOCK_NONE',
}
function getThreshold(model: string): HarmBlockThreshold {
if (modelsOffSafetySettings.has(model)) {
return 'OFF' as HarmBlockThreshold; // https://discuss.ai.google.dev/t/59352
}
return HarmBlockThreshold.BLOCK_NONE;
}
const DEFAULT_BASE_URL = 'https://generativelanguage.googleapis.com';
interface LobeGoogleAIParams {
apiKey?: string;
baseURL?: string;
client?: GoogleGenAI;
defaultHeaders?: Record<string, any>;
id?: string;
isVertexAi?: boolean;
}
const isAbortError = (error: Error): boolean => {
const message = error.message.toLowerCase();
return (
message.includes('aborted') ||
message.includes('cancelled') ||
message.includes('error reading from the stream') ||
message.includes('abort') ||
error.name === 'AbortError'
);
};
export class LobeGoogleAI implements LobeRuntimeAI {
private client: GoogleGenAI;
private isVertexAi: boolean;
baseURL?: string;
apiKey?: string;
provider: string;
constructor({
apiKey,
baseURL,
client,
isVertexAi,
id,
defaultHeaders,
}: LobeGoogleAIParams = {}) {
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
const httpOptions = baseURL
? ({ baseUrl: baseURL, headers: defaultHeaders } as HttpOptions)
: undefined;
this.apiKey = apiKey;
this.client = client ? client : new GoogleGenAI({ apiKey, httpOptions });
this.baseURL = client ? undefined : baseURL || DEFAULT_BASE_URL;
this.isVertexAi = isVertexAi || false;
this.provider = id || (isVertexAi ? 'vertexai' : 'google');
}
async chat(rawPayload: ChatStreamPayload, options?: ChatMethodOptions) {
try {
const payload = this.buildPayload(rawPayload);
const { model, thinkingBudget } = payload;
// https://ai.google.dev/gemini-api/docs/thinking#set-budget
const resolvedThinkingBudget = resolveModelThinkingBudget(model, thinkingBudget);
const thinkingConfig: ThinkingConfig = {
includeThoughts:
(!!thinkingBudget ||
(model && (model.includes('-2.5-') || model.includes('thinking')))) &&
resolvedThinkingBudget !== 0
? true
: undefined,
thinkingBudget: resolvedThinkingBudget,
};
const contents = await buildGoogleMessages(payload.messages);
const controller = new AbortController();
const originalSignal = options?.signal;
if (originalSignal) {
if (originalSignal.aborted) {
controller.abort();
} else {
originalSignal.addEventListener('abort', () => {
controller.abort();
});
}
}
const config: GenerateContentConfig = {
abortSignal: originalSignal,
maxOutputTokens: payload.max_tokens,
responseModalities: modelsWithModalities.has(model) ? ['Text', 'Image'] : undefined,
// avoid wide sensitive words
// refs: https://github.com/lobehub/lobe-chat/pull/1418
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: getThreshold(model),
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: getThreshold(model),
},
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: getThreshold(model),
},
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: getThreshold(model),
},
],
systemInstruction: modelsDisableInstuction.has(model)
? undefined
: (payload.system as string),
temperature: payload.temperature,
thinkingConfig:
modelsDisableInstuction.has(model) || model.toLowerCase().includes('learnlm')
? undefined
: thinkingConfig,
tools: this.buildGoogleToolsWithSearch(payload.tools, payload),
topP: payload.top_p,
};
const inputStartAt = Date.now();
const geminiStreamResponse = await this.client.models.generateContentStream({
config,
contents,
model,
});
const googleStream = this.createEnhancedStream(geminiStreamResponse, controller.signal);
const [prod, useForDebug] = googleStream.tee();
const key = this.isVertexAi
? 'DEBUG_VERTEX_AI_CHAT_COMPLETION'
: 'DEBUG_GOOGLE_CHAT_COMPLETION';
if (process.env[key] === '1') {
debugStream(useForDebug).catch();
}
// Convert the response into a friendly text-stream
const pricing = await getModelPricing(model, this.provider);
const Stream = this.isVertexAi ? VertexAIStream : GoogleGenerativeAIStream;
const stream = Stream(prod, {
callbacks: options?.callback,
inputStartAt,
payload: { model, pricing, provider: this.provider },
});
// Respond with the stream
return StreamingResponse(stream, { headers: options?.headers });
} catch (e) {
const err = e as Error;
// 移除之前的静默处理,统一抛出错误
if (isAbortError(err)) {
log('Request was cancelled');
throw AgentRuntimeError.chat({
error: { message: 'Request was cancelled' },
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: this.provider,
});
}
log('Error: %O', err);
const { errorType, error } = parseGoogleErrorMessage(err.message);
throw AgentRuntimeError.chat({ error, errorType, provider: this.provider });
}
}
/**
* Generate images using Google AI Imagen API or Gemini Chat Models
* @see https://ai.google.dev/gemini-api/docs/image-generation#imagen
*/
async createImage(payload: CreateImagePayload): Promise<CreateImageResponse> {
return createGoogleImage(this.client, this.provider, payload);
}
/**
* Generate structured output using Google Gemini API
* @see https://ai.google.dev/gemini-api/docs/structured-output
* @see https://ai.google.dev/gemini-api/docs/function-calling
*/
async generateObject(payload: GenerateObjectPayload, options?: GenerateObjectOptions) {
// Convert OpenAI messages to Google format
const contents = await buildGoogleMessages(payload.messages);
// Handle tools-based structured output
if (payload.tools && payload.tools.length > 0) {
return createGoogleGenerateObjectWithTools(
this.client,
{ contents, model: payload.model, tools: payload.tools },
options,
);
}
// Handle schema-based structured output
if (payload.schema) {
return createGoogleGenerateObject(
this.client,
{ contents, model: payload.model, schema: payload.schema },
options,
);
}
return undefined;
}
private createEnhancedStream(originalStream: any, signal: AbortSignal): ReadableStream {
// capture provider for error payloads inside the stream closure
const provider = this.provider;
return new ReadableStream({
async start(controller) {
let hasData = false;
try {
for await (const chunk of originalStream) {
if (signal.aborted) {
// 如果有数据已经输出,优雅地关闭流而不是抛出错误
if (hasData) {
log('Stream cancelled gracefully, preserving existing output');
// 显式注入取消错误,避免走 SSE 兜底 unexpected_end
controller.enqueue({
[LOBE_ERROR_KEY]: {
body: { name: 'Stream cancelled', provider, reason: 'aborted' },
message: 'Stream cancelled',
name: 'Stream cancelled',
type: AgentRuntimeErrorType.StreamChunkError,
},
});
controller.close();
return;
} else {
// 如果还没有数据输出,直接关闭流,由下游 SSE 在 flush 阶段补发错误事件
log('Stream cancelled before any output');
controller.close();
return;
}
}
hasData = true;
controller.enqueue(chunk);
}
} catch (error) {
const err = error as Error;
// 统一处理所有错误,包括 abort 错误
if (isAbortError(err) || signal.aborted) {
// 如果有数据已经输出,优雅地关闭流
if (hasData) {
log('Stream reading cancelled gracefully, preserving existing output');
// 显式注入取消错误,避免走 SSE 兜底 unexpected_end
controller.enqueue({
[LOBE_ERROR_KEY]: {
body: { name: 'Stream cancelled', provider, reason: 'aborted' },
message: 'Stream cancelled',
name: 'Stream cancelled',
type: AgentRuntimeErrorType.StreamChunkError,
},
});
controller.close();
return;
} else {
log('Stream reading cancelled before any output');
// 注入一个带详细错误信息的错误标记,交由下游 google-ai transformer 输出 error 事件
controller.enqueue({
[LOBE_ERROR_KEY]: {
body: {
message: err.message,
name: 'AbortError',
provider,
stack: err.stack,
},
message: err.message || 'Request was cancelled',
name: 'AbortError',
type: AgentRuntimeErrorType.StreamChunkError,
},
});
controller.close();
return;
}
} else {
// 处理其他流解析错误
log('Stream parsing error: %O', err);
// 尝试解析 Google 错误并提取 code/message/status
const { error: parsedError, errorType } = parseGoogleErrorMessage(
err?.message || String(err),
);
// 注入一个带详细错误信息的错误标记,交由下游 google-ai transformer 输出 error 事件
controller.enqueue({
[LOBE_ERROR_KEY]: {
body: { ...parsedError, provider },
message: parsedError?.message || err.message || 'Stream parsing error',
name: 'Stream parsing error',
type: errorType ?? AgentRuntimeErrorType.StreamChunkError,
},
});
controller.close();
return;
}
}
controller.close();
},
});
}
async models(options?: { signal?: AbortSignal }) {
try {
const url = `${this.baseURL}/v1beta/models?key=${this.apiKey}`;
const response = await fetch(url, {
method: 'GET',
signal: options?.signal,
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const json = await response.json();
const modelList: GoogleModelCard[] = json.models;
const processedModels = modelList.map((model) => {
const id = model.name.replace(/^models\//, '');
return {
contextWindowTokens: (model.inputTokenLimit || 0) + (model.outputTokenLimit || 0),
displayName: model.displayName || id,
id,
maxOutput: model.outputTokenLimit || undefined,
};
});
const { MODEL_LIST_CONFIGS, processModelList } = await import('../../utils/modelParse');
return processModelList(processedModels, MODEL_LIST_CONFIGS.google, 'google');
} catch (error) {
log('Failed to fetch Google models: %O', error);
throw error;
}
}
private buildPayload(payload: ChatStreamPayload) {
const system_message = payload.messages.find((m) => m.role === 'system');
const user_messages = payload.messages.filter((m) => m.role !== 'system');
return {
...payload,
messages: user_messages,
system: system_message?.content,
};
}
private buildGoogleToolsWithSearch(
tools: ChatCompletionTool[] | undefined,
payload?: ChatStreamPayload,
): GoogleFunctionCallTool[] | undefined {
const hasToolCalls = payload?.messages?.some((m) => m.tool_calls?.length);
const hasSearch = payload?.enabledSearch;
const hasUrlContext = payload?.urlContext;
const hasFunctionTools = tools && tools.length > 0;
// 如果已经有 tool_calls,优先处理 function declarations
if (hasToolCalls && hasFunctionTools) {
return buildGoogleTools(tools);
}
// 构建并返回搜索相关工具(搜索工具不能与 FunctionCall 同时使用)
if (hasUrlContext && hasSearch) {
return [{ urlContext: {} }, { googleSearch: {} }];
}
if (hasUrlContext) {
return [{ urlContext: {} }];
}
if (hasSearch) {
return [{ googleSearch: {} }];
}
// 最后考虑 function declarations
return buildGoogleTools(tools);
}
}
export default LobeGoogleAI;