UNPKG

@tensorflow/tfjs-converter

Version:

Tensorflow model converter for javascript

406 lines 50 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { io, Tensor } from '@tensorflow/tfjs-core'; import { OperationMapper } from '../operations/operation_mapper'; import { GraphExecutor } from './graph_executor'; import { ResourceManager } from './resource_manager'; export const TFHUB_SEARCH_PARAM = '?tfjs-format=file'; export const DEFAULT_MODEL_NAME = 'model.json'; /** * A `tf.GraphModel` is a directed, acyclic graph built from a * SavedModel GraphDef and allows inference execution. * * A `tf.GraphModel` can only be created by loading from a model converted from * a [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) using * the command line converter tool and loaded via `tf.loadGraphModel`. * * @doc {heading: 'Models', subheading: 'Classes'} */ export class GraphModel { /** * @param modelUrl url for the model, or an `io.IOHandler`. * @param weightManifestUrl url for the weight file generated by * scripts/convert.py script. * @param requestOption options for Request, which allows to send credentials * and custom headers. * @param onProgress Optional, progress callback function, fired periodically * before the load is completed. */ constructor(modelUrl, loadOptions = {}) { this.modelUrl = modelUrl; this.loadOptions = loadOptions; this.version = 'n/a'; if (loadOptions == null) { this.loadOptions = {}; } this.resourceManager = new ResourceManager(); } // Returns the version information for the tensorflow model GraphDef. get modelVersion() { return this.version; } get inputNodes() { return this.executor.inputNodes; } get outputNodes() { return this.executor.outputNodes; } get inputs() { return this.executor.inputs; } get outputs() { return this.executor.outputs; } get weights() { return this.executor.weightMap; } get metadata() { return this.artifacts.userDefinedMetadata; } get modelSignature() { return this.signature; } findIOHandler() { const path = this.modelUrl; if (path.load != null) { // Path is an IO Handler. this.handler = path; } else if (this.loadOptions.requestInit != null) { this.handler = io.browserHTTPRequest(path, this.loadOptions); } else { const handlers = io.getLoadHandlers(path, this.loadOptions); if (handlers.length === 0) { // For backward compatibility: if no load handler can be found, // assume it is a relative http path. handlers.push(io.browserHTTPRequest(path, this.loadOptions)); } else if (handlers.length > 1) { throw new Error(`Found more than one (${handlers.length}) load handlers for ` + `URL '${[path]}'`); } this.handler = handlers[0]; } } /** * Loads the model and weight files, construct the in memory weight map and * compile the inference graph. */ async load() { this.findIOHandler(); if (this.handler.load == null) { throw new Error('Cannot proceed with model loading because the IOHandler provided ' + 'does not have the `load` method implemented.'); } const artifacts = await this.handler.load(); return this.loadSync(artifacts); } /** * Synchronously construct the in memory weight map and * compile the inference graph. Also initialize hashtable if any. * * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true} */ loadSync(artifacts) { this.artifacts = artifacts; const graph = this.artifacts.modelTopology; let signature; if (this.artifacts.userDefinedMetadata != null && this.artifacts.userDefinedMetadata.signature != null) { signature = // tslint:disable-next-line:no-any this.artifacts.userDefinedMetadata.signature; } else { signature = this.artifacts.signature; } this.signature = signature; this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`; const weightMap = io.decodeWeights(this.artifacts.weightData, this.artifacts.weightSpecs); this.executor = new GraphExecutor(OperationMapper.Instance.transformGraph(graph, this.signature)); this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap); // Attach a model-level resourceManager to each executor to share resources, // such as `HashTable`. this.executor.resourceManager = this.resourceManager; if (artifacts.modelInitializer != null && artifacts.modelInitializer.node != null) { const initializer = OperationMapper.Instance.transformGraph(artifacts.modelInitializer); this.initializer = new GraphExecutor(initializer); this.initializer.weightMap = this.executor.weightMap; // Attach a model-level resourceManager to the initializer, the // hashTables created from when executing the initializer will be stored // in the resourceManager. this.initializer.resourceManager = this.resourceManager; this.initializer.executeAsync({}, []); } return true; } /** * Save the configuration and/or weights of the GraphModel. * * An `IOHandler` is an object that has a `save` method of the proper * signature defined. The `save` method manages the storing or * transmission of serialized data ("artifacts") that represent the * model's topology and weights onto or via a specific medium, such as * file downloads, local storage, IndexedDB in the web browser and HTTP * requests to a server. TensorFlow.js provides `IOHandler` * implementations for a number of frequently used saving mediums, such as * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io` * for more details. * * This method also allows you to refer to certain types of `IOHandler`s * as URL-like string shortcuts, such as 'localstorage://' and * 'indexeddb://'. * * Example 1: Save `model`'s topology and weights to browser [local * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage); * then load it back. * * ```js * const modelUrl = * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json'; * const model = await tf.loadGraphModel(modelUrl); * const zeros = tf.zeros([1, 224, 224, 3]); * model.predict(zeros).print(); * * const saveResults = await model.save('localstorage://my-model-1'); * * const loadedModel = await tf.loadGraphModel('localstorage://my-model-1'); * console.log('Prediction from loaded model:'); * model.predict(zeros).print(); * ``` * * @param handlerOrURL An instance of `IOHandler` or a URL-like, * scheme-based string shortcut for `IOHandler`. * @param config Options for saving the model. * @returns A `Promise` of `SaveResult`, which summarizes the result of * the saving, such as byte sizes of the saved artifacts for the model's * topology and weight values. * * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true} */ async save(handlerOrURL, config) { if (typeof handlerOrURL === 'string') { const handlers = io.getSaveHandlers(handlerOrURL); if (handlers.length === 0) { throw new Error(`Cannot find any save handlers for URL '${handlerOrURL}'`); } else if (handlers.length > 1) { throw new Error(`Found more than one (${handlers.length}) save handlers for ` + `URL '${handlerOrURL}'`); } handlerOrURL = handlers[0]; } if (handlerOrURL.save == null) { throw new Error('GraphModel.save() cannot proceed because the IOHandler ' + 'provided does not have the `save` attribute defined.'); } return handlerOrURL.save(this.artifacts); } /** * Execute the inference for the input tensors. * * @param input The input tensors, when there is single input for the model, * inputs param should be a `tf.Tensor`. For models with mutliple inputs, * inputs params should be in either `tf.Tensor`[] if the input order is * fixed, or otherwise NamedTensorMap format. * * For model with multiple inputs, we recommend you use NamedTensorMap as the * input type, if you use `tf.Tensor`[], the order of the array needs to * follow the * order of inputNodes array. @see {@link GraphModel.inputNodes} * * You can also feed any intermediate nodes using the NamedTensorMap as the * input type. For example, given the graph * InputNode => Intermediate => OutputNode, * you can execute the subgraph Intermediate => OutputNode by calling * model.execute('IntermediateNode' : tf.tensor(...)); * * This is useful for models that uses tf.dynamic_rnn, where the intermediate * state needs to be fed manually. * * For batch inference execution, the tensors for each input need to be * concatenated together. For example with mobilenet, the required input shape * is [1, 244, 244, 3], which represents the [batch, height, width, channel]. * If we are provide a batched data of 100 images, the input tensor should be * in the shape of [100, 244, 244, 3]. * * @param config Prediction configuration for specifying the batch size and * output node names. Currently the batch size option is ignored for graph * model. * * @returns Inference result tensors. The output would be single `tf.Tensor` * if model has single output node, otherwise Tensor[] or NamedTensorMap[] * will be returned for model with multiple outputs. * * @doc {heading: 'Models', subheading: 'Classes'} */ predict(inputs, config) { return this.execute(inputs, this.outputNodes); } normalizeInputs(inputs) { if (!(inputs instanceof Tensor) && !Array.isArray(inputs)) { // The input is already a NamedTensorMap. return inputs; } inputs = Array.isArray(inputs) ? inputs : [inputs]; if (inputs.length !== this.inputNodes.length) { throw new Error('Input tensor count mismatch,' + `the graph model has ${this.inputNodes.length} placeholders, ` + `while there are ${inputs.length} input tensors.`); } return this.inputNodes.reduce((map, inputName, i) => { map[inputName] = inputs[i]; return map; }, {}); } normalizeOutputs(outputs) { outputs = outputs || this.outputNodes; return !Array.isArray(outputs) ? [outputs] : outputs; } /** * Executes inference for the model for given input tensors. * @param inputs tensor, tensor array or tensor map of the inputs for the * model, keyed by the input node names. * @param outputs output node name from the Tensorflow model, if no * outputs are specified, the default outputs of the model would be used. * You can inspect intermediate nodes of the model by adding them to the * outputs array. * * @returns A single tensor if provided with a single output or no outputs * are provided and there is only one default output, otherwise return a * tensor array. The order of the tensor array is the same as the outputs * if provided, otherwise the order of outputNodes attribute of the model. * * @doc {heading: 'Models', subheading: 'Classes'} */ execute(inputs, outputs) { inputs = this.normalizeInputs(inputs); outputs = this.normalizeOutputs(outputs); const result = this.executor.execute(inputs, outputs); return result.length > 1 ? result : result[0]; } /** * Executes inference for the model for given input tensors in async * fashion, use this method when your model contains control flow ops. * @param inputs tensor, tensor array or tensor map of the inputs for the * model, keyed by the input node names. * @param outputs output node name from the Tensorflow model, if no outputs * are specified, the default outputs of the model would be used. You can * inspect intermediate nodes of the model by adding them to the outputs * array. * * @returns A Promise of single tensor if provided with a single output or * no outputs are provided and there is only one default output, otherwise * return a tensor map. * * @doc {heading: 'Models', subheading: 'Classes'} */ async executeAsync(inputs, outputs) { inputs = this.normalizeInputs(inputs); outputs = this.normalizeOutputs(outputs); const result = await this.executor.executeAsync(inputs, outputs); return result.length > 1 ? result : result[0]; } /** * Get intermediate tensors for model debugging mode (flag * KEEP_INTERMEDIATE_TENSORS is true). * * @doc {heading: 'Models', subheading: 'Classes'} */ getIntermediateTensors() { return this.executor.getIntermediateTensors(); } /** * Dispose intermediate tensors for model debugging mode (flag * KEEP_INTERMEDIATE_TENSORS is true). * * @doc {heading: 'Models', subheading: 'Classes'} */ disposeIntermediateTensors() { this.executor.disposeIntermediateTensors(); } convertTensorMapToTensorsMap(map) { return Object.keys(map).reduce((newMap, key) => { newMap[key] = [map[key]]; return newMap; }, {}); } /** * Releases the memory used by the weight tensors and resourceManager. * * @doc {heading: 'Models', subheading: 'Classes'} */ dispose() { this.executor.dispose(); if (this.initializer) { this.initializer.dispose(); } this.resourceManager.dispose(); } } /** * Load a graph model given a URL to the model definition. * * Example of loading MobileNetV2 from a URL and making a prediction with a * zeros input: * * ```js * const modelUrl = * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json'; * const model = await tf.loadGraphModel(modelUrl); * const zeros = tf.zeros([1, 224, 224, 3]); * model.predict(zeros).print(); * ``` * * Example of loading MobileNetV2 from a TF Hub URL and making a prediction with * a zeros input: * * ```js * const modelUrl = * 'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2'; * const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true}); * const zeros = tf.zeros([1, 224, 224, 3]); * model.predict(zeros).print(); * ``` * @param modelUrl The url or an `io.IOHandler` that loads the model. * @param options Options for the HTTP request, which allows to send credentials * and custom headers. * * @doc {heading: 'Models', subheading: 'Loading'} */ export async function loadGraphModel(modelUrl, options = {}) { if (modelUrl == null) { throw new Error('modelUrl in loadGraphModel() cannot be null. Please provide a url ' + 'or an IOHandler that loads the model'); } if (options == null) { options = {}; } if (options.fromTFHub) { if (modelUrl.load == null) { if (!modelUrl.endsWith('/')) { modelUrl = modelUrl + '/'; } modelUrl = `${modelUrl}${DEFAULT_MODEL_NAME}${TFHUB_SEARCH_PARAM}`; } } const model = new GraphModel(modelUrl, options); await model.load(); return model; } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"graph_model.js","sourceRoot":"","sources":["../../../../../../tfjs-converter/src/executor/graph_model.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAiB,EAAE,EAAsC,MAAM,EAAC,MAAM,uBAAuB,CAAC;AAIrG,OAAO,EAAC,eAAe,EAAC,MAAM,gCAAgC,CAAC;AAE/D,OAAO,EAAC,aAAa,EAAC,MAAM,kBAAkB,CAAC;AAC/C,OAAO,EAAC,eAAe,EAAC,MAAM,oBAAoB,CAAC;AAEnD,MAAM,CAAC,MAAM,kBAAkB,GAAG,mBAAmB,CAAC;AACtD,MAAM,CAAC,MAAM,kBAAkB,GAAG,YAAY,CAAC;AAC/C;;;;;;;;;GASG;AACH,MAAM,OAAO,UAAU;IA0CrB;;;;;;;;OAQG;IACH,YACY,QAA6B,EAC7B,cAA8B,EAAE;QADhC,aAAQ,GAAR,QAAQ,CAAqB;QAC7B,gBAAW,GAAX,WAAW,CAAqB;QAnDpC,YAAO,GAAG,KAAK,CAAC;QAoDtB,IAAI,WAAW,IAAI,IAAI,EAAE;YACvB,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC;SACvB;QACD,IAAI,CAAC,eAAe,GAAG,IAAI,eAAe,EAAE,CAAC;IAC/C,CAAC;IAjDD,qEAAqE;IACrE,IAAI,YAAY;QACd,OAAO,IAAI,CAAC,OAAO,CAAC;IACtB,CAAC;IAED,IAAI,UAAU;QACZ,OAAO,IAAI,CAAC,QAAQ,CAAC,UAAU,CAAC;IAClC,CAAC;IAED,IAAI,WAAW;QACb,OAAO,IAAI,CAAC,QAAQ,CAAC,WAAW,CAAC;IACnC,CAAC;IAED,IAAI,MAAM;QACR,OAAO,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC;IAC9B,CAAC;IAED,IAAI,OAAO;QACT,OAAO,IAAI,CAAC,QAAQ,CAAC,OAAO,CAAC;IAC/B,CAAC;IAED,IAAI,OAAO;QACT,OAAO,IAAI,CAAC,QAAQ,CAAC,SAAS,CAAC;IACjC,CAAC;IAED,IAAI,QAAQ;QACV,OAAO,IAAI,CAAC,SAAS,CAAC,mBAAmB,CAAC;IAC5C,CAAC;IAED,IAAI,cAAc;QAChB,OAAO,IAAI,CAAC,SAAS,CAAC;IACxB,CAAC;IAoBO,aAAa;QACnB,MAAM,IAAI,GAAG,IAAI,CAAC,QAAQ,CAAC;QAC3B,IAAK,IAAqB,CAAC,IAAI,IAAI,IAAI,EAAE;YACvC,yBAAyB;YACzB,IAAI,CAAC,OAAO,GAAG,IAAoB,CAAC;SACrC;aAAM,IAAI,IAAI,CAAC,WAAW,CAAC,WAAW,IAAI,IAAI,EAAE;YAC/C,IAAI,CAAC,OAAO,GAAG,EAAE,CAAC,kBAAkB,CAAC,IAAc,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;SACxE;aAAM;YACL,MAAM,QAAQ,GAAG,EAAE,CAAC,eAAe,CAAC,IAAc,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;YACtE,IAAI,QAAQ,CAAC,MAAM,KAAK,CAAC,EAAE;gBACzB,+DAA+D;gBAC/D,qCAAqC;gBACrC,QAAQ,CAAC,IAAI,CAAC,EAAE,CAAC,kBAAkB,CAAC,IAAc,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC;aACxE;iBAAM,IAAI,QAAQ,CAAC,MAAM,GAAG,CAAC,EAAE;gBAC9B,MAAM,IAAI,KAAK,CACX,wBAAwB,QAAQ,CAAC,MAAM,sBAAsB;oBAC7D,QAAQ,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;aACxB;YACD,IAAI,CAAC,OAAO,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAC;SAC5B;IACH,CAAC;IAED;;;OAGG;IACH,KAAK,CAAC,IAAI;QACR,IAAI,CAAC,aAAa,EAAE,CAAC;QACrB,IAAI,IAAI,CAAC,OAAO,CAAC,IAAI,IAAI,IAAI,EAAE;YAC7B,MAAM,IAAI,KAAK,CACX,mEAAmE;gBACnE,8CAA8C,CAAC,CAAC;SACrD;QACD,MAAM,SAAS,GAAG,MAAM,IAAI,CAAC,OAAO,CAAC,IAAI,EAAE,CAAC;QAE5C,OAAO,IAAI,CAAC,QAAQ,CAAC,SAAS,CAAC,CAAC;IAClC,CAAC;IAED;;;;;OAKG;IACH,QAAQ,CAAC,SAA4B;QACnC,IAAI,CAAC,SAAS,GAAG,SAAS,CAAC;QAC3B,MAAM,KAAK,GAAG,IAAI,CAAC,SAAS,CAAC,aAAqC,CAAC;QAEnE,IAAI,SAAS,CAAC;QACd,IAAI,IAAI,CAAC,SAAS,CAAC,mBAAmB,IAAI,IAAI;YAC1C,IAAI,CAAC,SAAS,CAAC,mBAAmB,CAAC,SAAS,IAAI,IAAI,EAAE;YACxD,SAAS,GAAI,kCAAkC;gBAC1C,IAAI,CAAC,SAAS,CAAC,mBAA2B,CAAC,SACpB,CAAC;SAC9B;aAAM;YACL,SAAS,GAAG,IAAI,CAAC,SAAS,CAAC,SAAS,CAAC;SACtC;QACD,IAAI,CAAC,SAAS,GAAG,SAAS,CAAC;QAE3B,IAAI,CAAC,OAAO,GAAG,GAAG,KAAK,CAAC,QAAQ,CAAC,QAAQ,IAAI,KAAK,CAAC,QAAQ,CAAC,WAAW,EAAE,CAAC;QAC1E,MAAM,SAAS,GACX,EAAE,CAAC,aAAa,CAAC,IAAI,CAAC,SAAS,CAAC,UAAU,EAAE,IAAI,CAAC,SAAS,CAAC,WAAW,CAAC,CAAC;QAC5E,IAAI,CAAC,QAAQ,GAAG,IAAI,aAAa,CAC7B,eAAe,CAAC,QAAQ,CAAC,cAAc,CAAC,KAAK,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC,CAAC;QACpE,IAAI,CAAC,QAAQ,CAAC,SAAS,GAAG,IAAI,CAAC,4BAA4B,CAAC,SAAS,CAAC,CAAC;QACvE,4EAA4E;QAC5E,uBAAuB;QACvB,IAAI,CAAC,QAAQ,CAAC,eAAe,GAAG,IAAI,CAAC,eAAe,CAAC;QAErD,IAAI,SAAS,CAAC,gBAAgB,IAAI,IAAI;YACjC,SAAS,CAAC,gBAAyC,CAAC,IAAI,IAAI,IAAI,EAAE;YACrE,MAAM,WAAW,GACb,eAAe,CAAC,QAAQ,CAAC,cAAc,CAAC,SAAS,CAAC,gBAAgB,CAAC,CAAC;YACxE,IAAI,CAAC,WAAW,GAAG,IAAI,aAAa,CAAC,WAAW,CAAC,CAAC;YAClD,IAAI,CAAC,WAAW,CAAC,SAAS,GAAG,IAAI,CAAC,QAAQ,CAAC,SAAS,CAAC;YACrD,+DAA+D;YAC/D,wEAAwE;YACxE,0BAA0B;YAC1B,IAAI,CAAC,WAAW,CAAC,eAAe,GAAG,IAAI,CAAC,eAAe,CAAC;YACxD,IAAI,CAAC,WAAW,CAAC,YAAY,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC;SACvC;QAED,OAAO,IAAI,CAAC;IACd,CAAC;IAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;OA2CG;IACH,KAAK,CAAC,IAAI,CAAC,YAAiC,EAAE,MAAsB;QAElE,IAAI,OAAO,YAAY,KAAK,QAAQ,EAAE;YACpC,MAAM,QAAQ,GAAG,EAAE,CAAC,eAAe,CAAC,YAAY,CAAC,CAAC;YAClD,IAAI,QAAQ,CAAC,MAAM,KAAK,CAAC,EAAE;gBACzB,MAAM,IAAI,KAAK,CACX,0CAA0C,YAAY,GAAG,CAAC,CAAC;aAChE;iBAAM,IAAI,QAAQ,CAAC,MAAM,GAAG,CAAC,EAAE;gBAC9B,MAAM,IAAI,KAAK,CACX,wBAAwB,QAAQ,CAAC,MAAM,sBAAsB;oBAC7D,QAAQ,YAAY,GAAG,CAAC,CAAC;aAC9B;YACD,YAAY,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAC;SAC5B;QACD,IAAI,YAAY,CAAC,IAAI,IAAI,IAAI,EAAE;YAC7B,MAAM,IAAI,KAAK,CACX,yDAAyD;gBACzD,sDAAsD,CAAC,CAAC;SAC7D;QAED,OAAO,YAAY,CAAC,IAAI,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;IAC3C,CAAC;IAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;OAqCG;IACH,OAAO,CAAC,MAAsC,EAAE,MAA2B;QAEzE,OAAO,IAAI,CAAC,OAAO,CAAC,MAAM,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;IAChD,CAAC;IAEO,eAAe,CAAC,MACc;QACpC,IAAI,CAAC,CAAC,MAAM,YAAY,MAAM,CAAC,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;YACzD,yCAAyC;YACzC,OAAO,MAAM,CAAC;SACf;QACD,MAAM,GAAG,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;QACnD,IAAI,MAAM,CAAC,MAAM,KAAK,IAAI,CAAC,UAAU,CAAC,MAAM,EAAE;YAC5C,MAAM,IAAI,KAAK,CACX,8BAA8B;gBAC9B,uBAAuB,IAAI,CAAC,UAAU,CAAC,MAAM,iBAAiB;gBAC9D,mBAAmB,MAAM,CAAC,MAAM,iBAAiB,CAAC,CAAC;SACxD;QACD,OAAO,IAAI,CAAC,UAAU,CAAC,MAAM,CAAC,CAAC,GAAG,EAAE,SAAS,EAAE,CAAC,EAAE,EAAE;YAClD,GAAG,CAAC,SAAS,CAAC,GAAI,MAAmB,CAAC,CAAC,CAAC,CAAC;YACzC,OAAO,GAAG,CAAC;QACb,CAAC,EAAE,EAAoB,CAAC,CAAC;IAC3B,CAAC;IAEO,gBAAgB,CAAC,OAAwB;QAC/C,OAAO,GAAG,OAAO,IAAI,IAAI,CAAC,WAAW,CAAC;QACtC,OAAO,CAAC,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC;IACvD,CAAC;IAED;;;;;;;;;;;;;;;OAeG;IACH,OAAO,CAAC,MAAsC,EAAE,OAAyB;QAEvE,MAAM,GAAG,IAAI,CAAC,eAAe,CAAC,MAAM,CAAC,CAAC;QACtC,OAAO,GAAG,IAAI,CAAC,gBAAgB,CAAC,OAAO,CAAC,CAAC;QACzC,MAAM,MAAM,GAAG,IAAI,CAAC,QAAQ,CAAC,OAAO,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;QACtD,OAAO,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;IAChD,CAAC;IACD;;;;;;;;;;;;;;;OAeG;IACH,KAAK,CAAC,YAAY,CACd,MAAsC,EACtC,OAAyB;QAC3B,MAAM,GAAG,IAAI,CAAC,eAAe,CAAC,MAAM,CAAC,CAAC;QACtC,OAAO,GAAG,IAAI,CAAC,gBAAgB,CAAC,OAAO,CAAC,CAAC;QACzC,MAAM,MAAM,GAAG,MAAM,IAAI,CAAC,QAAQ,CAAC,YAAY,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;QACjE,OAAO,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;IAChD,CAAC;IAED;;;;;OAKG;IACH,sBAAsB;QACpB,OAAO,IAAI,CAAC,QAAQ,CAAC,sBAAsB,EAAE,CAAC;IAChD,CAAC;IAED;;;;;OAKG;IACH,0BAA0B;QACxB,IAAI,CAAC,QAAQ,CAAC,0BAA0B,EAAE,CAAC;IAC7C,CAAC;IAEO,4BAA4B,CAAC,GAAmB;QACtD,OAAO,MAAM,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,MAAM,CAAC,CAAC,MAAuB,EAAE,GAAG,EAAE,EAAE;YAC9D,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;YACzB,OAAO,MAAM,CAAC;QAChB,CAAC,EAAE,EAAE,CAAC,CAAC;IACT,CAAC;IAED;;;;OAIG;IACH,OAAO;QACL,IAAI,CAAC,QAAQ,CAAC,OAAO,EAAE,CAAC;QAExB,IAAI,IAAI,CAAC,WAAW,EAAE;YACpB,IAAI,CAAC,WAAW,CAAC,OAAO,EAAE,CAAC;SAC5B;QAED,IAAI,CAAC,eAAe,CAAC,OAAO,EAAE,CAAC;IACjC,CAAC;CACF;AAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA6BG;AACH,MAAM,CAAC,KAAK,UAAU,cAAc,CAChC,QAA6B,EAC7B,UAA0B,EAAE;IAC9B,IAAI,QAAQ,IAAI,IAAI,EAAE;QACpB,MAAM,IAAI,KAAK,CACX,oEAAoE;YACpE,sCAAsC,CAAC,CAAC;KAC7C;IACD,IAAI,OAAO,IAAI,IAAI,EAAE;QACnB,OAAO,GAAG,EAAE,CAAC;KACd;IAED,IAAI,OAAO,CAAC,SAAS,EAAE;QACrB,IAAK,QAAyB,CAAC,IAAI,IAAI,IAAI,EAAE;YAC3C,IAAI,CAAE,QAAmB,CAAC,QAAQ,CAAC,GAAG,CAAC,EAAE;gBACvC,QAAQ,GAAI,QAAmB,GAAG,GAAG,CAAC;aACvC;YACD,QAAQ,GAAG,GAAG,QAAQ,GAAG,kBAAkB,GAAG,kBAAkB,EAAE,CAAC;SACpE;KACF;IACD,MAAM,KAAK,GAAG,IAAI,UAAU,CAAC,QAAQ,EAAE,OAAO,CAAC,CAAC;IAChD,MAAM,KAAK,CAAC,IAAI,EAAE,CAAC;IACnB,OAAO,KAAK,CAAC;AACf,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {InferenceModel, io, ModelPredictConfig, NamedTensorMap, Tensor} from '@tensorflow/tfjs-core';\n\nimport * as tensorflow from '../data/compiled_api';\nimport {NamedTensorsMap, TensorInfo} from '../data/types';\nimport {OperationMapper} from '../operations/operation_mapper';\n\nimport {GraphExecutor} from './graph_executor';\nimport {ResourceManager} from './resource_manager';\n\nexport const TFHUB_SEARCH_PARAM = '?tfjs-format=file';\nexport const DEFAULT_MODEL_NAME = 'model.json';\n/**\n * A `tf.GraphModel` is a directed, acyclic graph built from a\n * SavedModel GraphDef and allows inference execution.\n *\n * A `tf.GraphModel` can only be created by loading from a model converted from\n * a [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) using\n * the command line converter tool and loaded via `tf.loadGraphModel`.\n *\n * @doc {heading: 'Models', subheading: 'Classes'}\n */\nexport class GraphModel implements InferenceModel {\n  private executor: GraphExecutor;\n  private version = 'n/a';\n  private handler: io.IOHandler;\n  private artifacts: io.ModelArtifacts;\n  private initializer: GraphExecutor;\n  private resourceManager: ResourceManager;\n  private signature: tensorflow.ISignatureDef;\n\n  // Returns the version information for the tensorflow model GraphDef.\n  get modelVersion(): string {\n    return this.version;\n  }\n\n  get inputNodes(): string[] {\n    return this.executor.inputNodes;\n  }\n\n  get outputNodes(): string[] {\n    return this.executor.outputNodes;\n  }\n\n  get inputs(): TensorInfo[] {\n    return this.executor.inputs;\n  }\n\n  get outputs(): TensorInfo[] {\n    return this.executor.outputs;\n  }\n\n  get weights(): NamedTensorsMap {\n    return this.executor.weightMap;\n  }\n\n  get metadata(): {} {\n    return this.artifacts.userDefinedMetadata;\n  }\n\n  get modelSignature(): {} {\n    return this.signature;\n  }\n\n  /**\n   * @param modelUrl url for the model, or an `io.IOHandler`.\n   * @param weightManifestUrl url for the weight file generated by\n   * scripts/convert.py script.\n   * @param requestOption options for Request, which allows to send credentials\n   * and custom headers.\n   * @param onProgress Optional, progress callback function, fired periodically\n   * before the load is completed.\n   */\n  constructor(\n      private modelUrl: string|io.IOHandler,\n      private loadOptions: io.LoadOptions = {}) {\n    if (loadOptions == null) {\n      this.loadOptions = {};\n    }\n    this.resourceManager = new ResourceManager();\n  }\n\n  private findIOHandler() {\n    const path = this.modelUrl;\n    if ((path as io.IOHandler).load != null) {\n      // Path is an IO Handler.\n      this.handler = path as io.IOHandler;\n    } else if (this.loadOptions.requestInit != null) {\n      this.handler = io.browserHTTPRequest(path as string, this.loadOptions);\n    } else {\n      const handlers = io.getLoadHandlers(path as string, this.loadOptions);\n      if (handlers.length === 0) {\n        // For backward compatibility: if no load handler can be found,\n        // assume it is a relative http path.\n        handlers.push(io.browserHTTPRequest(path as string, this.loadOptions));\n      } else if (handlers.length > 1) {\n        throw new Error(\n            `Found more than one (${handlers.length}) load handlers for ` +\n            `URL '${[path]}'`);\n      }\n      this.handler = handlers[0];\n    }\n  }\n\n  /**\n   * Loads the model and weight files, construct the in memory weight map and\n   * compile the inference graph.\n   */\n  async load(): Promise<boolean> {\n    this.findIOHandler();\n    if (this.handler.load == null) {\n      throw new Error(\n          'Cannot proceed with model loading because the IOHandler provided ' +\n          'does not have the `load` method implemented.');\n    }\n    const artifacts = await this.handler.load();\n\n    return this.loadSync(artifacts);\n  }\n\n  /**\n   * Synchronously construct the in memory weight map and\n   * compile the inference graph. Also initialize hashtable if any.\n   *\n   * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}\n   */\n  loadSync(artifacts: io.ModelArtifacts) {\n    this.artifacts = artifacts;\n    const graph = this.artifacts.modelTopology as tensorflow.IGraphDef;\n\n    let signature;\n    if (this.artifacts.userDefinedMetadata != null &&\n        this.artifacts.userDefinedMetadata.signature != null) {\n      signature =  // tslint:disable-next-line:no-any\n          (this.artifacts.userDefinedMetadata as any).signature as\n          tensorflow.ISignatureDef;\n    } else {\n      signature = this.artifacts.signature;\n    }\n    this.signature = signature;\n\n    this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`;\n    const weightMap =\n        io.decodeWeights(this.artifacts.weightData, this.artifacts.weightSpecs);\n    this.executor = new GraphExecutor(\n        OperationMapper.Instance.transformGraph(graph, this.signature));\n    this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap);\n    // Attach a model-level resourceManager to each executor to share resources,\n    // such as `HashTable`.\n    this.executor.resourceManager = this.resourceManager;\n\n    if (artifacts.modelInitializer != null &&\n        (artifacts.modelInitializer as tensorflow.IGraphDef).node != null) {\n      const initializer =\n          OperationMapper.Instance.transformGraph(artifacts.modelInitializer);\n      this.initializer = new GraphExecutor(initializer);\n      this.initializer.weightMap = this.executor.weightMap;\n      // Attach a model-level resourceManager to the initializer, the\n      // hashTables created from when executing the initializer will be stored\n      // in the resourceManager.\n      this.initializer.resourceManager = this.resourceManager;\n      this.initializer.executeAsync({}, []);\n    }\n\n    return true;\n  }\n\n  /**\n   * Save the configuration and/or weights of the GraphModel.\n   *\n   * An `IOHandler` is an object that has a `save` method of the proper\n   * signature defined. The `save` method manages the storing or\n   * transmission of serialized data (\"artifacts\") that represent the\n   * model's topology and weights onto or via a specific medium, such as\n   * file downloads, local storage, IndexedDB in the web browser and HTTP\n   * requests to a server. TensorFlow.js provides `IOHandler`\n   * implementations for a number of frequently used saving mediums, such as\n   * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`\n   * for more details.\n   *\n   * This method also allows you to refer to certain types of `IOHandler`s\n   * as URL-like string shortcuts, such as 'localstorage://' and\n   * 'indexeddb://'.\n   *\n   * Example 1: Save `model`'s topology and weights to browser [local\n   * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);\n   * then load it back.\n   *\n   * ```js\n   * const modelUrl =\n   *    'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';\n   * const model = await tf.loadGraphModel(modelUrl);\n   * const zeros = tf.zeros([1, 224, 224, 3]);\n   * model.predict(zeros).print();\n   *\n   * const saveResults = await model.save('localstorage://my-model-1');\n   *\n   * const loadedModel = await tf.loadGraphModel('localstorage://my-model-1');\n   * console.log('Prediction from loaded model:');\n   * model.predict(zeros).print();\n   * ```\n   *\n   * @param handlerOrURL An instance of `IOHandler` or a URL-like,\n   * scheme-based string shortcut for `IOHandler`.\n   * @param config Options for saving the model.\n   * @returns A `Promise` of `SaveResult`, which summarizes the result of\n   * the saving, such as byte sizes of the saved artifacts for the model's\n   *   topology and weight values.\n   *\n   * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}\n   */\n  async save(handlerOrURL: io.IOHandler|string, config?: io.SaveConfig):\n      Promise<io.SaveResult> {\n    if (typeof handlerOrURL === 'string') {\n      const handlers = io.getSaveHandlers(handlerOrURL);\n      if (handlers.length === 0) {\n        throw new Error(\n            `Cannot find any save handlers for URL '${handlerOrURL}'`);\n      } else if (handlers.length > 1) {\n        throw new Error(\n            `Found more than one (${handlers.length}) save handlers for ` +\n            `URL '${handlerOrURL}'`);\n      }\n      handlerOrURL = handlers[0];\n    }\n    if (handlerOrURL.save == null) {\n      throw new Error(\n          'GraphModel.save() cannot proceed because the IOHandler ' +\n          'provided does not have the `save` attribute defined.');\n    }\n\n    return handlerOrURL.save(this.artifacts);\n  }\n\n  /**\n   * Execute the inference for the input tensors.\n   *\n   * @param input The input tensors, when there is single input for the model,\n   * inputs param should be a `tf.Tensor`. For models with mutliple inputs,\n   * inputs params should be in either `tf.Tensor`[] if the input order is\n   * fixed, or otherwise NamedTensorMap format.\n   *\n   * For model with multiple inputs, we recommend you use NamedTensorMap as the\n   * input type, if you use `tf.Tensor`[], the order of the array needs to\n   * follow the\n   * order of inputNodes array. @see {@link GraphModel.inputNodes}\n   *\n   * You can also feed any intermediate nodes using the NamedTensorMap as the\n   * input type. For example, given the graph\n   *    InputNode => Intermediate => OutputNode,\n   * you can execute the subgraph Intermediate => OutputNode by calling\n   *    model.execute('IntermediateNode' : tf.tensor(...));\n   *\n   * This is useful for models that uses tf.dynamic_rnn, where the intermediate\n   * state needs to be fed manually.\n   *\n   * For batch inference execution, the tensors for each input need to be\n   * concatenated together. For example with mobilenet, the required input shape\n   * is [1, 244, 244, 3], which represents the [batch, height, width, channel].\n   * If we are provide a batched data of 100 images, the input tensor should be\n   * in the shape of [100, 244, 244, 3].\n   *\n   * @param config Prediction configuration for specifying the batch size and\n   * output node names. Currently the batch size option is ignored for graph\n   * model.\n   *\n   * @returns Inference result tensors. The output would be single `tf.Tensor`\n   * if model has single output node, otherwise Tensor[] or NamedTensorMap[]\n   * will be returned for model with multiple outputs.\n   *\n   * @doc {heading: 'Models', subheading: 'Classes'}\n   */\n  predict(inputs: Tensor|Tensor[]|NamedTensorMap, config?: ModelPredictConfig):\n      Tensor|Tensor[]|NamedTensorMap {\n    return this.execute(inputs, this.outputNodes);\n  }\n\n  private normalizeInputs(inputs: Tensor|Tensor[]|\n                          NamedTensorMap): NamedTensorMap {\n    if (!(inputs instanceof Tensor) && !Array.isArray(inputs)) {\n      // The input is already a NamedTensorMap.\n      return inputs;\n    }\n    inputs = Array.isArray(inputs) ? inputs : [inputs];\n    if (inputs.length !== this.inputNodes.length) {\n      throw new Error(\n          'Input tensor count mismatch,' +\n          `the graph model has ${this.inputNodes.length} placeholders, ` +\n          `while there are ${inputs.length} input tensors.`);\n    }\n    return this.inputNodes.reduce((map, inputName, i) => {\n      map[inputName] = (inputs as Tensor[])[i];\n      return map;\n    }, {} as NamedTensorMap);\n  }\n\n  private normalizeOutputs(outputs: string|string[]): string[] {\n    outputs = outputs || this.outputNodes;\n    return !Array.isArray(outputs) ? [outputs] : outputs;\n  }\n\n  /**\n   * Executes inference for the model for given input tensors.\n   * @param inputs tensor, tensor array or tensor map of the inputs for the\n   * model, keyed by the input node names.\n   * @param outputs output node name from the Tensorflow model, if no\n   * outputs are specified, the default outputs of the model would be used.\n   * You can inspect intermediate nodes of the model by adding them to the\n   * outputs array.\n   *\n   * @returns A single tensor if provided with a single output or no outputs\n   * are provided and there is only one default output, otherwise return a\n   * tensor array. The order of the tensor array is the same as the outputs\n   * if provided, otherwise the order of outputNodes attribute of the model.\n   *\n   * @doc {heading: 'Models', subheading: 'Classes'}\n   */\n  execute(inputs: Tensor|Tensor[]|NamedTensorMap, outputs?: string|string[]):\n      Tensor|Tensor[] {\n    inputs = this.normalizeInputs(inputs);\n    outputs = this.normalizeOutputs(outputs);\n    const result = this.executor.execute(inputs, outputs);\n    return result.length > 1 ? result : result[0];\n  }\n  /**\n   * Executes inference for the model for given input tensors in async\n   * fashion, use this method when your model contains control flow ops.\n   * @param inputs tensor, tensor array or tensor map of the inputs for the\n   * model, keyed by the input node names.\n   * @param outputs output node name from the Tensorflow model, if no outputs\n   * are specified, the default outputs of the model would be used. You can\n   * inspect intermediate nodes of the model by adding them to the outputs\n   * array.\n   *\n   * @returns A Promise of single tensor if provided with a single output or\n   * no outputs are provided and there is only one default output, otherwise\n   * return a tensor map.\n   *\n   * @doc {heading: 'Models', subheading: 'Classes'}\n   */\n  async executeAsync(\n      inputs: Tensor|Tensor[]|NamedTensorMap,\n      outputs?: string|string[]): Promise<Tensor|Tensor[]> {\n    inputs = this.normalizeInputs(inputs);\n    outputs = this.normalizeOutputs(outputs);\n    const result = await this.executor.executeAsync(inputs, outputs);\n    return result.length > 1 ? result : result[0];\n  }\n\n  /**\n   * Get intermediate tensors for model debugging mode (flag\n   * KEEP_INTERMEDIATE_TENSORS is true).\n   *\n   * @doc {heading: 'Models', subheading: 'Classes'}\n   */\n  getIntermediateTensors(): NamedTensorsMap {\n    return this.executor.getIntermediateTensors();\n  }\n\n  /**\n   * Dispose intermediate tensors for model debugging mode (flag\n   * KEEP_INTERMEDIATE_TENSORS is true).\n   *\n   * @doc {heading: 'Models', subheading: 'Classes'}\n   */\n  disposeIntermediateTensors() {\n    this.executor.disposeIntermediateTensors();\n  }\n\n  private convertTensorMapToTensorsMap(map: NamedTensorMap): NamedTensorsMap {\n    return Object.keys(map).reduce((newMap: NamedTensorsMap, key) => {\n      newMap[key] = [map[key]];\n      return newMap;\n    }, {});\n  }\n\n  /**\n   * Releases the memory used by the weight tensors and resourceManager.\n   *\n   * @doc {heading: 'Models', subheading: 'Classes'}\n   */\n  dispose() {\n    this.executor.dispose();\n\n    if (this.initializer) {\n      this.initializer.dispose();\n    }\n\n    this.resourceManager.dispose();\n  }\n}\n\n/**\n * Load a graph model given a URL to the model definition.\n *\n * Example of loading MobileNetV2 from a URL and making a prediction with a\n * zeros input:\n *\n * ```js\n * const modelUrl =\n *    'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';\n * const model = await tf.loadGraphModel(modelUrl);\n * const zeros = tf.zeros([1, 224, 224, 3]);\n * model.predict(zeros).print();\n * ```\n *\n * Example of loading MobileNetV2 from a TF Hub URL and making a prediction with\n * a zeros input:\n *\n * ```js\n * const modelUrl =\n *    'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';\n * const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true});\n * const zeros = tf.zeros([1, 224, 224, 3]);\n * model.predict(zeros).print();\n * ```\n * @param modelUrl The url or an `io.IOHandler` that loads the model.\n * @param options Options for the HTTP request, which allows to send credentials\n *    and custom headers.\n *\n * @doc {heading: 'Models', subheading: 'Loading'}\n */\nexport async function loadGraphModel(\n    modelUrl: string|io.IOHandler,\n    options: io.LoadOptions = {}): Promise<GraphModel> {\n  if (modelUrl == null) {\n    throw new Error(\n        'modelUrl in loadGraphModel() cannot be null. Please provide a url ' +\n        'or an IOHandler that loads the model');\n  }\n  if (options == null) {\n    options = {};\n  }\n\n  if (options.fromTFHub) {\n    if ((modelUrl as io.IOHandler).load == null) {\n      if (!(modelUrl as string).endsWith('/')) {\n        modelUrl = (modelUrl as string) + '/';\n      }\n      modelUrl = `${modelUrl}${DEFAULT_MODEL_NAME}${TFHUB_SEARCH_PARAM}`;\n    }\n  }\n  const model = new GraphModel(modelUrl, options);\n  await model.load();\n  return model;\n}\n"]}