UNPKG

skymel-adk-js-beta

Version:

Skymel Agent Development Kit using Javascript - A JavaScript SDK for creating and managing intelligent agents

234 lines (208 loc) 8.71 kB
import {CommonValidators} from "./common_validators.js"; import * as tf from "./tfjs+esm.js"; // const tf = require ("./tfjs.js"); // importScripts("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"); // importScripts("./tfjs.js", "./common_validators.js"); class TFJSWebWorkerModelRunnerWorker { constructor(workerIndex, config) { this.tfJSModelUrl = ('tfJSModelUrl' in config) ? config['tfJSModelUrl'] : null; this.predictUsingExecuteAsync = ('predictUsingExecuteAsync' in config) ? config['predictUsingExecuteAsync'] : false; this.saveAndLoadModelUsingIndexedDB = ('saveAndLoadModelUsingIndexedDB' in config) ? config['saveAndLoadModelUsingIndexedDB'] : true; this.loadedTFJSModel = null; this.workerIndex = CommonValidators.isNumber(workerIndex) ? workerIndex : -1; console.log(tf); } async saveModelToIndexedDB(modelIndexedDbPath) { try { if (this.loadedTFJSModel === null) { return false; } await this.loadedTFJSModel.save(modelIndexedDbPath); return true; } catch (error) { console.log(`Encountered error while saving TensorFlowJS Model to IndexedDB : ${error}`); return false; } } async loadModelFromIndexedDB(modelIndexedDbPath) { try { this.loadedTFJSModel = await tf.loadGraphModel(modelIndexedDbPath); return true; } catch (error) { console.log(`Encountered error while loading TensorFlowJS Model from IndexedDB : ${error}`); console.log(error.stack); return false; } } async loadModelFromWebUrl(modelWebUrl) { try { this.loadedTFJSModel = await tf.loadGraphModel(modelWebUrl); return true; } catch (error) { console.log(`Encountered error while loading TensorFlowJS Model from Web: ${error}`); console.log(error.stack); return false; } } makeIndexedDbModelPathFromWebUrl(modelWebUrl) { return "indexeddb://" + modelWebUrl.replace("https://", "").replace("http://", "").replace("/", "_"); } async load() { if (this.saveAndLoadModelUsingIndexedDB) { const modelIndexedDBUrl = this.makeIndexedDbModelPathFromWebUrl(this.tfJSModelUrl); console.log("Attempting to load model from IndexedDB"); console.log(modelIndexedDBUrl); const didModelLoadFromIndexedDB = await this.loadModelFromIndexedDB(modelIndexedDBUrl); if (!didModelLoadFromIndexedDB) { console.log("Could not load model from IndexedDB. Trying to load from web url.") const didModelLoadFromWebUrl = await this.loadModelFromWebUrl(this.tfJSModelUrl); if (didModelLoadFromWebUrl) { console.log("Model loaded from web url. Saving to IndexedDB."); await this.saveModelToIndexedDB(modelIndexedDBUrl); } } } else { console.log("Attempting to load model from web url."); const didModelLoadFromWebUrl = await this.loadModelFromWebUrl(this.tfJSModelUrl); if (!didModelLoadFromWebUrl) { console.log("Could not load model from web url."); } else { console.log("Model loaded from web url."); } } } convertInputToTensor(input) { return tf.tensor(input.data, input.shape, input.dtype); } convertDictionaryToTensorDict(dict) { let output = {}; for (let name in dict) { output[name] = this.convertInputToTensor(dict[name]); } return output; } convertArrayToTensor(array) { let output = []; for (let x = 0; x < array.length; ++x) { output.push(this.convertInputToTensor(array[x])); } return output; } convertFeedDictToTensor(feedDict) { console.log(feedDict); if (Array.isArray(feedDict)) { return this.convertArrayToTensor(feedDict); } return this.convertDictionaryToTensorDict(feedDict); } async convertTensorToJSON(tensor) { // console.log("Received tensor: ", tensor); let output = {}; output['data'] = await tensor.dataSync(); output['shape'] = tensor.shape; output['dtype'] = tensor.dtype; return output; } isTensor(input) { return input.hasOwnProperty('dataId') && input.hasOwnProperty('shape') && input.hasOwnProperty('dtype'); } async convertTensorArrayToJSON(tensorArray) { let output = []; for (let x = 0; x < tensorArray.length; ++x) { let tensorConvertedToJSON = await this.convertTensorToJSON(tensorArray[x]); output.push(tensorConvertedToJSON); } return output; } async convertTensorDictToJSON(tensorDict) { let output = {}; for (let name in tensorDict) { output[name] = await this.convertTensorToJSON(tensorDict[name]); } return output; } async convertPredictionOutputToJSON(predictionOutput) { if (predictionOutput === null) { return null; } // console.log("Converting Prediction output: "); // console.log(predictionOutput); if (this.isTensor(predictionOutput)) { // console.log("Received tensor"); return this.convertTensorToJSON(predictionOutput); } if (Array.isArray(predictionOutput)) { // console.log("Received array"); return this.convertTensorArrayToJSON(predictionOutput); } // console.log("Received dict"); return this.convertTensorDictToJSON(predictionOutput); } async predict(feedDict) { try { const feedDictConvertedToTensor = this.convertFeedDictToTensor(feedDict); console.log(feedDictConvertedToTensor); let predictionOutput = null; if (this.predictUsingExecuteAsync) { predictionOutput = await this.loadedTFJSModel.executeAsync(feedDictConvertedToTensor); } else { predictionOutput = await this.loadedTFJSModel.predict(feedDictConvertedToTensor); } // console.log("Prediction output: "); // console.log(predictionOutput); return await this.convertPredictionOutputToJSON(predictionOutput); } catch (error) { console.log(`Encountered error while predicting using loaded TensorFlowJS Model : ${error}`); return null; } } async disposeLoadedModel() { if (CommonValidators.isEmpty(this.loadedTFJSModel)) { return true; } try { this.loadedTFJSModel.layers.forEach(l => l.dispose()); this.loadedTFJSModel = null; } catch (error) { console.log(`Error encountered while unloading TensorflowJS model : ${error}`); return false; } return true; } prepareResponseToMainScript(responseDataObject) { return {workerIndex: this.workerIndex, responseData: responseDataObject}; } } let workerObject = null; onmessage = async (event) => { switch (event.data.command) { case 'create': const workerConfig = event.data.config; const workerIndex = event.data.workerIndex; workerObject = new TFJSWebWorkerModelRunnerWorker(workerIndex, workerConfig); console.log('Worker created'); console.log(workerObject); postMessage(workerObject.prepareResponseToMainScript(true)); break; case 'load': await workerObject.load(); postMessage(workerObject.prepareResponseToMainScript(true)); break; case 'predict': // console.log('Worker predicting'); let predictResult = await workerObject.predict(event.data.feedDict); // console.log(predictResult); postMessage(workerObject.prepareResponseToMainScript(predictResult)); break; case 'disposeLoadedModel': console.log('Disposing loded model'); let disposeModelResult = await workerObject.disposeLoadedModel(); console.log(disposeModelResult); postMessage(workerObject.prepareResponseToMainScript(disposeModelResult)); break; default: console.log("Unknown command sent to worker!"); postMessage(workerObject.prepareResponseToMainScript(false)); break; } };