donobu
Version:
Create browser automations with an LLM agent and replay them as Playwright scripts.
416 lines • 17.1 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.GoogleGptClient = void 0;
const GptModelNotFoundException_1 = require("../exceptions/GptModelNotFoundException");
const GptClient_1 = require("./GptClient");
const JsonUtils_1 = require("../utils/JsonUtils");
const MiscUtils_1 = require("../utils/MiscUtils");
const Logger_1 = require("../utils/Logger");
const GptPlatformAuthenticationFailedException_1 = require("../exceptions/GptPlatformAuthenticationFailedException");
const GptPlatformInternalErrorException_1 = require("../exceptions/GptPlatformInternalErrorException");
const GptPlatformNotReachableException_1 = require("../exceptions/GptPlatformNotReachableException");
/**
* Implementation using the Google API
* @see https://ai.google.dev/api/all-methods#service:-generativelanguage.googleapis.com
*/
class GoogleGptClient extends GptClient_1.GptClient {
/**
* Create a new instance.
* @param apiKey The Google API key to use for all requests with this client.
* @param modelName See https://ai.google.dev/gemini-api/docs/models/gemini for the list of models.
* @param apiUrl The URL of the API to use for all requests with this client.
*/
constructor(googleGeminiConfig) {
super(googleGeminiConfig);
this.googleGeminiConfig = googleGeminiConfig;
this.headers = new Headers({
[GoogleGptClient.API_KEY_HEADER_NAME]: googleGeminiConfig.apiKey,
'Content-Type': GoogleGptClient.CONTENT_TYPE,
});
}
async ping() {
const resp = await this.makeRequest(`/v1beta/models/${this.googleGeminiConfig.modelName}`);
if (resp.status === 404) {
throw new GptModelNotFoundException_1.GptModelNotFoundException(this.config.type, this.googleGeminiConfig.modelName);
}
else if (resp.status !== 200) {
throw await this.mapErrorResponseToDonobuException(resp);
}
}
async getMessage(messages) {
// Set system prompt if it exists (Google only supports one)
const systemMessageRaw = messages.find((msg) => msg.type === 'system');
const mappedMessages = messages
.filter((msg) => msg.type !== 'system')
.map(GoogleGptClient.chatRequestMessageFromGptMessage);
GoogleGptClient.shenanigansUserMessageMerge(mappedMessages);
const body = {
generationConfig: {
temperature: 0.0,
maxOutputTokens: 8192,
},
contents: mappedMessages,
systemInstruction: systemMessageRaw
? {
role: 'user',
parts: [
{
text: systemMessageRaw.text,
},
],
}
: undefined,
};
const resp = await this.makeRequest(`/v1beta/models/${this.googleGeminiConfig.modelName}:generateContent`, 'POST', body);
if (resp.status !== 200) {
throw await this.mapErrorResponseToDonobuException(resp);
}
const data = await resp.json();
const text = data.candidates[0].content.parts[0].text;
const promptTokensUsed = data.usageMetadata.promptTokenCount;
const completionTokensUsed = data.usageMetadata.candidatesTokenCount;
return {
type: 'assistant',
text: text,
promptTokensUsed: promptTokensUsed,
completionTokensUsed: completionTokensUsed,
};
}
async getStructuredOutput(messages, jsonSchema) {
// Set system prompt if it exists (Google only supports one)
const systemMessageRaw = messages.find((msg) => msg.type === 'system');
const mappedMessages = messages
.filter((msg) => msg.type !== 'system')
.map(GoogleGptClient.chatRequestMessageFromGptMessage);
GoogleGptClient.shenanigansUserMessageMerge(mappedMessages);
const body = {
generationConfig: {
temperature: 0.0,
maxOutputTokens: 8192,
responseMimeType: GoogleGptClient.CONTENT_TYPE,
responseSchema: GoogleGptClient.createGoogleCompatibleJsonSchema(jsonSchema),
},
contents: mappedMessages,
systemInstruction: systemMessageRaw
? {
role: 'user',
parts: [
{
text: systemMessageRaw.text,
},
],
}
: undefined,
};
const resp = await this.makeRequest(`/v1beta/models/${this.googleGeminiConfig.modelName}:generateContent`, 'POST', body);
if (resp.status !== 200) {
throw await this.mapErrorResponseToDonobuException(resp);
}
const data = await resp.json();
const text = data.candidates[0].content.parts[0].text.trim();
const respObj = JsonUtils_1.JsonUtils.jsonStringToJsonObject(text);
if (!respObj) {
throw new Error('Failed to parse response as JSON object');
}
const promptTokensUsed = data.usageMetadata.promptTokenCount;
const completionTokensUsed = data.usageMetadata.candidatesTokenCount;
return {
type: 'structured_output',
output: respObj,
promptTokensUsed: promptTokensUsed,
completionTokensUsed: completionTokensUsed,
};
}
async getToolCalls(messages, tools) {
// Set system prompt if it exists (Google only supports one)
const systemMessageRaw = messages.find((msg) => msg.type === 'system');
const mappedMessages = messages
.filter((msg) => msg.type !== 'system')
.map(GoogleGptClient.chatRequestMessageFromGptMessage);
GoogleGptClient.shenanigansUserMessageMerge(mappedMessages);
const body = {
generationConfig: {
temperature: 0.0,
maxOutputTokens: 8192,
},
toolConfig: {
functionCallingConfig: {
mode: 'ANY',
},
},
contents: mappedMessages,
tools: tools.length
? [
{
functionDeclarations: tools.map(GoogleGptClient.toolChoiceFromTool),
},
]
: undefined,
systemInstruction: systemMessageRaw
? {
role: 'user',
parts: [
{
text: systemMessageRaw.text,
},
],
}
: undefined,
};
const resp = await this.makeRequest(`/v1beta/models/${this.googleGeminiConfig.modelName}:generateContent`, 'POST', body);
if (resp.status !== 200) {
throw await this.mapErrorResponseToDonobuException(resp);
}
const data = await resp.json();
const toolCallsArray = data.candidates[0].content.parts;
if (!toolCallsArray?.length) {
throw new Error('Unable to extract tool calls from Google response');
}
const toolCalls = toolCallsArray.map((tc) => {
const functionName = tc.functionCall.name;
const parameters = tc.functionCall.args;
// Fix the assert1 issue
GoogleGptClient.fixAssertFields(parameters);
const tool = tools.find((t) => t.name === functionName);
if (!tool) {
throw new Error(`Tool not found: ${functionName}`);
}
return {
name: functionName,
parameters: parameters,
toolCallId: MiscUtils_1.MiscUtils.createAdHocToolCallId(),
};
});
const promptTokensUsed = data.usageMetadata.promptTokenCount;
const completionTokensUsed = data.usageMetadata.candidatesTokenCount;
return {
type: 'proposed_tool_calls',
proposedToolCalls: toolCalls,
promptTokensUsed: promptTokensUsed,
completionTokensUsed: completionTokensUsed,
};
}
/**
* Fix the "assert1" field issue in tool call parameters
* Recursively search and replace any field containing "assert1" with "assert"
*/
static fixAssertFields(obj) {
if (!obj || typeof obj !== 'object') {
return;
}
if (Array.isArray(obj)) {
obj.forEach((item) => GoogleGptClient.fixAssertFields(item));
return;
}
const objRecord = obj;
// Check each property name for keys containing "assert1"
const keys = Object.keys(objRecord);
const keysToRename = [];
keys.forEach((key) => {
// If the key contains "assert1", prepare to rename it
if (key.includes('assert1')) {
// Create the new key by replacing all instances of "assert1" with "assert"
const newKey = key.replace(/assert1/g, 'assert');
keysToRename.push({ oldKey: key, newKey });
}
// Recursively process nested objects
if (objRecord[key] && typeof objRecord[key] === 'object') {
GoogleGptClient.fixAssertFields(objRecord[key]);
}
});
// Apply all renames after iterating to avoid modifying the object during iteration
keysToRename.forEach(({ oldKey, newKey }) => {
const value = objRecord[oldKey];
delete objRecord[oldKey];
objRecord[newKey] = value;
Logger_1.appLogger.info(`Fixed field: renamed "${oldKey}" to "${newKey}" in tool call parameters`);
});
}
async mapErrorResponseToDonobuException(error) {
try {
const errorData = await error.json();
Logger_1.appLogger.error(`Google error response: ${JSON.stringify(JsonUtils_1.JsonUtils.objectToJson(errorData))}`);
if (errorData.error?.details?.[0]?.reason === 'API_KEY_INVALID') {
return new GptPlatformAuthenticationFailedException_1.GptPlatformAuthenticationFailedException(this.config.type);
}
return new GptPlatformInternalErrorException_1.GptPlatformInternalErrorException(errorData.error?.message || `HTTP ${error.status}: ${error.statusText}`);
}
catch (_) {
Logger_1.appLogger.error(`Failed to parse Google error response: HTTP ${error.status}: ${error.statusText}`);
return new GptPlatformInternalErrorException_1.GptPlatformInternalErrorException(`HTTP ${error.status}: ${error.statusText}`);
}
}
async makeRequest(endpoint, method = 'GET', body) {
try {
return await fetch(`${GoogleGptClient.API_URL}${endpoint}`, {
method,
headers: this.headers,
body: body ? JSON.stringify(body) : undefined,
signal: AbortSignal.timeout(GoogleGptClient.REQUEST_TIMEOUT_MILLISECONDS),
});
}
catch (error) {
if (error instanceof TypeError) {
throw new GptPlatformNotReachableException_1.GptPlatformNotReachableException(this.config.type);
}
else {
throw error;
}
}
}
static chatRequestMessageFromGptMessage(gptMessage) {
if (gptMessage.type === 'assistant') {
return {
role: 'model',
parts: [
{
text: gptMessage.text,
},
],
};
}
if (gptMessage.type === 'structured_output') {
return {
role: 'model',
parts: [
{
text: JSON.stringify(gptMessage.output, null, 2),
},
],
};
}
if (gptMessage.type === 'proposed_tool_calls') {
return {
role: 'model',
parts: gptMessage.proposedToolCalls.map((tc) => ({
functionCall: {
name: tc.name,
args: JsonUtils_1.JsonUtils.objectToJson(tc.parameters),
},
})),
};
}
if (gptMessage.type === 'system') {
throw new Error('Inline system messages must be filtered out for Google calls!');
}
if (gptMessage.type === 'user') {
return {
role: 'user',
parts: gptMessage.items.map((item) => {
if ('bytes' in item) {
// PNG
return {
inlineData: {
mimeType: 'image/png',
data: Buffer.from(item.bytes).toString('base64'),
},
};
}
return { text: item.text }; // Text
}),
};
}
if (gptMessage.type === 'tool_call_result') {
return {
role: 'user',
parts: [
{
functionResponse: {
name: gptMessage.toolName,
response: {
value: gptMessage.data,
},
},
},
],
};
}
throw new Error(`Unsupported message type: ${JsonUtils_1.JsonUtils.objectToJson(gptMessage)}`);
}
static toolChoiceFromTool(tool) {
const result = {
name: tool.name,
description: tool.description,
};
// Only set parameters if there are properties
if (tool.inputSchema.properties &&
Object.keys(tool.inputSchema.properties).length > 0) {
result.parameters = GoogleGptClient.createGoogleCompatibleJsonSchema(tool.inputSchema);
}
return result;
}
/**
* Creates a new JSON schema compatible with Google's API.
*/
static createGoogleCompatibleJsonSchema(schema) {
// Deep clone the schema to avoid modifying the original.
const clonedSchema = JSON.parse(JSON.stringify(schema));
GoogleGptClient.removeAdditionalProperties(clonedSchema);
GoogleGptClient.normalizeTypeFields(clonedSchema);
return clonedSchema;
}
/**
* Recursively removes all "additionalProperties" fields from the JSON
* schema since Google does not support that keyword.
**/
static removeAdditionalProperties(node) {
if (!node || typeof node !== 'object') {
return;
}
if (Array.isArray(node)) {
node.forEach((item) => GoogleGptClient.removeAdditionalProperties(item));
return;
}
delete node['additionalProperties'];
Object.values(node).forEach((value) => {
if (value && typeof value === 'object') {
GoogleGptClient.removeAdditionalProperties(value);
}
});
}
/**
* Merges adjacent user messages because Google will otherwise reject the request.
*/
static shenanigansUserMessageMerge(messages) {
for (let i = messages.length - 1; i > 0; i--) {
const message = messages[i];
const adjacentMessage = messages[i - 1];
if (message.role === 'user' && adjacentMessage.role === 'user') {
adjacentMessage.parts.push(...message.parts);
messages.splice(i, 1);
}
}
}
/**
* Recursively replaces "type" fields that are arrays containing "null" with
* a single non-null type since Google will otherwise throw an error.
**/
static normalizeTypeFields(node) {
if (!node || typeof node !== 'object') {
return;
}
if (Array.isArray(node)) {
node.forEach((item) => GoogleGptClient.normalizeTypeFields(item));
return;
}
const obj = node;
if (Array.isArray(obj.type)) {
// Find first non-null type
const nonNullType = obj.type.find((type) => type !== 'null');
if (nonNullType) {
obj.type = nonNullType;
}
}
Object.values(obj).forEach((value) => {
if (value && typeof value === 'object') {
GoogleGptClient.normalizeTypeFields(value);
}
});
}
}
exports.GoogleGptClient = GoogleGptClient;
GoogleGptClient.API_URL = 'https://generativelanguage.googleapis.com';
GoogleGptClient.REQUEST_TIMEOUT_MILLISECONDS = 120000;
GoogleGptClient.API_KEY_HEADER_NAME = 'x-goog-api-key';
GoogleGptClient.CONTENT_TYPE = 'application/json';
//# sourceMappingURL=GoogleGptClient.js.map