UNPKG

autosnippet

Version:

Extract code patterns into a knowledge base for AI coding assistants

416 lines (415 loc) 17.6 kB
/** * GoogleGeminiProvider - Google Gemini AI 提供商 * 直接调用 REST API(不依赖 SDK) * * v3: 统一消息格式 — chatWithTools() 接受 Provider-Agnostic 消息 * 内部自动转换为 Gemini 原生 contents / functionDeclarations 格式 * 支持 toolChoice: 'auto' | 'required' | 'none' */ import Logger from '#infra/logging/Logger.js'; import { AiProvider, } from '../AiProvider.js'; const GEMINI_BASE = 'https://generativelanguage.googleapis.com/v1beta'; const DEFAULT_EMBED_MODEL = 'models/gemini-embedding-001'; export class GoogleGeminiProvider extends AiProvider { #embedModel; constructor(config = {}) { super({ ...config, maxConcurrency: config.maxConcurrency || Number(process.env.ASD_GEMINI_MAX_CONCURRENCY || process.env.ASD_AI_MAX_CONCURRENCY || 2), }); this.name = 'google-gemini'; this.model = config.model || 'gemini-3-flash-preview'; this.apiKey = config.apiKey || process.env.ASD_GOOGLE_API_KEY || ''; this.#embedModel = config.embedModel ? `models/${config.embedModel.replace(/^models\//, '')}` : DEFAULT_EMBED_MODEL; this.logger = Logger.getInstance(); } /** 是否支持原生结构化函数调用 */ get supportsNativeToolCalling() { return true; } async chat(prompt, context = {}) { return this._withRetry(async () => { const { history = [], temperature = 0.7, maxTokens = 8192, systemPrompt } = context; const contents = []; for (const h of history) { contents.push({ role: h.role === 'assistant' ? 'model' : 'user', parts: [{ text: h.content }], }); } contents.push({ role: 'user', parts: [{ text: prompt }] }); const body = { contents, generationConfig: { temperature, maxOutputTokens: maxTokens, }, }; // systemInstruction 支持(chat 也可用 systemPrompt) if (systemPrompt) { body.systemInstruction = { parts: [{ text: systemPrompt }] }; } const url = `${GEMINI_BASE}/models/${this.model}:generateContent?key=${this.apiKey}`; const data = await this._post(url, body); // 提取 token 用量 if (data?.usageMetadata) { this._emitTokenUsage({ inputTokens: data.usageMetadata.promptTokenCount || 0, outputTokens: data.usageMetadata.candidatesTokenCount || 0, totalTokens: (data.usageMetadata.promptTokenCount || 0) + (data.usageMetadata.candidatesTokenCount || 0), }); } return data?.candidates?.[0]?.content?.parts?.[0]?.text || ''; }); } /** * 带工具声明的结构化对话 — Gemini 原生 Function Calling * * 接受统一消息格式,内部转换为 Gemini 原生 contents 格式。 * * @param prompt 未使用 messages 时的 fallback prompt * @param opts.messages 统一格式消息 * @param opts.toolSchemas [{name, description, parameters}] * @param opts.toolChoice 'auto' | 'required' | 'none' * @returns >|null}>} */ async chatWithTools(prompt, opts = {}) { return this._withRetry(async () => { const { messages: rawMessages, toolSchemas: rawToolSchemas, toolChoice = 'auto', systemPrompt, temperature = 0.7, maxTokens = 8192, } = opts; const messages = rawMessages; const toolSchemas = rawToolSchemas; // 统一消息 → Gemini contents const contents = messages && messages.length > 0 ? this.#convertMessages(messages) : [{ role: 'user', parts: [{ text: prompt }] }]; const body = { contents, generationConfig: { temperature, maxOutputTokens: maxTokens, }, }; // 工具声明: 标准 schema → Gemini functionDeclarations if (toolSchemas && toolSchemas.length > 0) { body.tools = [ { functionDeclarations: toolSchemas.map((s) => this.#toFunctionDeclaration(s)), }, ]; } // toolChoice → Gemini mode (仅在有工具声明时设置,无工具时设 toolConfig 可能导致空响应) if (body.tools) { body.toolConfig = { functionCallingConfig: { mode: this.#toGeminiMode(toolChoice) }, }; } // 系统指令 if (systemPrompt) { body.systemInstruction = { parts: [{ text: systemPrompt }] }; } const url = `${GEMINI_BASE}/models/${this.model}:generateContent?key=${this.apiKey}`; const data = await this._post(url, body, opts.abortSignal); return this.#parseToolResponse(data); }); } // ─── 内部转换方法 ────────────────────── /** * 统一消息格式 → Gemini contents * - user → {role: 'user', parts: [{text}]} * - assistant → {role: 'model', parts: [{text}, {functionCall}...]} * - tool → grouped into {role: 'user', parts: [{functionResponse}...]} * * Gemini 要求严格交替 user/model 角色。 * 连续同角色消息(如 L2/L3 压缩后的摘要)自动合并 parts 以避免 400 错误。 */ #convertMessages(messages) { const contents = []; let pendingToolResults = []; /** 推入 contents,如果上一个 entry 同角色则合并 parts */ const pushOrMerge = (entry) => { const last = contents[contents.length - 1]; if (last && last.role === entry.role) { last.parts.push(...entry.parts); } else { contents.push(entry); } }; for (const msg of messages) { if (msg.role === 'tool') { // 收集连续 tool results → 将在下一个非 tool 消息前或末尾 flush pendingToolResults.push({ functionResponse: { name: msg.name || '', response: { result: msg.content || '' }, }, }); continue; } // Flush pending tool results before non-tool message if (pendingToolResults.length > 0) { pushOrMerge({ role: 'user', parts: pendingToolResults }); pendingToolResults = []; } if (msg.role === 'user') { pushOrMerge({ role: 'user', parts: [{ text: msg.content || '' }] }); } else if (msg.role === 'assistant') { const parts = []; if (msg.content) { parts.push({ text: msg.content }); } if (msg.toolCalls && msg.toolCalls.length > 0) { for (const tc of msg.toolCalls) { const fcPart = { functionCall: { name: tc.name, args: tc.args || {} }, }; // Gemini 3+: 回填 thoughtSignature(首个 functionCall 必须携带) if (tc.thoughtSignature) { fcPart.thoughtSignature = tc.thoughtSignature; } parts.push(fcPart); } } if (parts.length > 0) { pushOrMerge({ role: 'model', parts }); } } } // Flush remaining tool results if (pendingToolResults.length > 0) { pushOrMerge({ role: 'user', parts: pendingToolResults }); } return contents; } /** toolChoice → Gemini mode */ #toGeminiMode(toolChoice) { switch (toolChoice) { case 'required': return 'ANY'; case 'none': return 'NONE'; default: return 'AUTO'; } } /** 标准 tool schema → Gemini functionDeclaration */ #toFunctionDeclaration(schema) { return { name: schema.name, description: schema.description || '', parameters: this.#sanitizeSchemaForGemini(schema.parameters), }; } /** * 清理 JSON Schema 使之兼容 Gemini API 的 OpenAPI 子集(递归) * Gemini API 不支持 default、examples 等 JSON Schema 扩展字段 */ #sanitizeSchemaForGemini(schema) { if (!schema || typeof schema !== 'object') { return { type: 'object', properties: {} }; } const cleaned = { ...schema }; delete cleaned.default; delete cleaned.examples; if (!cleaned.type) { cleaned.type = 'object'; } // 递归清理 properties if (cleaned.properties) { const props = {}; for (const [key, val] of Object.entries(cleaned.properties)) { props[key] = this.#sanitizeSchemaForGemini(val); } cleaned.properties = props; } // 递归清理 items (array 类型) if (cleaned.type === 'array') { if (cleaned.items && typeof cleaned.items === 'object') { cleaned.items = this.#sanitizeSchemaForGemini(cleaned.items); } else { // Gemini 强制要求 array 必须有 items,缺失时补 string 兜底 cleaned.items = { type: 'string' }; } } return cleaned; } /** * 解析 Gemini API 响应 — 提取 functionCall 或 text * 返回统一格式(含生成的 id) */ #parseToolResponse(data) { const content = data?.candidates?.[0]?.content; // 提取 token 用量 (Gemini usageMetadata) const usage = data?.usageMetadata ? { inputTokens: data.usageMetadata.promptTokenCount || 0, outputTokens: data.usageMetadata.candidatesTokenCount || 0, totalTokens: data.usageMetadata.totalTokenCount || 0, } : null; if (!content || !content.parts || content.parts.length === 0) { return { text: '', functionCalls: null, usage }; } const functionCalls = []; const textParts = []; let fcIndex = 0; for (const part of content.parts) { if (part.functionCall) { functionCalls.push({ id: `gemini_fc_${Date.now()}_${fcIndex++}`, name: part.functionCall.name, args: part.functionCall.args || {}, // Gemini 3+: thoughtSignature 必须原样回传,否则后续请求 400 thoughtSignature: part.thoughtSignature || undefined, }); } else if (part.text) { textParts.push(part.text); } } if (functionCalls.length > 0) { this.logger?.debug(`[GeminiProvider] native function calls: ${functionCalls.map((fc) => fc.name).join(', ')}`); return { text: textParts.length > 0 ? textParts.join('\n') : null, functionCalls, usage, }; } return { text: textParts.join('\n'), functionCalls: null, usage, }; } async summarize(code) { const prompt = `请对以下代码生成结构化摘要,返回 JSON 格式 {title, description, language, patterns: [], keyAPIs: []}:\n\n${code}`; return ((await this.chatWithStructuredOutput(prompt, { temperature: 0.3, maxTokens: 8192, })) || { title: '', description: '' }); } /** * Structured Output — Gemini 原生 JSON mode * * 使用 responseMimeType: 'application/json' 强制 Gemini 返回合法 JSON。 * 可选传入 responseSchema 做编译期校验(Gemini 1.5+ / Gemini 2+)。 */ async chatWithStructuredOutput(prompt, opts = {}) { return this._withRetry(async () => { const { schema, temperature = 0.3, maxTokens = 32768, systemPrompt } = opts; const contents = [{ role: 'user', parts: [{ text: prompt }] }]; const generationConfig = { temperature, maxOutputTokens: maxTokens, responseMimeType: 'application/json', }; // 如果提供了 JSON Schema,注入 responseSchema(Gemini 编译期校验) if (schema) { generationConfig.responseSchema = this.#sanitizeSchemaForGemini(schema); } const body = { contents, generationConfig }; if (systemPrompt) { body.systemInstruction = { parts: [{ text: systemPrompt }] }; } const url = `${GEMINI_BASE}/models/${this.model}:generateContent?key=${this.apiKey}`; const data = await this._post(url, body); // 提取 token 用量 if (data?.usageMetadata) { this._emitTokenUsage({ inputTokens: data.usageMetadata.promptTokenCount || 0, outputTokens: data.usageMetadata.candidatesTokenCount || 0, totalTokens: (data.usageMetadata.promptTokenCount || 0) + (data.usageMetadata.candidatesTokenCount || 0), }); } const text = data?.candidates?.[0]?.content?.parts?.[0]?.text || ''; if (!text) { return null; } try { return JSON.parse(text); } catch { // Gemini JSON mode 偶尔返回前后有空白的 JSON,尝试 extractJSON 降级 const openChar = opts.openChar || '{'; const closeChar = opts.closeChar || '}'; return this.extractJSON(text, openChar, closeChar); } }); } async embed(text) { const texts = Array.isArray(text) ? text : [text]; const results = []; for (let i = 0; i < texts.length; i += 100) { const batch = texts.slice(i, i + 100); const requests = batch.map((t) => ({ model: this.#embedModel, content: { parts: [{ text: t.slice(0, 8000) }] }, })); const url = `${GEMINI_BASE}/${this.#embedModel}:batchEmbedContents?key=${this.apiKey}`; const data = await this._post(url, { requests }); if (data?.embeddings) { results.push(...data.embeddings.map((e) => e.values)); } } return Array.isArray(text) ? results : results[0] || []; } async _post(url, body, externalSignal) { if (!this.apiKey) { const err = new Error('Google Gemini API Key 未配置。请在 .env 中设置 ASD_GOOGLE_API_KEY,或运行 asd setup 完成配置。'); err.code = 'API_KEY_MISSING'; throw err; } const controller = new AbortController(); const timer = setTimeout(() => controller.abort(), this.timeout); // 外部中止信号 → 联动本地 controller const onExternalAbort = () => controller.abort(); externalSignal?.addEventListener('abort', onExternalAbort, { once: true }); try { const res = await this._fetch(url, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(body), signal: controller.signal, }); if (!res.ok) { const retryAfterHeader = res.headers.get('retry-after'); let retryAfterMs = 0; if (retryAfterHeader) { const sec = Number(retryAfterHeader); if (Number.isFinite(sec) && sec > 0) { retryAfterMs = sec * 1000; } else { const when = Date.parse(retryAfterHeader); if (Number.isFinite(when)) { retryAfterMs = Math.max(0, when - Date.now()); } } } let detail = ''; try { const j = (await res.json()); detail = j?.error?.message || JSON.stringify(j).slice(0, 300); } catch { /* ignore */ } const err = Object.assign(new Error(`Gemini API error: ${res.status}${detail ? ` — ${detail}` : ''}`), { status: res.status, ...(retryAfterMs > 0 ? { retryAfterMs } : {}) }); throw err; } return (await res.json()); } finally { clearTimeout(timer); externalSignal?.removeEventListener('abort', onExternalAbort); } } } export default GoogleGeminiProvider;