UNPKG

ai-sdk-provider-gemini-cli

Version:

Community AI SDK provider for Google Gemini using the official CLI/SDK

701 lines (689 loc) 22 kB
"use strict"; var __defProp = Object.defineProperty; var __getOwnPropDesc = Object.getOwnPropertyDescriptor; var __getOwnPropNames = Object.getOwnPropertyNames; var __hasOwnProp = Object.prototype.hasOwnProperty; var __export = (target, all) => { for (var name in all) __defProp(target, name, { get: all[name], enumerable: true }); }; var __copyProps = (to, from, except, desc) => { if (from && typeof from === "object" || typeof from === "function") { for (let key of __getOwnPropNames(from)) if (!__hasOwnProp.call(to, key) && key !== except) __defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable }); } return to; }; var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod); // src/index.ts var index_exports = {}; __export(index_exports, { createGeminiCliCoreProvider: () => createGeminiProvider, createGeminiProvider: () => createGeminiProvider }); module.exports = __toCommonJS(index_exports); // src/gemini-language-model.ts var import_node_crypto = require("crypto"); // src/client.ts var import_gemini_cli_core = require("@google/gemini-cli-core"); async function initializeGeminiClient(options, modelId) { let authType; if (options.authType === "api-key" || options.authType === "gemini-api-key") { authType = import_gemini_cli_core.AuthType.USE_GEMINI; } else if (options.authType === "vertex-ai") { authType = import_gemini_cli_core.AuthType.USE_VERTEX_AI; } else if (options.authType === "oauth" || options.authType === "oauth-personal") { authType = import_gemini_cli_core.AuthType.LOGIN_WITH_GOOGLE; } else if (options.authType === "google-auth-library") { authType = import_gemini_cli_core.AuthType.USE_GEMINI; } const configMock = { getModel: () => modelId, getProxy: () => options.proxy || process.env.HTTP_PROXY || process.env.HTTPS_PROXY || void 0 }; const config = (0, import_gemini_cli_core.createContentGeneratorConfig)( configMock, authType ); if ((options.authType === "api-key" || options.authType === "gemini-api-key") && options.apiKey) { config.apiKey = options.apiKey; } else if (options.authType === "vertex-ai" && options.vertexAI) { config.vertexai = true; } const client = await (0, import_gemini_cli_core.createContentGenerator)( config, configMock ); return { client, config }; } // src/message-mapper.ts function mapPromptToGeminiFormat(options) { let messages = options.prompt; const contents = []; let systemInstruction; if (options.mode?.type === "object-json" && options.mode.schema && messages.length > 0) { const lastMessage = messages[messages.length - 1]; if (lastMessage.role === "user" && Array.isArray(lastMessage.content)) { const schemaPrompt = ` You must respond with a JSON object that exactly matches this schema: ${JSON.stringify(options.mode.schema, null, 2)} IMPORTANT: Use the exact field names from the schema. Do not add extra fields.`; messages = [...messages]; const lastContent = [...lastMessage.content]; for (let i = lastContent.length - 1; i >= 0; i--) { const content = lastContent[i]; if (content.type === "text") { lastContent[i] = { ...content, text: content.text + schemaPrompt }; break; } } messages[messages.length - 1] = { ...lastMessage, content: lastContent }; } } for (const message of messages) { switch (message.role) { case "system": systemInstruction = { role: "user", parts: [{ text: message.content }] }; break; case "user": contents.push(mapUserMessage(message)); break; case "assistant": contents.push(mapAssistantMessage(message)); break; case "tool": contents.push({ role: "user", parts: message.content.map( (part) => mapToolResultPart(part) ) }); break; } } return { contents, systemInstruction }; } function mapUserMessage(message) { const parts = []; for (const part of message.content) { switch (part.type) { case "text": parts.push({ text: part.text }); break; case "image": parts.push(mapImagePart(part)); break; } } return { role: "user", parts }; } function mapAssistantMessage(message) { const parts = []; for (const part of message.content) { switch (part.type) { case "text": parts.push({ text: part.text }); break; case "tool-call": parts.push({ functionCall: { name: part.toolName, args: part.args || {} } }); break; } } return { role: "model", parts }; } function mapImagePart(part) { if (part.image instanceof URL) { throw new Error( "URL images are not supported by Gemini CLI Core. Please provide base64-encoded image data." ); } const mimeType = part.mimeType || "image/jpeg"; let base64Data; if (typeof part.image === "string") { base64Data = part.image; } else if (part.image instanceof Uint8Array) { base64Data = Buffer.from(part.image).toString("base64"); } else { throw new Error("Unsupported image format"); } return { inlineData: { mimeType, data: base64Data } }; } function mapToolResultPart(part) { return { functionResponse: { name: part.toolName, response: part.result } }; } // src/tool-mapper.ts var import_zod_to_json_schema = require("zod-to-json-schema"); function mapToolsToGeminiFormat(tools) { const functionDeclarations = []; for (const tool of tools) { functionDeclarations.push({ name: tool.name, description: tool.description, parameters: convertToolParameters(tool.parameters) }); } return [{ functionDeclarations }]; } function convertToolParameters(parameters) { if (isJsonSchema(parameters)) { return cleanJsonSchema(parameters); } if (isZodSchema(parameters)) { const jsonSchema = (0, import_zod_to_json_schema.zodToJsonSchema)(parameters); return cleanJsonSchema(jsonSchema); } return parameters; } function isJsonSchema(obj) { return typeof obj === "object" && obj !== null && ("type" in obj || "properties" in obj || "$schema" in obj); } function isZodSchema(obj) { return typeof obj === "object" && obj !== null && "_def" in obj && typeof obj._def === "object"; } function cleanJsonSchema(schema) { if (typeof schema !== "object" || schema === null) { return schema; } const cleaned = { ...schema }; delete cleaned.$schema; delete cleaned.$ref; delete cleaned.$defs; delete cleaned.definitions; if (cleaned.properties && typeof cleaned.properties === "object") { const cleanedProps = {}; for (const [key, value] of Object.entries(cleaned.properties)) { cleanedProps[key] = cleanJsonSchema(value); } cleaned.properties = cleanedProps; } if (cleaned.items) { cleaned.items = cleanJsonSchema(cleaned.items); } if (cleaned.additionalProperties && typeof cleaned.additionalProperties === "object") { cleaned.additionalProperties = cleanJsonSchema( cleaned.additionalProperties ); } for (const key of ["allOf", "anyOf", "oneOf"]) { const arrayProp = cleaned[key]; if (Array.isArray(arrayProp)) { cleaned[key] = arrayProp.map( (item) => cleanJsonSchema(item) ); } } return cleaned; } // src/error.ts var import_provider = require("@ai-sdk/provider"); function mapGeminiError(error) { if (error instanceof Error) { const message = error.message.toLowerCase(); if (message.includes("rate limit") || message.includes("quota")) { return new import_provider.APICallError({ url: "gemini-cli-core", requestBodyValues: {}, statusCode: 429, responseHeaders: {}, message: error.message, cause: error, data: {}, isRetryable: true }); } if (message.includes("unauthorized") || message.includes("authentication") || message.includes("api key")) { return new import_provider.APICallError({ url: "gemini-cli-core", requestBodyValues: {}, statusCode: 401, responseHeaders: {}, message: error.message, cause: error, data: {}, isRetryable: false }); } if (message.includes("invalid") || message.includes("bad request")) { return new import_provider.APICallError({ url: "gemini-cli-core", requestBodyValues: {}, statusCode: 400, responseHeaders: {}, message: error.message, cause: error, data: {}, isRetryable: false }); } if (message.includes("not found") || message.includes("model")) { return new import_provider.APICallError({ url: "gemini-cli-core", requestBodyValues: {}, statusCode: 404, responseHeaders: {}, message: error.message, cause: error, data: {}, isRetryable: false }); } return new import_provider.APICallError({ url: "gemini-cli-core", requestBodyValues: {}, statusCode: 500, responseHeaders: {}, message: error.message, cause: error, data: {}, isRetryable: true }); } return new import_provider.APICallError({ url: "gemini-cli-core", requestBodyValues: {}, statusCode: 500, responseHeaders: {}, message: "An unknown error occurred", cause: error, data: {}, isRetryable: true }); } // src/extract-json.ts function extractJson(text) { let content = text.trim(); const fenceMatch = /```(?:json)?\s*([\s\S]*?)\s*```/i.exec(content); if (fenceMatch) { content = fenceMatch[1]; } const varMatch = /^\s*(?:const|let|var)\s+\w+\s*=\s*([\s\S]*)/i.exec(content); if (varMatch) { content = varMatch[1]; if (content.trim().endsWith(";")) { content = content.trim().slice(0, -1); } } const firstObj = content.indexOf("{"); const firstArr = content.indexOf("["); if (firstObj === -1 && firstArr === -1) { return text; } const start = firstArr === -1 ? firstObj : firstObj === -1 ? firstArr : Math.min(firstObj, firstArr); content = content.slice(start); try { const parsed = JSON.parse(content); return JSON.stringify(parsed); } catch { } const openChar = content[0]; const closeChar = openChar === "{" ? "}" : "]"; const closingPositions = []; let depth = 0; let inString = false; let escapeNext = false; for (let i = 0; i < content.length; i++) { const char = content[i]; if (escapeNext) { escapeNext = false; continue; } if (char === "\\") { escapeNext = true; continue; } if (char === '"' && !inString) { inString = true; continue; } if (char === '"' && inString) { inString = false; continue; } if (inString) continue; if (char === openChar) { depth++; } else if (char === closeChar) { depth--; if (depth === 0) { closingPositions.push(i + 1); } } } for (let i = closingPositions.length - 1; i >= 0; i--) { try { const attempt = content.slice(0, closingPositions[i]); const parsed = JSON.parse(attempt); return JSON.stringify(parsed); } catch { } } return text; } // src/gemini-language-model.ts function mapGeminiFinishReason(geminiReason) { switch (geminiReason) { case "STOP": return "stop"; case "MAX_TOKENS": return "length"; case "SAFETY": case "RECITATION": return "content-filter"; case "OTHER": return "other"; default: return "unknown"; } } var GeminiLanguageModel = class { constructor(options) { this.specificationVersion = "v1"; this.provider = "gemini-cli-core"; this.defaultObjectGenerationMode = "json"; this.supportsImageUrls = false; // CLI Core uses base64 data, not URLs this.supportsStructuredOutputs = true; this.modelId = options.modelId; this.providerOptions = options.providerOptions; this.settings = options.settings; } async ensureInitialized() { if (this.contentGenerator && this.config) { return { contentGenerator: this.contentGenerator, config: this.config }; } if (!this.initPromise) { this.initPromise = this.initialize(); } await this.initPromise; return { contentGenerator: this.contentGenerator, config: this.config }; } async initialize() { try { const { client, config } = await initializeGeminiClient( this.providerOptions, this.modelId ); this.contentGenerator = client; this.config = config; } catch (error) { throw new Error(`Failed to initialize Gemini model: ${String(error)}`); } } /** * Non-streaming generation method */ async doGenerate(options) { try { const { contentGenerator } = await this.ensureInitialized(); const { contents, systemInstruction } = mapPromptToGeminiFormat(options); const generationConfig = { temperature: options.temperature, topP: options.topP, topK: options.topK, maxOutputTokens: options.maxTokens || 65536, // Default to 65536 (64K) - max supported by Gemini 2.5 models stopSequences: options.stopSequences, responseMimeType: options.mode.type === "object-json" ? "application/json" : "text/plain" }; let tools; if (options.mode.type === "regular" && options.mode.tools) { const functionTools = options.mode.tools.filter( (tool) => tool.type === "function" ); if (functionTools.length > 0) { tools = mapToolsToGeminiFormat(functionTools); } } const request = { model: this.modelId, contents, config: generationConfig }; if (systemInstruction) { request.systemInstruction = systemInstruction; } if (tools) { request.tools = tools; } const response = await contentGenerator.generateContent(request); const candidate = response.candidates?.[0]; const responseContent = candidate?.content; let text = responseContent?.parts?.[0]?.text || ""; if (options.mode.type === "object-json" && text) { text = extractJson(text); } let toolCalls; if (responseContent?.parts) { toolCalls = responseContent.parts.filter((part) => !!part.functionCall).map((part) => ({ toolCallType: "function", toolCallId: (0, import_node_crypto.randomUUID)(), toolName: part.functionCall.name || "", args: JSON.stringify(part.functionCall.args || {}) })); } return { text, toolCalls: toolCalls?.length ? toolCalls : void 0, finishReason: mapGeminiFinishReason(candidate?.finishReason), usage: { promptTokens: response.usageMetadata?.promptTokenCount || 0, completionTokens: response.usageMetadata?.candidatesTokenCount || 0 }, rawCall: { rawPrompt: { contents, systemInstruction, generationConfig, tools }, rawSettings: generationConfig }, rawResponse: { body: response }, response: { id: (0, import_node_crypto.randomUUID)(), timestamp: /* @__PURE__ */ new Date(), modelId: this.modelId } }; } catch (error) { throw mapGeminiError(error); } } /** * Streaming generation method */ async doStream(options) { try { const { contentGenerator } = await this.ensureInitialized(); const { contents, systemInstruction } = mapPromptToGeminiFormat(options); const generationConfig = { temperature: options.temperature, topP: options.topP, topK: options.topK, maxOutputTokens: options.maxTokens || 65536, // Default to 65536 (64K) - max supported by Gemini 2.5 models stopSequences: options.stopSequences, responseMimeType: options.mode.type === "object-json" ? "application/json" : "text/plain" }; let tools; if (options.mode.type === "regular" && options.mode.tools) { const functionTools = options.mode.tools.filter( (tool) => tool.type === "function" ); if (functionTools.length > 0) { tools = mapToolsToGeminiFormat(functionTools); } } const request = { model: this.modelId, contents, config: generationConfig }; if (systemInstruction) { request.systemInstruction = systemInstruction; } if (tools) { request.tools = tools; } const streamResponse = await contentGenerator.generateContentStream(request); const stream = new ReadableStream({ async start(controller) { try { let accumulatedText = ""; const isObjectJsonMode = options.mode.type === "object-json"; for await (const chunk of streamResponse) { const candidate = chunk.candidates?.[0]; const content = candidate?.content; if (content?.parts) { for (const part of content.parts) { if (part.text) { if (isObjectJsonMode) { accumulatedText += part.text; } else { controller.enqueue({ type: "text-delta", textDelta: part.text }); } } else if (part.functionCall) { controller.enqueue({ type: "tool-call", toolCallType: "function", toolCallId: (0, import_node_crypto.randomUUID)(), toolName: part.functionCall.name || "", args: JSON.stringify(part.functionCall.args || {}) }); } } } if (candidate?.finishReason) { if (isObjectJsonMode && accumulatedText) { const extractedJson = extractJson(accumulatedText); controller.enqueue({ type: "text-delta", textDelta: extractedJson }); } controller.enqueue({ type: "finish", finishReason: mapGeminiFinishReason(candidate.finishReason), usage: { promptTokens: chunk.usageMetadata?.promptTokenCount || 0, completionTokens: chunk.usageMetadata?.candidatesTokenCount || 0 } }); } } if (isObjectJsonMode && accumulatedText && !controller.desiredSize) { const extractedJson = extractJson(accumulatedText); controller.enqueue({ type: "text-delta", textDelta: extractedJson }); } controller.close(); } catch (error) { controller.error(mapGeminiError(error)); } } }); return { stream, rawCall: { rawPrompt: { contents, systemInstruction, generationConfig, tools }, rawSettings: generationConfig } }; } catch (error) { throw mapGeminiError(error); } } }; // src/validation.ts function validateAuthOptions(options = {}) { const authType = options.authType || "oauth-personal"; switch (authType) { case "api-key": case "gemini-api-key": if (!("apiKey" in options) || !options.apiKey) { throw new Error(`API key is required for ${authType} auth type`); } return { ...options, authType }; case "vertex-ai": if ("vertexAI" in options && options.vertexAI) { if (!options.vertexAI.projectId || options.vertexAI.projectId.trim() === "") { throw new Error("Project ID is required for vertex-ai auth type"); } if (!options.vertexAI.location || options.vertexAI.location.trim() === "") { throw new Error("Location is required for vertex-ai auth type"); } } else { throw new Error( "Vertex AI configuration is required for vertex-ai auth type" ); } return { ...options, authType }; case "oauth": case "oauth-personal": return { ...options, authType }; case "google-auth-library": if (!("googleAuth" in options) || !options.googleAuth) { throw new Error( "Google Auth Library instance is required for google-auth-library auth type" ); } return { ...options, authType }; default: throw new Error(`Invalid auth type: ${String(authType)}`); } } // src/gemini-provider.ts function createGeminiProvider(options = {}) { const validatedOptions = validateAuthOptions(options); const createLanguageModel = (modelId, settings) => { return new GeminiLanguageModel({ modelId, providerOptions: validatedOptions, settings }); }; const provider = function(modelId, settings) { if (new.target) { throw new Error( "The provider function cannot be called with the new keyword." ); } return createLanguageModel(modelId, settings); }; provider.languageModel = createLanguageModel; provider.chat = createLanguageModel; return provider; } // Annotate the CommonJS export names for ESM import in node: 0 && (module.exports = { createGeminiCliCoreProvider, createGeminiProvider }); //# sourceMappingURL=index.js.map