askui
Version:
Reliable, automated end-to-end-testing that depends on what is shown on your screen instead of the technology you are running on
147 lines (146 loc) • 7.36 kB
JavaScript
"use strict";
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
return new (P || (P = Promise))(function (resolve, reject) {
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
};
var __importDefault = (this && this.__importDefault) || function (mod) {
return (mod && mod.__esModule) ? mod : { "default": mod };
};
Object.defineProperty(exports, "__esModule", { value: true });
exports.InferenceClient = void 0;
const url_join_1 = __importDefault(require("url-join"));
const ui_control_commands_1 = require("../core/ui-control-commands");
const annotation_1 = require("../core/annotation/annotation");
const transformations_1 = require("../utils/transformations");
const inference_response_error_1 = require("./inference-response-error");
const config_error_1 = require("./config-error");
const logger_1 = require("../lib/logger");
class InferenceClient {
constructor(baseUrl, httpClient, resize, workspaceId, modelComposition, apiVersion = 'v1') {
this.baseUrl = baseUrl;
this.httpClient = httpClient;
this.resize = resize;
this.workspaceId = workspaceId;
this.modelComposition = modelComposition;
this.apiVersion = apiVersion;
const versionedBaseUrl = (0, url_join_1.default)(this.baseUrl, 'api', this.apiVersion);
const url = workspaceId
? (0, url_join_1.default)(versionedBaseUrl, 'workspaces', workspaceId)
: versionedBaseUrl;
this.urls = {
actEndpoint: (0, url_join_1.default)(url, 'act', 'inference'),
inference: (0, url_join_1.default)(url, 'inference'),
isImageRequired: (0, url_join_1.default)(url, 'instruction', 'is-image-required'),
vqaInference: (0, url_join_1.default)(url, 'vqa', 'inference'),
};
this.httpClient.urlsToRetry = Object.values(this.urls);
if (this.resize !== undefined && this.resize <= 0) {
throw new config_error_1.ConfigurationError(`Resize must be a positive number. The current resize value "${this.resize}" is not valid.`);
}
this.resize = this.resize ? Math.ceil(this.resize) : this.resize;
}
isImageRequired(instruction) {
return __awaiter(this, void 0, void 0, function* () {
const response = yield this.httpClient.post(this.urls.isImageRequired, {
instruction,
});
return response.body.isImageRequired;
});
}
// eslint-disable-next-line class-methods-use-this
resizeIfNeeded(customElements, image) {
return __awaiter(this, void 0, void 0, function* () {
if (!image || customElements.length > 0 || this.resize === undefined) {
return { base64Image: image, resizeRatio: 1 };
}
return (0, transformations_1.resizeBase64ImageWithSameRatio)(image, this.resize);
});
}
inference() {
return __awaiter(this, arguments, void 0, function* (customElements = [], image, instruction, modelComposition = []) {
const resizedImage = yield this.resizeIfNeeded(customElements, image);
const response = yield this.httpClient.post(this.urls.inference, this.urls.inference.includes('v4-experimental') ? {
image: resizedImage.base64Image,
instruction,
tasks: ['OCR'],
}
: {
customElements,
image: resizedImage.base64Image,
instruction,
modelComposition: modelComposition.length > 0 ? modelComposition : this.modelComposition,
});
InferenceClient.logMetaInformation(response.headers);
return ui_control_commands_1.InferenceResponse.fromJson(response.body, resizedImage.resizeRatio, image);
});
}
vqaInference(image, prompt, config) {
return __awaiter(this, void 0, void 0, function* () {
const response = yield this.httpClient.post(this.urls.vqaInference, {
config,
image,
prompt,
});
InferenceClient.logMetaInformation(response.headers);
return response.body;
});
}
static logMetaInformation(headers) {
if (headers['askui-usage-warnings'] !== undefined) {
logger_1.logger.warn(headers['askui-usage-warnings']);
}
}
predictControlCommand(instruction_1, modelComposition_1) {
return __awaiter(this, arguments, void 0, function* (instruction, modelComposition, customElements = [], image) {
const inferenceResponse = yield this.inference(customElements, image, instruction, modelComposition);
if (!(inferenceResponse instanceof ui_control_commands_1.ControlCommand)) {
throw new inference_response_error_1.InferenceResponseError('Internal Error. Can not execute command');
}
return inferenceResponse;
});
}
getDetectedElements(instruction_1, image_1) {
return __awaiter(this, arguments, void 0, function* (instruction, image, customElements = []) {
const inferenceResponse = yield this.inference(customElements, image, instruction);
if (!(inferenceResponse instanceof annotation_1.Annotation)) {
throw new inference_response_error_1.InferenceResponseError('Internal Error. Unable to get the detected elements');
}
return inferenceResponse.detected_elements;
});
}
predictImageAnnotation(image_1) {
return __awaiter(this, arguments, void 0, function* (image, customElements = []) {
const inferenceResponse = yield this.inference(customElements, image);
if (!(inferenceResponse instanceof annotation_1.Annotation)) {
throw new inference_response_error_1.InferenceResponseError('Internal Error. Can not execute annotation');
}
return inferenceResponse;
});
}
predictVQAAnswer(prompt, image, config) {
return __awaiter(this, void 0, void 0, function* () {
const inferenceResponse = yield this.vqaInference(image, prompt, config);
const { response } = inferenceResponse.data;
try {
return JSON.parse(response);
}
catch (error) {
logger_1.logger.warn(`Response is no valid JSON: ${response}`);
}
return response;
});
}
predictActResponse(params) {
return __awaiter(this, void 0, void 0, function* () {
const response = yield this.httpClient.post(this.urls.actEndpoint, params);
InferenceClient.logMetaInformation(response.headers);
return response.body;
});
}
}
exports.InferenceClient = InferenceClient;