UNPKG

@unified-llm/core

Version:

Unified LLM interface (in-memory).

615 lines 27.7 kB
/* ----------------------------------------------------------- * GoogleProvider * - Google Generative AI を Enbod の LLMProvider インターフェースに適合 * - “ローカル Assistant” という扱いなので、ID は擬似値を発行し * すべてメモリ内で完結させる * ---------------------------------------------------------- */ import { GoogleGenerativeAI } from '@google/generative-ai'; import BaseProvider from '../base-provider'; import { validateChatRequest } from '../../utils/validation'; import { ResponseFormat } from '../../response-format'; // type ChatHistory = { role: 'user' | 'assistant'; content: string }[]; /** スタブ実装。SDK を呼ばずにビルドだけ通す */ export class GeminiProvider extends BaseProvider { constructor({ apiKey, model, tools }) { super({ model: model || 'gemini-pro', tools }); this.client = new GoogleGenerativeAI(apiKey); } async chat(request) { validateChatRequest(request); try { const model = request.model || this.model; if (!model) { throw new Error('Model is required for Gemini chat'); } const modelInstance = this.client.getGenerativeModel({ model }); const tools = this.convertToolsToGeminiFormat(request.tools, this.tools); // Extract system prompt from messages const systemMessage = request.messages.find(m => m.role === 'system'); const systemInstruction = systemMessage ? this.extractTextFromContent(systemMessage.content) : undefined; // Standard flow: exclude function results and system messages from history since Gemini handles them differently const filteredMessages = request.messages.filter(msg => { const content = this.normalizeContent(msg.content); return !content.some(c => c.type === 'tool_result') && msg.role !== 'system'; }); const history = await this.convertToGeminiHistory(filteredMessages.slice(0, -1)); const chatConfig = { history, generationConfig: this.convertGenerationConfig(request.generationConfig), tools: tools.length > 0 ? tools : undefined, }; if (systemInstruction) { chatConfig.systemInstruction = { parts: [{ text: systemInstruction }], role: 'user' }; } const chat = modelInstance.startChat(chatConfig); const lastMessage = filteredMessages[filteredMessages.length - 1]; const prompt = this.extractPromptFromMessage(lastMessage); let result = await chat.sendMessage(prompt); let response = await result.response; // ツール呼び出しがある場合、実行して結果を返す while (this.hasFunctionCalls(response) && this.tools) { const functionCalls = this.extractFunctionCalls(response); const functionResults = []; for (const call of functionCalls) { const customFunction = this.tools.find(func => func.function.name === call.name); if (customFunction) { try { // CustomFunctionのargsとfunction callのargsをマージ const mergedArgs = { ...(customFunction.args || {}), ...call.args }; const callResult = await customFunction.handler(mergedArgs); functionResults.push({ name: call.name, response: { result: callResult }, }); } catch (error) { functionResults.push({ name: call.name, response: { error: error instanceof Error ? error.message : 'Unknown error' }, }); } } } // 関数実行結果を送信して次の応答を取得 if (functionResults.length > 0) { // Gemini形式に変換 const parts = functionResults.map(funcResult => ({ functionResponse: { name: funcResult.name, response: funcResult.response } })); result = await chat.sendMessage(parts); response = await result.response; } else { break; } } return this.convertFromGeminiFormat(response, result); } catch (error) { throw this.handleError(error); } } async *stream(request) { var _a; validateChatRequest(request); const model = request.model || this.model; if (!model) { throw new Error('Model is required for Gemini chat'); } const modelInstance = this.client.getGenerativeModel({ model }); const tools = this.convertToolsToGeminiFormat(request.tools, this.tools); // Extract system prompt from messages const systemMessage = request.messages.find(m => m.role === 'system'); const systemInstruction = systemMessage ? this.extractTextFromContent(systemMessage.content) : undefined; // Filter out system messages and tool results from history const filteredMessages = request.messages.filter(msg => { const content = this.normalizeContent(msg.content); return !content.some(c => c.type === 'tool_result') && msg.role !== 'system'; }); const history = await this.convertToGeminiHistory(filteredMessages.slice(0, -1)); // Keep trying to get a response until we don't get tool calls while (true) { const chatConfig = { history, generationConfig: this.convertGenerationConfig(request.generationConfig), tools: tools.length > 0 ? tools : undefined, }; if (systemInstruction) { chatConfig.systemInstruction = { parts: [{ text: systemInstruction }], role: 'user' }; } const chat = modelInstance.startChat(chatConfig); const lastMessage = filteredMessages[filteredMessages.length - 1]; const prompt = this.extractPromptFromMessage(lastMessage); const result = await chat.sendMessageStream(prompt); // Collect all chunks first to detect if there are function calls const chunks = []; for await (const chunk of result.stream) { chunks.push(chunk); } // Get the complete response to check for function calls const completeResponse = await result.response; const hasFunctionCalls = this.hasFunctionCalls(completeResponse); if (!hasFunctionCalls) { // No function calls, yield all collected chunks if (chunks.length === 1) { // If only one chunk, split it into multiple chunks for proper streaming simulation const singleChunk = chunks[0]; const text = singleChunk.text(); const words = text.split(' '); const chunkSize = Math.max(1, Math.floor(words.length / 2)); // Create at least 2 chunks for (let i = 0; i < words.length; i += chunkSize) { const chunkWords = words.slice(i, i + chunkSize); const chunkText = chunkWords.join(' ') + (i + chunkSize < words.length ? ' ' : ''); const mockChunk = { text: () => chunkText, candidates: [{ content: { parts: [{ text: chunkText }] } }] }; yield this.convertStreamChunk(mockChunk); } } else { for (const chunk of chunks) { yield this.convertStreamChunk(chunk); } } break; } else { // Function calls detected, execute them const functionCalls = this.extractFunctionCalls(completeResponse); const functionResults = []; for (const call of functionCalls) { const customFunction = (_a = this.tools) === null || _a === void 0 ? void 0 : _a.find(func => func.function.name === call.name); if (customFunction) { try { // Merge default args with function call args const mergedArgs = { ...(customFunction.args || {}), ...call.args }; const callResult = await customFunction.handler(mergedArgs); functionResults.push({ name: call.name, response: { result: callResult }, }); } catch (error) { functionResults.push({ name: call.name, response: { error: error instanceof Error ? error.message : 'Unknown error' }, }); } } } // If we have function results, execute them and return the final response in streaming format if (functionResults.length > 0) { // Create a streaming response with the function result const resultText = functionResults.map(result => typeof result.response.result === 'string' ? result.response.result : JSON.stringify(result.response.result)).join('\n'); // Split the result into chunks for streaming simulation const words = resultText.split(' '); const chunkSize = Math.max(1, Math.floor(words.length / 3)); // Create at least 3 chunks for (let i = 0; i < words.length; i += chunkSize) { const chunkWords = words.slice(i, i + chunkSize); const chunkText = chunkWords.join(' ') + (i + chunkSize < words.length ? ' ' : ''); const mockChunk = { text: () => chunkText, candidates: [{ content: { parts: [{ text: chunkText }] } }] }; yield this.convertStreamChunk(mockChunk); } // Break out of the loop after streaming function results break; } } break; } } convertToolsToGeminiFormat(requestTools, providerTools) { const allTools = []; // request.toolsを追加 if (requestTools && requestTools.length > 0) { allTools.push(...requestTools.map(tool => ({ name: tool.function.name, description: tool.function.description || '', parameters: tool.function.parameters || { type: 'object', properties: {} } }))); } // provider.toolsを追加 if (providerTools && providerTools.length > 0) { allTools.push(...providerTools.map(func => ({ name: func.function.name, description: func.function.description || '', parameters: func.function.parameters || { type: 'object', properties: {} } }))); } if (allTools.length === 0) return []; // Gemini expects a single object with functionDeclarations array return [{ functionDeclarations: allTools }]; } hasFunctionCalls(response) { try { const functionCalls = response.functionCalls(); return functionCalls && functionCalls.length > 0; } catch (_a) { // candidates approach if (response.candidates && response.candidates.length > 0) { const candidate = response.candidates[0]; if (candidate.content && candidate.content.parts) { return candidate.content.parts.some((part) => part.functionCall); } } return false; } } extractFunctionCalls(response) { try { const functionCalls = response.functionCalls(); return functionCalls || []; } catch (_a) { // candidates approach const calls = []; if (response.candidates && response.candidates.length > 0) { const candidate = response.candidates[0]; if (candidate.content && candidate.content.parts) { candidate.content.parts.forEach((part) => { if (part.functionCall) { calls.push({ name: part.functionCall.name, args: part.functionCall.args || {} }); } }); } } return calls; } } async convertToGeminiHistory(messages) { return Promise.all(messages.map(async (msg) => { // Debug logging can be enabled for message conversion if needed const content = this.normalizeContent(msg.content); const parts = await Promise.all(content.map(async (c) => { var _a; // Debug logging for content items can be enabled if needed switch (c.type) { case 'text': return { text: c.text }; case 'image': return { inlineData: { mimeType: c.source.mediaType || 'image/jpeg', data: c.source.data || '', }, }; case 'tool_use': { // customFunctionsからツールを検索して実行 const customFunction = (_a = this.tools) === null || _a === void 0 ? void 0 : _a.find(func => func.function.name === c.name); if (customFunction) { try { // CustomFunctionのargsとtool_useのinputをマージ const mergedArgs = { ...(customFunction.args || {}), ...c.input }; const result = await customFunction.handler(mergedArgs); return { functionResponse: { name: c.name, response: { result: typeof result === 'string' ? result : JSON.stringify(result) } } }; } catch (error) { return { functionResponse: { name: c.name, response: { error: error instanceof Error ? error.message : 'Unknown error' } } }; } } return { functionCall: { name: c.name, args: c.input } }; } case 'tool_result': { const responseContent = Array.isArray(c.content) ? c.content.map(item => { if (item.type === 'text') { // Try to parse JSON result to extract the actual value try { const parsed = JSON.parse(item.text); return typeof parsed === 'string' ? parsed : JSON.stringify(parsed); } catch (_a) { return item.text; } } return '[Non-text content]'; }).join('\n') : '[Tool result]'; return { functionResponse: { name: c.functionName || c.toolUseId, response: responseContent } }; } default: return { text: '[Unsupported content type]' }; } })); // For Gemini, function responses must come from 'function' role // Check if this message contains functionResponse parts const hasFunctionResponse = parts.some(part => 'functionResponse' in part); let role; if (hasFunctionResponse) { role = 'function'; // Function responses must be from function according to Gemini docs } else { role = msg.role === 'assistant' ? 'model' : 'user'; } return { role, parts, }; })); } extractPromptFromMessage(message) { const content = this.normalizeContent(message.content); if (content.length === 1 && content[0].type === 'text') { return content[0].text; } return content.map(c => { switch (c.type) { case 'text': return { text: c.text }; case 'image': return { inlineData: { mimeType: c.source.mediaType || 'image/jpeg', data: c.source.data || '', }, }; default: return { text: '[Unsupported content type]' }; } }); } convertGenerationConfig(config) { if (!config) return undefined; const result = { temperature: config.temperature, topP: config.top_p, topK: config.top_k, maxOutputTokens: config.max_tokens, stopSequences: config.stopSequences, }; // Handle response format if (config.responseFormat) { // If it's a ResponseFormat instance, use its toGoogle method if (config.responseFormat instanceof ResponseFormat) { const googleFormat = config.responseFormat.toGoogle(); result.responseMimeType = googleFormat.responseMimeType; result.responseSchema = googleFormat.responseSchema; } // Handle legacy format else if (config.responseFormat.type === 'json_object') { result.responseMimeType = 'application/json'; if (config.responseFormat.schema) { result.responseSchema = this.convertToGoogleSchema(config.responseFormat.schema); } } } return result; } convertToGoogleSchema(schema) { const converted = { type: this.mapToGoogleType(schema.type) }; if (schema.description) { converted.description = schema.description; } if (schema.type === 'object' && schema.properties) { converted.properties = {}; for (const [key, value] of Object.entries(schema.properties)) { converted.properties[key] = this.convertToGoogleSchema(value); } if (schema.required) { converted.required = schema.required; } } if (schema.type === 'array' && schema.items) { converted.items = this.convertToGoogleSchema(schema.items); } if (schema.enum) { converted.enum = schema.enum; } return converted; } mapToGoogleType(type) { const typeMap = { 'object': 'OBJECT', 'array': 'ARRAY', 'string': 'STRING', 'number': 'NUMBER', 'boolean': 'BOOLEAN', 'null': 'NULL' }; return typeMap[type] || 'STRING'; } convertFromGeminiFormat(response, _result) { var _a, _b; if (!this.model) { throw new Error('Model is required for Gemini response conversion'); } const content = []; // Debug logging can be enabled if needed // console.log('🔍 Debug Gemini response structure:', { ... }); // Check candidates for content if (response.candidates && response.candidates.length > 0) { const candidate = response.candidates[0]; // console.log('🔍 Candidate content:', candidate.content); if (candidate.content && candidate.content.parts) { candidate.content.parts.forEach((part, _index) => { // console.log(`🔍 Part ${index}:`, part); if (part.text) { content.push({ type: 'text', text: part.text }); } else if (part.functionCall) { content.push({ type: 'tool_use', id: this.generateMessageId(), name: part.functionCall.name, input: part.functionCall.args || {} }); } }); } } // Fallback to legacy methods if candidates approach doesn't work if (content.length === 0) { try { const text = response.text(); if (text) { content.push({ type: 'text', text }); } } catch (_e) { // console.log('🔍 No text method available'); } try { const functionCalls = response.functionCalls(); if (functionCalls && functionCalls.length > 0) { functionCalls.forEach((call) => { content.push({ type: 'tool_use', id: this.generateMessageId(), name: call.name, input: call.args || {} }); }); } } catch (_e) { // console.log('🔍 No functionCalls method available'); } } // コンテンツが空の場合はプレースホルダーを追加 if (content.length === 0) { content.push({ type: 'text', text: '[No content from Gemini]' }); } const unifiedMessage = { id: this.generateMessageId(), role: 'assistant', content, createdAt: new Date(), }; // Geminiは使用統計を異なる形式で提供 const usage = response.usageMetadata ? { inputTokens: response.usageMetadata.promptTokenCount || 0, outputTokens: response.usageMetadata.candidatesTokenCount || 0, totalTokens: response.usageMetadata.totalTokenCount || 0, } : undefined; // Extract text for convenience field const contentArray = Array.isArray(unifiedMessage.content) ? unifiedMessage.content : [{ type: 'text', text: unifiedMessage.content }]; const textContent = contentArray.find((c) => c.type === 'text'); return { id: this.generateMessageId(), model: this.model, provider: 'google', message: unifiedMessage, text: (textContent === null || textContent === void 0 ? void 0 : textContent.text) || '', usage, finish_reason: this.mapFinishReason((_b = (_a = response.candidates) === null || _a === void 0 ? void 0 : _a[0]) === null || _b === void 0 ? void 0 : _b.finishReason), createdAt: new Date(), rawResponse: response, }; } convertStreamChunk(chunk) { if (!this.model) { throw new Error('Model is required for Gemini stream'); } const text = chunk.text(); const content = [{ type: 'text', text }]; const unifiedMessage = { id: this.generateMessageId(), role: 'assistant', content, createdAt: new Date(), }; // Extract text for convenience field const contentArray = Array.isArray(unifiedMessage.content) ? unifiedMessage.content : [{ type: 'text', text: unifiedMessage.content }]; const textContent = contentArray.find((c) => c.type === 'text'); return { id: this.generateMessageId(), model: this.model, provider: 'google', message: unifiedMessage, text: (textContent === null || textContent === void 0 ? void 0 : textContent.text) || '', createdAt: new Date(), rawResponse: chunk, }; } mapFinishReason(reason) { switch (reason) { case 'STOP': return 'stop'; case 'MAX_TOKENS': return 'length'; case 'SAFETY': return 'content_filter'; default: return null; } } extractTextFromContent(content) { if (typeof content === 'string') return content; return content .filter(c => c.type === 'text') .map(c => c.text) .join('\n') || ''; } handleError(error) { return { code: error.code || 'gemini_error', message: error.message || 'Unknown error occurred', type: 'api_error', provider: 'google', details: error, }; } } //# sourceMappingURL=provider.js.map