UNPKG

@focusconsulting/auto-a11y

Version:

A powerful tool that combines AI with accessibility-first element selection for Playwright tests

403 lines (402 loc) 18.9 kB
"use strict"; var __importDefault = (this && this.__importDefault) || function (mod) { return (mod && mod.__esModule) ? mod : { "default": mod }; }; Object.defineProperty(exports, "__esModule", { value: true }); exports.A11yAILocator = void 0; exports.createA11yAILocator = createA11yAILocator; const ollama_1 = require("ollama"); const sdk_1 = __importDefault(require("@anthropic-ai/sdk")); const openai_1 = __importDefault(require("openai")); const generative_ai_1 = require("@google/generative-ai"); const snapshot_manager_1 = require("./snapshot-manager"); const prompt_1 = require("./prompt"); const sanitize_html_1 = require("./sanitize-html"); const zod_to_json_schema_1 = __importDefault(require("zod-to-json-schema")); class A11yAILocator { constructor(page, testInfo, options) { this.ollama = null; this.anthropic = null; this.openai = null; this.googleAI = null; this.useSimplifiedHtml = false; this.snapshotFilePath = null; this.cachedBodyContent = null; this.lastHtml = null; this.useSimplifiedHtml = options.useSimplifiedHtml || false; this.testInstance = options.testInstance || null; this.page = page; this.aiProvider = options.provider; // Set default models based on provider if not specified if (!options.model) { switch (options.provider) { case "anthropic": this.model = "claude-3-haiku-20240307"; break; case "openai": this.model = "gpt-4o-mini"; break; case "gemini": this.model = "gemini-2.5-pro-exp-03-25"; break; case "deepseek": this.model = "deepseek-chat"; break; case "bedrock": throw new Error("Model must be specified for Bedrock provider"); case "ollama": throw new Error("Model must be specified for Ollama provider"); default: throw new Error(`Unknown provider: ${options.provider}`); } } else { this.model = options.model; } // Initialize the appropriate client based on the provider switch (this.aiProvider) { case "anthropic": const anthropicApiKey = options.apiKey || process.env.ANTHROPIC_API_KEY; if (!anthropicApiKey) { throw new Error("Anthropic API key is required. Provide it via options.apiKey or ANTHROPIC_API_KEY environment variable."); } this.anthropic = new sdk_1.default({ apiKey: anthropicApiKey, }); break; case "openai": const openaiApiKey = options.apiKey || process.env.OPENAI_API_KEY; if (!openaiApiKey) { throw new Error("OpenAI API key is required. Provide it via options.apiKey or OPENAI_API_KEY environment variable."); } this.openai = new openai_1.default({ apiKey: openaiApiKey, baseURL: options.baseUrl, // Allow overriding for Azure OpenAI etc. }); break; case "gemini": const geminiApiKey = options.apiKey || process.env.GEMINI_API_KEY; if (!geminiApiKey) { throw new Error("Gemini API key is required. Provide it via options.apiKey or GEMINI_API_KEY environment variable."); } this.googleAI = new generative_ai_1.GoogleGenerativeAI(geminiApiKey); break; case "deepseek": // DeepSeek uses OpenAI compatible API const deepseekApiKey = options.apiKey || process.env.DEEPSEEK_API_KEY; if (!deepseekApiKey) { throw new Error("DeepSeek API key is required. Provide it via options.apiKey or DEEPSEEK_API_KEY environment variable."); } this.openai = new openai_1.default({ apiKey: deepseekApiKey, baseURL: options.baseUrl || "https://api.deepseek.com/v1", // Default DeepSeek API endpoint }); break; case "bedrock": if (!options.model) { throw new Error("Model must be specified for Bedrock provider"); } const bedrockApiKey = options.apiKey || process.env.AWS_ACCESS_KEY_ID; const bedrockSecretKey = process.env.AWS_SECRET_ACCESS_KEY; if (!bedrockApiKey || !bedrockSecretKey) { throw new Error("AWS credentials are required for Bedrock. Provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."); } // Bedrock uses OpenAI compatible client this.openai = new openai_1.default({ apiKey: bedrockApiKey, baseURL: options.baseUrl || `https://bedrock-runtime.us-east-1.amazonaws.com/model/${this.model}`, }); break; case "ollama": if (!options.model) { throw new Error("Model must be specified for Ollama provider"); } this.ollama = new ollama_1.Ollama({ host: options.baseUrl || "http://localhost:11434", // Ollama host }); break; } // Default to test name if available, otherwise use provided path or null if (options.snapshotFilePath) { this.snapshotFilePath = options.snapshotFilePath; } else { this.snapshotFilePath = snapshot_manager_1.SnapshotManager.createSnapshotPath(testInfo); } // Initialize the snapshot manager this.snapshotManager = new snapshot_manager_1.SnapshotManager(this.snapshotFilePath); } /** * Creates a locator using AI to determine the best Testing Library query * @param description Human description of the element to find * @returns Playwright Locator for the element */ async locate(description) { // Check if we have a saved snapshot for this description const snapshots = this.snapshotManager.readSnapshots(); if (snapshots[description]) { const locatorQuery = snapshots[description]; const locator = this.executeTestingLibraryQuery(locatorQuery); const validSnapshotLocator = await this.testInstance.step(`auto-a11y: attempting to use locator snapshot: ${locatorQuery.query}, ${locatorQuery.params.join(",")}`, async () => { // Verify the locator exists on the page const count = await locator.count(); if (count > 0) { return locator; } else { return null; } }); if (validSnapshotLocator) { return validSnapshotLocator; } } // Get the current page HTML const html = await this.page.content(); // Extract and sanitize only the body content let bodyContent; // Use cached body content if HTML hasn't changed if (this.lastHtml === html && this.cachedBodyContent) { bodyContent = this.cachedBodyContent; } else { bodyContent = (0, sanitize_html_1.extractBodyContent)(html); // Cache the results this.lastHtml = html; this.cachedBodyContent = bodyContent; } // Create the prompt with the description and body content const prompt = (0, prompt_1.createLocatorPrompt)(description, bodyContent); try { const locatorQuery = this.testInstance ? await this.testInstance.step(`auto-a11y locating: ${description}`, async () => { // Make the AI request const queryInfo = await this.executePrompt(prompt, { systemPrompt: "You must always return the COMPLETE text content for getByText queries, never partial matches. For example, if the element contains 'Yes, you can', you must return the entire text 'Yes, you can', not just 'Yes'.", }); return prompt_1.LocatorQuerySchema.parse(JSON.parse(queryInfo)); }) : // If no test instance is provided, execute without the step wrapper prompt_1.LocatorQuerySchema.parse(JSON.parse(await this.executePrompt(prompt, { systemPrompt: "You must always return the COMPLETE text content for getByText queries, never partial matches. For example, if the element contains 'Yes, you can', you must return the entire text 'Yes, you can', not just 'Yes'.", }))); // Save the snapshot for future use this.snapshotManager.saveSnapshot(description, locatorQuery); // Execute the appropriate Testing Library query return this.executeTestingLibraryQuery(locatorQuery); } catch (error) { console.warn(`AI request failed: ${error}.`); // Check if simplified HTML approach should be used if (this.useSimplifiedHtml) { try { const locatorQuery = await this.locateWithSimplifiedHTML(description, html); // Save the snapshot for future use this.snapshotManager.saveSnapshot(description, locatorQuery); // Execute the appropriate Testing Library query return this.executeTestingLibraryQuery(locatorQuery); } catch (fallbackError) { console.error(`Simplified HTML approach also failed: ${fallbackError}`); } } // Last resort: try a simple text search console.warn(`Falling back to simple text search for: "${description}"`); return this.page.getByText(description, { exact: false }); } } /** * Attempts to locate an element using simplified HTML when the main approach times out * @param description The element description * @param html The full HTML content * @returns Object containing query name and parameters */ async locateWithSimplifiedHTML(description, html) { // Create a much more simplified version of the HTML const simplifiedHTML = (0, sanitize_html_1.simplifyHtml)(html); // Create a simplified prompt const prompt = (0, prompt_1.createSimpleLocatorPrompt)(description, simplifiedHTML); // Get the query suggestion from the appropriate AI provider const queryInfo = await this.executePrompt(prompt, { systemPrompt: "Return only the query name and parameters. Be concise.", }); const locatorQuery = prompt_1.LocatorQuerySchema.parse(JSON.parse(queryInfo)); return locatorQuery; } /** * Executes a prompt against the configured AI provider * @param prompt The prompt to send to the AI * @param options Additional options for the AI request * @returns The AI response as a string */ async executePrompt(prompt, options = {}) { if (this.aiProvider === "anthropic" && this.anthropic) { const responsePromise = this.anthropic.messages.create({ model: this.model, max_tokens: 1024, system: options.systemPrompt || "Return only the query name and parameters. Be concise.", messages: [ { role: "user", content: prompt }, { role: "assistant", content: "{" }, ], }); const response = await responsePromise; const textContent = response.content.find((item) => item.type === "text"); if (textContent && "text" in textContent) { return textContent.text.trim(); } else { throw new Error("No text content found in Anthropic response"); } } else if (this.aiProvider === "openai" && this.openai) { const responsePromise = this.openai.responses.create({ model: this.model, text: { format: { type: "json_schema", name: "locatorQuerySchema", schema: (0, zod_to_json_schema_1.default)(prompt_1.LocatorQuerySchema) }, }, input: [ { role: "system", content: options.systemPrompt || "Return only the query name and parameters. Be concise.", }, { role: "user", content: prompt }, { role: "assistant", content: "{" }, ], }); const response = await responsePromise; return response.output[0].type === "message" && response.output[0].content[0].type === "output_text" ? response.output[0].content[0].text : ""; } else if (this.aiProvider === "deepseek" && this.openai) { // DeepSeek uses OpenAI compatible API this.openai.chat.completions.create({ model: this.model, messages: [ { role: "system", content: options.systemPrompt || "Return only the query name and parameters. Be concise.", }, { role: "user", content: prompt }, { role: "assistant", content: "{" }, ], response_format: { type: "json_object" } }); const responsePromise = this.openai.chat.completions.create({ model: this.model, messages: [ { role: "system", content: options.systemPrompt || "Return only the query name and parameters. Be concise.", }, { role: "user", content: prompt }, { role: "assistant", content: "{" }, ], response_format: { type: "json_object" } }); const response = await responsePromise; return response.choices[0]?.message?.content || ""; } else if (this.aiProvider === "gemini" && this.googleAI) { const genAI = this.googleAI; const model = genAI.getGenerativeModel({ model: this.model, safetySettings: [ { category: generative_ai_1.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold: generative_ai_1.HarmBlockThreshold.BLOCK_NONE, }, { category: generative_ai_1.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold: generative_ai_1.HarmBlockThreshold.BLOCK_NONE, }, { category: generative_ai_1.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold: generative_ai_1.HarmBlockThreshold.BLOCK_NONE, }, { category: generative_ai_1.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold: generative_ai_1.HarmBlockThreshold.BLOCK_NONE, }, ], }); const responsePromise = model.generateContent({ contents: [ { role: "user", parts: [{ text: prompt }] }, { role: "model", parts: [{ text: "{" }] }, ], generationConfig: { responseMimeType: "application/json", }, }); const response = await responsePromise; console.log(response.response.text().trim()); return response.response.text().trim(); } else if (this.ollama) { const responsePromise = this.ollama.chat({ model: this.model, format: (0, zod_to_json_schema_1.default)(prompt_1.LocatorQuerySchema), options: { num_ctx: 8192 + prompt.length, }, messages: [ { role: "user", content: prompt }, { role: "assistant", content: "{" }, ], }); const response = await responsePromise; return response.message.content.trim(); } else { throw new Error("No AI provider configured"); } } /** * Executes the appropriate Testing Library query based on the AI suggestion * @param queryName The name of the Testing Library query * @param params The parameters for the query * @returns Playwright Locator */ executeTestingLibraryQuery(query) { switch (query.query.toLowerCase()) { case "getbyrole": // First param is role, second is name (optional) if (query.params.length > 1) { return this.page.getByRole(query.params[0], { name: query.params[1], }); } return this.page.getByRole(query.params[0]); case "getbytext": return this.page.getByText(query.params[0], { exact: false }); case "getbylabeltext": return this.page.getByLabel(query.params[0]); case "getbyplaceholdertext": return this.page.getByPlaceholder(query.params[0]); case "getbytestid": return this.page.getByTestId(query.params[0]); case "getbyalttext": return this.page.getByAltText(query.params[0]); default: // Fallback to a basic text search if the query type is not recognized return this.page.getByText(query.params[0]); } } } exports.A11yAILocator = A11yAILocator; // Helper function to create an A11yAILocator instance function createA11yAILocator(page, testInfo, testInstance, options) { return new A11yAILocator(page, testInfo, { ...options, testInstance }); }