UNPKG

askui

Version:

Reliable, automated end-to-end-testing that depends on what is shown on your screen instead of the technology you are running on

147 lines (146 loc) 7.36 kB
"use strict"; var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); }; var __importDefault = (this && this.__importDefault) || function (mod) { return (mod && mod.__esModule) ? mod : { "default": mod }; }; Object.defineProperty(exports, "__esModule", { value: true }); exports.InferenceClient = void 0; const url_join_1 = __importDefault(require("url-join")); const ui_control_commands_1 = require("../core/ui-control-commands"); const annotation_1 = require("../core/annotation/annotation"); const transformations_1 = require("../utils/transformations"); const inference_response_error_1 = require("./inference-response-error"); const config_error_1 = require("./config-error"); const logger_1 = require("../lib/logger"); class InferenceClient { constructor(baseUrl, httpClient, resize, workspaceId, modelComposition, apiVersion = 'v1') { this.baseUrl = baseUrl; this.httpClient = httpClient; this.resize = resize; this.workspaceId = workspaceId; this.modelComposition = modelComposition; this.apiVersion = apiVersion; const versionedBaseUrl = (0, url_join_1.default)(this.baseUrl, 'api', this.apiVersion); const url = workspaceId ? (0, url_join_1.default)(versionedBaseUrl, 'workspaces', workspaceId) : versionedBaseUrl; this.urls = { actEndpoint: (0, url_join_1.default)(url, 'act', 'inference'), inference: (0, url_join_1.default)(url, 'inference'), isImageRequired: (0, url_join_1.default)(url, 'instruction', 'is-image-required'), vqaInference: (0, url_join_1.default)(url, 'vqa', 'inference'), }; this.httpClient.urlsToRetry = Object.values(this.urls); if (this.resize !== undefined && this.resize <= 0) { throw new config_error_1.ConfigurationError(`Resize must be a positive number. The current resize value "${this.resize}" is not valid.`); } this.resize = this.resize ? Math.ceil(this.resize) : this.resize; } isImageRequired(instruction) { return __awaiter(this, void 0, void 0, function* () { const response = yield this.httpClient.post(this.urls.isImageRequired, { instruction, }); return response.body.isImageRequired; }); } // eslint-disable-next-line class-methods-use-this resizeIfNeeded(customElements, image) { return __awaiter(this, void 0, void 0, function* () { if (!image || customElements.length > 0 || this.resize === undefined) { return { base64Image: image, resizeRatio: 1 }; } return (0, transformations_1.resizeBase64ImageWithSameRatio)(image, this.resize); }); } inference() { return __awaiter(this, arguments, void 0, function* (customElements = [], image, instruction, modelComposition = []) { const resizedImage = yield this.resizeIfNeeded(customElements, image); const response = yield this.httpClient.post(this.urls.inference, this.urls.inference.includes('v4-experimental') ? { image: resizedImage.base64Image, instruction, tasks: ['OCR'], } : { customElements, image: resizedImage.base64Image, instruction, modelComposition: modelComposition.length > 0 ? modelComposition : this.modelComposition, }); InferenceClient.logMetaInformation(response.headers); return ui_control_commands_1.InferenceResponse.fromJson(response.body, resizedImage.resizeRatio, image); }); } vqaInference(image, prompt, config) { return __awaiter(this, void 0, void 0, function* () { const response = yield this.httpClient.post(this.urls.vqaInference, { config, image, prompt, }); InferenceClient.logMetaInformation(response.headers); return response.body; }); } static logMetaInformation(headers) { if (headers['askui-usage-warnings'] !== undefined) { logger_1.logger.warn(headers['askui-usage-warnings']); } } predictControlCommand(instruction_1, modelComposition_1) { return __awaiter(this, arguments, void 0, function* (instruction, modelComposition, customElements = [], image) { const inferenceResponse = yield this.inference(customElements, image, instruction, modelComposition); if (!(inferenceResponse instanceof ui_control_commands_1.ControlCommand)) { throw new inference_response_error_1.InferenceResponseError('Internal Error. Can not execute command'); } return inferenceResponse; }); } getDetectedElements(instruction_1, image_1) { return __awaiter(this, arguments, void 0, function* (instruction, image, customElements = []) { const inferenceResponse = yield this.inference(customElements, image, instruction); if (!(inferenceResponse instanceof annotation_1.Annotation)) { throw new inference_response_error_1.InferenceResponseError('Internal Error. Unable to get the detected elements'); } return inferenceResponse.detected_elements; }); } predictImageAnnotation(image_1) { return __awaiter(this, arguments, void 0, function* (image, customElements = []) { const inferenceResponse = yield this.inference(customElements, image); if (!(inferenceResponse instanceof annotation_1.Annotation)) { throw new inference_response_error_1.InferenceResponseError('Internal Error. Can not execute annotation'); } return inferenceResponse; }); } predictVQAAnswer(prompt, image, config) { return __awaiter(this, void 0, void 0, function* () { const inferenceResponse = yield this.vqaInference(image, prompt, config); const { response } = inferenceResponse.data; try { return JSON.parse(response); } catch (error) { logger_1.logger.warn(`Response is no valid JSON: ${response}`); } return response; }); } predictActResponse(params) { return __awaiter(this, void 0, void 0, function* () { const response = yield this.httpClient.post(this.urls.actEndpoint, params); InferenceClient.logMetaInformation(response.headers); return response.body; }); } } exports.InferenceClient = InferenceClient;