UNPKG

skymel-adk-js-beta

Version:

Skymel Agent Development Kit using Javascript - A JavaScript SDK for creating and managing intelligent agents

151 lines (126 loc) 4.94 kB
import {CommonValidators} from "./common_validators.js"; export class TFJSWebWorkerModelRunner { constructor(config, numberOfWorkers = 3) { this.config = config; this.workers = []; this.awaitingResponseFromWorkerIndices = new Set(); this.selectedWorkerIndex = 0; for (let i = 0; i < numberOfWorkers; ++i) { this.workers.push(new Worker(new URL('./tfjs_web_worker_model_runner_worker.js', import.meta.url), {type: 'module'})); } //, {type: 'module'}); } getCountOfCurrentlyIdleWorkers() { return this.workers.length - this.awaitingResponseFromWorkerIndices.size; } isReadyToAcceptInferenceTask() { return this.getCountOfCurrentlyIdleWorkers() > 0; } sendMessageToWorker(message, workerIndex = 0) { const selectedWorker = this.workers[workerIndex]; this.awaitingResponseFromWorkerIndices.add(workerIndex); const currentObject = this; return new Promise((resolve, reject) => { selectedWorker.onmessage = (event) => { resolve(currentObject.parseResponseFromWorker(event.data)); }; selectedWorker.onerror = (error) => { reject(error); }; selectedWorker.postMessage(message); }); } parseResponseFromWorker(responseObject) { const respondingWorkerIndex = CommonValidators.getKeyValueFromDictIfKeyAbsentReturnDefault(responseObject, 'workerIndex', -1); this.awaitingResponseFromWorkerIndices.delete(respondingWorkerIndex); return CommonValidators.getKeyValueFromDictIfKeyAbsentReturnDefault(responseObject, 'responseData', null); } __getPotentiallyFreeWorkerIndex() { let workerIndex = ++this.selectedWorkerIndex % this.workers.length; if (this.getCountOfCurrentlyIdleWorkers() === 0) { return workerIndex; } while (this.awaitingResponseFromWorkerIndices.has(workerIndex)) { workerIndex = ++this.selectedWorkerIndex % this.workers.length; } return workerIndex; } getSelectedWorkerIndex() { // return (++this.selectedWorkerIndex % this.workers.length); return this.__getPotentiallyFreeWorkerIndex(); } ___makeFloat32Tensor(flatArray, shape) { let totalLength = 1; for (let i = 0; i < shape.length; ++i) { totalLength *= shape[i]; } if (flatArray.length !== totalLength) { throw new Error("Provided flatArray length doesn't equal expected tensor length"); } return { data: flatArray, shape: shape, dtype: 'float32' }; // return tf.tensor(flatArray, shape, "float32"); } ___makeInt64Tensor(flatArray, shape) { let totalLength = 1; for (let i = 0; i < shape.length; ++i) { totalLength *= shape[i]; } if (flatArray.length !== totalLength) { throw new Error("Provided flatArray length doesn't equal expected tensor length"); } return { data: flatArray, shape: shape, dtype: 'int32' }; // return tf.tensor(flatArray, shape, "int32"); } __disposeTensor(inputTensor) { // Just a placeholder. // inputTensor.dispose(); } async load() { try { for (let i = 0; i < this.workers.length; ++i) { await this.sendMessageToWorker({command: 'create', config: this.config, workerIndex: i}, i); await this.sendMessageToWorker({command: 'load'}, i); } } catch (error) { console.log(`Encountered error while loadig TensorFlowJS Model : ${error}`); console.log(error); console.log(error.stack); } } makeInputTensorFromFlatFloat32Array(flatArray, shape) { return this.___makeFloat32Tensor(flatArray, shape); } makeInputTensorFromFlatInt64Array(flatArray, shape) { return this.___makeInt64Tensor(flatArray, shape); } disposeInputTensor(inputTensor) { return this.__disposeTensor(inputTensor); } async runInference(feedDict) { if (this.workers.length === 0) { return null; } try { return await this.sendMessageToWorker({ command: 'predict', feedDict: feedDict }, this.getSelectedWorkerIndex()); } catch (error) { console.log(`Encountered error while running inference on TensorFlowJS Model : ${error}`); } return null; } async dispose() { if (CommonValidators.isEmpty(this.workers)) { return; } for (let i = 0; i < this.workers.length; ++i) { await this.sendMessageToWorker({command: 'disposeLoadedModel'}, i); } } } // exports.default = TFLiteModelRunner;