UNPKG

donobu

Version:

Create browser automations with an LLM agent and replay them as Playwright scripts.

416 lines 17.1 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.GoogleGptClient = void 0; const GptModelNotFoundException_1 = require("../exceptions/GptModelNotFoundException"); const GptClient_1 = require("./GptClient"); const JsonUtils_1 = require("../utils/JsonUtils"); const MiscUtils_1 = require("../utils/MiscUtils"); const Logger_1 = require("../utils/Logger"); const GptPlatformAuthenticationFailedException_1 = require("../exceptions/GptPlatformAuthenticationFailedException"); const GptPlatformInternalErrorException_1 = require("../exceptions/GptPlatformInternalErrorException"); const GptPlatformNotReachableException_1 = require("../exceptions/GptPlatformNotReachableException"); /** * Implementation using the Google API * @see https://ai.google.dev/api/all-methods#service:-generativelanguage.googleapis.com */ class GoogleGptClient extends GptClient_1.GptClient { /** * Create a new instance. * @param apiKey The Google API key to use for all requests with this client. * @param modelName See https://ai.google.dev/gemini-api/docs/models/gemini for the list of models. * @param apiUrl The URL of the API to use for all requests with this client. */ constructor(googleGeminiConfig) { super(googleGeminiConfig); this.googleGeminiConfig = googleGeminiConfig; this.headers = new Headers({ [GoogleGptClient.API_KEY_HEADER_NAME]: googleGeminiConfig.apiKey, 'Content-Type': GoogleGptClient.CONTENT_TYPE, }); } async ping() { const resp = await this.makeRequest(`/v1beta/models/${this.googleGeminiConfig.modelName}`); if (resp.status === 404) { throw new GptModelNotFoundException_1.GptModelNotFoundException(this.config.type, this.googleGeminiConfig.modelName); } else if (resp.status !== 200) { throw await this.mapErrorResponseToDonobuException(resp); } } async getMessage(messages) { // Set system prompt if it exists (Google only supports one) const systemMessageRaw = messages.find((msg) => msg.type === 'system'); const mappedMessages = messages .filter((msg) => msg.type !== 'system') .map(GoogleGptClient.chatRequestMessageFromGptMessage); GoogleGptClient.shenanigansUserMessageMerge(mappedMessages); const body = { generationConfig: { temperature: 0.0, maxOutputTokens: 8192, }, contents: mappedMessages, systemInstruction: systemMessageRaw ? { role: 'user', parts: [ { text: systemMessageRaw.text, }, ], } : undefined, }; const resp = await this.makeRequest(`/v1beta/models/${this.googleGeminiConfig.modelName}:generateContent`, 'POST', body); if (resp.status !== 200) { throw await this.mapErrorResponseToDonobuException(resp); } const data = await resp.json(); const text = data.candidates[0].content.parts[0].text; const promptTokensUsed = data.usageMetadata.promptTokenCount; const completionTokensUsed = data.usageMetadata.candidatesTokenCount; return { type: 'assistant', text: text, promptTokensUsed: promptTokensUsed, completionTokensUsed: completionTokensUsed, }; } async getStructuredOutput(messages, jsonSchema) { // Set system prompt if it exists (Google only supports one) const systemMessageRaw = messages.find((msg) => msg.type === 'system'); const mappedMessages = messages .filter((msg) => msg.type !== 'system') .map(GoogleGptClient.chatRequestMessageFromGptMessage); GoogleGptClient.shenanigansUserMessageMerge(mappedMessages); const body = { generationConfig: { temperature: 0.0, maxOutputTokens: 8192, responseMimeType: GoogleGptClient.CONTENT_TYPE, responseSchema: GoogleGptClient.createGoogleCompatibleJsonSchema(jsonSchema), }, contents: mappedMessages, systemInstruction: systemMessageRaw ? { role: 'user', parts: [ { text: systemMessageRaw.text, }, ], } : undefined, }; const resp = await this.makeRequest(`/v1beta/models/${this.googleGeminiConfig.modelName}:generateContent`, 'POST', body); if (resp.status !== 200) { throw await this.mapErrorResponseToDonobuException(resp); } const data = await resp.json(); const text = data.candidates[0].content.parts[0].text.trim(); const respObj = JsonUtils_1.JsonUtils.jsonStringToJsonObject(text); if (!respObj) { throw new Error('Failed to parse response as JSON object'); } const promptTokensUsed = data.usageMetadata.promptTokenCount; const completionTokensUsed = data.usageMetadata.candidatesTokenCount; return { type: 'structured_output', output: respObj, promptTokensUsed: promptTokensUsed, completionTokensUsed: completionTokensUsed, }; } async getToolCalls(messages, tools) { // Set system prompt if it exists (Google only supports one) const systemMessageRaw = messages.find((msg) => msg.type === 'system'); const mappedMessages = messages .filter((msg) => msg.type !== 'system') .map(GoogleGptClient.chatRequestMessageFromGptMessage); GoogleGptClient.shenanigansUserMessageMerge(mappedMessages); const body = { generationConfig: { temperature: 0.0, maxOutputTokens: 8192, }, toolConfig: { functionCallingConfig: { mode: 'ANY', }, }, contents: mappedMessages, tools: tools.length ? [ { functionDeclarations: tools.map(GoogleGptClient.toolChoiceFromTool), }, ] : undefined, systemInstruction: systemMessageRaw ? { role: 'user', parts: [ { text: systemMessageRaw.text, }, ], } : undefined, }; const resp = await this.makeRequest(`/v1beta/models/${this.googleGeminiConfig.modelName}:generateContent`, 'POST', body); if (resp.status !== 200) { throw await this.mapErrorResponseToDonobuException(resp); } const data = await resp.json(); const toolCallsArray = data.candidates[0].content.parts; if (!toolCallsArray?.length) { throw new Error('Unable to extract tool calls from Google response'); } const toolCalls = toolCallsArray.map((tc) => { const functionName = tc.functionCall.name; const parameters = tc.functionCall.args; // Fix the assert1 issue GoogleGptClient.fixAssertFields(parameters); const tool = tools.find((t) => t.name === functionName); if (!tool) { throw new Error(`Tool not found: ${functionName}`); } return { name: functionName, parameters: parameters, toolCallId: MiscUtils_1.MiscUtils.createAdHocToolCallId(), }; }); const promptTokensUsed = data.usageMetadata.promptTokenCount; const completionTokensUsed = data.usageMetadata.candidatesTokenCount; return { type: 'proposed_tool_calls', proposedToolCalls: toolCalls, promptTokensUsed: promptTokensUsed, completionTokensUsed: completionTokensUsed, }; } /** * Fix the "assert1" field issue in tool call parameters * Recursively search and replace any field containing "assert1" with "assert" */ static fixAssertFields(obj) { if (!obj || typeof obj !== 'object') { return; } if (Array.isArray(obj)) { obj.forEach((item) => GoogleGptClient.fixAssertFields(item)); return; } const objRecord = obj; // Check each property name for keys containing "assert1" const keys = Object.keys(objRecord); const keysToRename = []; keys.forEach((key) => { // If the key contains "assert1", prepare to rename it if (key.includes('assert1')) { // Create the new key by replacing all instances of "assert1" with "assert" const newKey = key.replace(/assert1/g, 'assert'); keysToRename.push({ oldKey: key, newKey }); } // Recursively process nested objects if (objRecord[key] && typeof objRecord[key] === 'object') { GoogleGptClient.fixAssertFields(objRecord[key]); } }); // Apply all renames after iterating to avoid modifying the object during iteration keysToRename.forEach(({ oldKey, newKey }) => { const value = objRecord[oldKey]; delete objRecord[oldKey]; objRecord[newKey] = value; Logger_1.appLogger.info(`Fixed field: renamed "${oldKey}" to "${newKey}" in tool call parameters`); }); } async mapErrorResponseToDonobuException(error) { try { const errorData = await error.json(); Logger_1.appLogger.error(`Google error response: ${JSON.stringify(JsonUtils_1.JsonUtils.objectToJson(errorData))}`); if (errorData.error?.details?.[0]?.reason === 'API_KEY_INVALID') { return new GptPlatformAuthenticationFailedException_1.GptPlatformAuthenticationFailedException(this.config.type); } return new GptPlatformInternalErrorException_1.GptPlatformInternalErrorException(errorData.error?.message || `HTTP ${error.status}: ${error.statusText}`); } catch (_) { Logger_1.appLogger.error(`Failed to parse Google error response: HTTP ${error.status}: ${error.statusText}`); return new GptPlatformInternalErrorException_1.GptPlatformInternalErrorException(`HTTP ${error.status}: ${error.statusText}`); } } async makeRequest(endpoint, method = 'GET', body) { try { return await fetch(`${GoogleGptClient.API_URL}${endpoint}`, { method, headers: this.headers, body: body ? JSON.stringify(body) : undefined, signal: AbortSignal.timeout(GoogleGptClient.REQUEST_TIMEOUT_MILLISECONDS), }); } catch (error) { if (error instanceof TypeError) { throw new GptPlatformNotReachableException_1.GptPlatformNotReachableException(this.config.type); } else { throw error; } } } static chatRequestMessageFromGptMessage(gptMessage) { if (gptMessage.type === 'assistant') { return { role: 'model', parts: [ { text: gptMessage.text, }, ], }; } if (gptMessage.type === 'structured_output') { return { role: 'model', parts: [ { text: JSON.stringify(gptMessage.output, null, 2), }, ], }; } if (gptMessage.type === 'proposed_tool_calls') { return { role: 'model', parts: gptMessage.proposedToolCalls.map((tc) => ({ functionCall: { name: tc.name, args: JsonUtils_1.JsonUtils.objectToJson(tc.parameters), }, })), }; } if (gptMessage.type === 'system') { throw new Error('Inline system messages must be filtered out for Google calls!'); } if (gptMessage.type === 'user') { return { role: 'user', parts: gptMessage.items.map((item) => { if ('bytes' in item) { // PNG return { inlineData: { mimeType: 'image/png', data: Buffer.from(item.bytes).toString('base64'), }, }; } return { text: item.text }; // Text }), }; } if (gptMessage.type === 'tool_call_result') { return { role: 'user', parts: [ { functionResponse: { name: gptMessage.toolName, response: { value: gptMessage.data, }, }, }, ], }; } throw new Error(`Unsupported message type: ${JsonUtils_1.JsonUtils.objectToJson(gptMessage)}`); } static toolChoiceFromTool(tool) { const result = { name: tool.name, description: tool.description, }; // Only set parameters if there are properties if (tool.inputSchema.properties && Object.keys(tool.inputSchema.properties).length > 0) { result.parameters = GoogleGptClient.createGoogleCompatibleJsonSchema(tool.inputSchema); } return result; } /** * Creates a new JSON schema compatible with Google's API. */ static createGoogleCompatibleJsonSchema(schema) { // Deep clone the schema to avoid modifying the original. const clonedSchema = JSON.parse(JSON.stringify(schema)); GoogleGptClient.removeAdditionalProperties(clonedSchema); GoogleGptClient.normalizeTypeFields(clonedSchema); return clonedSchema; } /** * Recursively removes all "additionalProperties" fields from the JSON * schema since Google does not support that keyword. **/ static removeAdditionalProperties(node) { if (!node || typeof node !== 'object') { return; } if (Array.isArray(node)) { node.forEach((item) => GoogleGptClient.removeAdditionalProperties(item)); return; } delete node['additionalProperties']; Object.values(node).forEach((value) => { if (value && typeof value === 'object') { GoogleGptClient.removeAdditionalProperties(value); } }); } /** * Merges adjacent user messages because Google will otherwise reject the request. */ static shenanigansUserMessageMerge(messages) { for (let i = messages.length - 1; i > 0; i--) { const message = messages[i]; const adjacentMessage = messages[i - 1]; if (message.role === 'user' && adjacentMessage.role === 'user') { adjacentMessage.parts.push(...message.parts); messages.splice(i, 1); } } } /** * Recursively replaces "type" fields that are arrays containing "null" with * a single non-null type since Google will otherwise throw an error. **/ static normalizeTypeFields(node) { if (!node || typeof node !== 'object') { return; } if (Array.isArray(node)) { node.forEach((item) => GoogleGptClient.normalizeTypeFields(item)); return; } const obj = node; if (Array.isArray(obj.type)) { // Find first non-null type const nonNullType = obj.type.find((type) => type !== 'null'); if (nonNullType) { obj.type = nonNullType; } } Object.values(obj).forEach((value) => { if (value && typeof value === 'object') { GoogleGptClient.normalizeTypeFields(value); } }); } } exports.GoogleGptClient = GoogleGptClient; GoogleGptClient.API_URL = 'https://generativelanguage.googleapis.com'; GoogleGptClient.REQUEST_TIMEOUT_MILLISECONDS = 120000; GoogleGptClient.API_KEY_HEADER_NAME = 'x-goog-api-key'; GoogleGptClient.CONTENT_TYPE = 'application/json'; //# sourceMappingURL=GoogleGptClient.js.map