@tensorflow/tfjs-converter
Version:
Tensorflow model converter for javascript
406 lines • 50 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { io, Tensor } from '@tensorflow/tfjs-core';
import { OperationMapper } from '../operations/operation_mapper';
import { GraphExecutor } from './graph_executor';
import { ResourceManager } from './resource_manager';
export const TFHUB_SEARCH_PARAM = '?tfjs-format=file';
export const DEFAULT_MODEL_NAME = 'model.json';
/**
* A `tf.GraphModel` is a directed, acyclic graph built from a
* SavedModel GraphDef and allows inference execution.
*
* A `tf.GraphModel` can only be created by loading from a model converted from
* a [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) using
* the command line converter tool and loaded via `tf.loadGraphModel`.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
export class GraphModel {
/**
* @param modelUrl url for the model, or an `io.IOHandler`.
* @param weightManifestUrl url for the weight file generated by
* scripts/convert.py script.
* @param requestOption options for Request, which allows to send credentials
* and custom headers.
* @param onProgress Optional, progress callback function, fired periodically
* before the load is completed.
*/
constructor(modelUrl, loadOptions = {}) {
this.modelUrl = modelUrl;
this.loadOptions = loadOptions;
this.version = 'n/a';
if (loadOptions == null) {
this.loadOptions = {};
}
this.resourceManager = new ResourceManager();
}
// Returns the version information for the tensorflow model GraphDef.
get modelVersion() {
return this.version;
}
get inputNodes() {
return this.executor.inputNodes;
}
get outputNodes() {
return this.executor.outputNodes;
}
get inputs() {
return this.executor.inputs;
}
get outputs() {
return this.executor.outputs;
}
get weights() {
return this.executor.weightMap;
}
get metadata() {
return this.artifacts.userDefinedMetadata;
}
get modelSignature() {
return this.signature;
}
findIOHandler() {
const path = this.modelUrl;
if (path.load != null) {
// Path is an IO Handler.
this.handler = path;
}
else if (this.loadOptions.requestInit != null) {
this.handler = io.browserHTTPRequest(path, this.loadOptions);
}
else {
const handlers = io.getLoadHandlers(path, this.loadOptions);
if (handlers.length === 0) {
// For backward compatibility: if no load handler can be found,
// assume it is a relative http path.
handlers.push(io.browserHTTPRequest(path, this.loadOptions));
}
else if (handlers.length > 1) {
throw new Error(`Found more than one (${handlers.length}) load handlers for ` +
`URL '${[path]}'`);
}
this.handler = handlers[0];
}
}
/**
* Loads the model and weight files, construct the in memory weight map and
* compile the inference graph.
*/
async load() {
this.findIOHandler();
if (this.handler.load == null) {
throw new Error('Cannot proceed with model loading because the IOHandler provided ' +
'does not have the `load` method implemented.');
}
const artifacts = await this.handler.load();
return this.loadSync(artifacts);
}
/**
* Synchronously construct the in memory weight map and
* compile the inference graph. Also initialize hashtable if any.
*
* @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
*/
loadSync(artifacts) {
this.artifacts = artifacts;
const graph = this.artifacts.modelTopology;
let signature;
if (this.artifacts.userDefinedMetadata != null &&
this.artifacts.userDefinedMetadata.signature != null) {
signature = // tslint:disable-next-line:no-any
this.artifacts.userDefinedMetadata.signature;
}
else {
signature = this.artifacts.signature;
}
this.signature = signature;
this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`;
const weightMap = io.decodeWeights(this.artifacts.weightData, this.artifacts.weightSpecs);
this.executor = new GraphExecutor(OperationMapper.Instance.transformGraph(graph, this.signature));
this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap);
// Attach a model-level resourceManager to each executor to share resources,
// such as `HashTable`.
this.executor.resourceManager = this.resourceManager;
if (artifacts.modelInitializer != null &&
artifacts.modelInitializer.node != null) {
const initializer = OperationMapper.Instance.transformGraph(artifacts.modelInitializer);
this.initializer = new GraphExecutor(initializer);
this.initializer.weightMap = this.executor.weightMap;
// Attach a model-level resourceManager to the initializer, the
// hashTables created from when executing the initializer will be stored
// in the resourceManager.
this.initializer.resourceManager = this.resourceManager;
this.initializer.executeAsync({}, []);
}
return true;
}
/**
* Save the configuration and/or weights of the GraphModel.
*
* An `IOHandler` is an object that has a `save` method of the proper
* signature defined. The `save` method manages the storing or
* transmission of serialized data ("artifacts") that represent the
* model's topology and weights onto or via a specific medium, such as
* file downloads, local storage, IndexedDB in the web browser and HTTP
* requests to a server. TensorFlow.js provides `IOHandler`
* implementations for a number of frequently used saving mediums, such as
* `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
* for more details.
*
* This method also allows you to refer to certain types of `IOHandler`s
* as URL-like string shortcuts, such as 'localstorage://' and
* 'indexeddb://'.
*
* Example 1: Save `model`'s topology and weights to browser [local
* storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
* then load it back.
*
* ```js
* const modelUrl =
* 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
* const model = await tf.loadGraphModel(modelUrl);
* const zeros = tf.zeros([1, 224, 224, 3]);
* model.predict(zeros).print();
*
* const saveResults = await model.save('localstorage://my-model-1');
*
* const loadedModel = await tf.loadGraphModel('localstorage://my-model-1');
* console.log('Prediction from loaded model:');
* model.predict(zeros).print();
* ```
*
* @param handlerOrURL An instance of `IOHandler` or a URL-like,
* scheme-based string shortcut for `IOHandler`.
* @param config Options for saving the model.
* @returns A `Promise` of `SaveResult`, which summarizes the result of
* the saving, such as byte sizes of the saved artifacts for the model's
* topology and weight values.
*
* @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
*/
async save(handlerOrURL, config) {
if (typeof handlerOrURL === 'string') {
const handlers = io.getSaveHandlers(handlerOrURL);
if (handlers.length === 0) {
throw new Error(`Cannot find any save handlers for URL '${handlerOrURL}'`);
}
else if (handlers.length > 1) {
throw new Error(`Found more than one (${handlers.length}) save handlers for ` +
`URL '${handlerOrURL}'`);
}
handlerOrURL = handlers[0];
}
if (handlerOrURL.save == null) {
throw new Error('GraphModel.save() cannot proceed because the IOHandler ' +
'provided does not have the `save` attribute defined.');
}
return handlerOrURL.save(this.artifacts);
}
/**
* Execute the inference for the input tensors.
*
* @param input The input tensors, when there is single input for the model,
* inputs param should be a `tf.Tensor`. For models with mutliple inputs,
* inputs params should be in either `tf.Tensor`[] if the input order is
* fixed, or otherwise NamedTensorMap format.
*
* For model with multiple inputs, we recommend you use NamedTensorMap as the
* input type, if you use `tf.Tensor`[], the order of the array needs to
* follow the
* order of inputNodes array. @see {@link GraphModel.inputNodes}
*
* You can also feed any intermediate nodes using the NamedTensorMap as the
* input type. For example, given the graph
* InputNode => Intermediate => OutputNode,
* you can execute the subgraph Intermediate => OutputNode by calling
* model.execute('IntermediateNode' : tf.tensor(...));
*
* This is useful for models that uses tf.dynamic_rnn, where the intermediate
* state needs to be fed manually.
*
* For batch inference execution, the tensors for each input need to be
* concatenated together. For example with mobilenet, the required input shape
* is [1, 244, 244, 3], which represents the [batch, height, width, channel].
* If we are provide a batched data of 100 images, the input tensor should be
* in the shape of [100, 244, 244, 3].
*
* @param config Prediction configuration for specifying the batch size and
* output node names. Currently the batch size option is ignored for graph
* model.
*
* @returns Inference result tensors. The output would be single `tf.Tensor`
* if model has single output node, otherwise Tensor[] or NamedTensorMap[]
* will be returned for model with multiple outputs.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
predict(inputs, config) {
return this.execute(inputs, this.outputNodes);
}
normalizeInputs(inputs) {
if (!(inputs instanceof Tensor) && !Array.isArray(inputs)) {
// The input is already a NamedTensorMap.
return inputs;
}
inputs = Array.isArray(inputs) ? inputs : [inputs];
if (inputs.length !== this.inputNodes.length) {
throw new Error('Input tensor count mismatch,' +
`the graph model has ${this.inputNodes.length} placeholders, ` +
`while there are ${inputs.length} input tensors.`);
}
return this.inputNodes.reduce((map, inputName, i) => {
map[inputName] = inputs[i];
return map;
}, {});
}
normalizeOutputs(outputs) {
outputs = outputs || this.outputNodes;
return !Array.isArray(outputs) ? [outputs] : outputs;
}
/**
* Executes inference for the model for given input tensors.
* @param inputs tensor, tensor array or tensor map of the inputs for the
* model, keyed by the input node names.
* @param outputs output node name from the Tensorflow model, if no
* outputs are specified, the default outputs of the model would be used.
* You can inspect intermediate nodes of the model by adding them to the
* outputs array.
*
* @returns A single tensor if provided with a single output or no outputs
* are provided and there is only one default output, otherwise return a
* tensor array. The order of the tensor array is the same as the outputs
* if provided, otherwise the order of outputNodes attribute of the model.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
execute(inputs, outputs) {
inputs = this.normalizeInputs(inputs);
outputs = this.normalizeOutputs(outputs);
const result = this.executor.execute(inputs, outputs);
return result.length > 1 ? result : result[0];
}
/**
* Executes inference for the model for given input tensors in async
* fashion, use this method when your model contains control flow ops.
* @param inputs tensor, tensor array or tensor map of the inputs for the
* model, keyed by the input node names.
* @param outputs output node name from the Tensorflow model, if no outputs
* are specified, the default outputs of the model would be used. You can
* inspect intermediate nodes of the model by adding them to the outputs
* array.
*
* @returns A Promise of single tensor if provided with a single output or
* no outputs are provided and there is only one default output, otherwise
* return a tensor map.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
async executeAsync(inputs, outputs) {
inputs = this.normalizeInputs(inputs);
outputs = this.normalizeOutputs(outputs);
const result = await this.executor.executeAsync(inputs, outputs);
return result.length > 1 ? result : result[0];
}
/**
* Get intermediate tensors for model debugging mode (flag
* KEEP_INTERMEDIATE_TENSORS is true).
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
getIntermediateTensors() {
return this.executor.getIntermediateTensors();
}
/**
* Dispose intermediate tensors for model debugging mode (flag
* KEEP_INTERMEDIATE_TENSORS is true).
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
disposeIntermediateTensors() {
this.executor.disposeIntermediateTensors();
}
convertTensorMapToTensorsMap(map) {
return Object.keys(map).reduce((newMap, key) => {
newMap[key] = [map[key]];
return newMap;
}, {});
}
/**
* Releases the memory used by the weight tensors and resourceManager.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
dispose() {
this.executor.dispose();
if (this.initializer) {
this.initializer.dispose();
}
this.resourceManager.dispose();
}
}
/**
* Load a graph model given a URL to the model definition.
*
* Example of loading MobileNetV2 from a URL and making a prediction with a
* zeros input:
*
* ```js
* const modelUrl =
* 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
* const model = await tf.loadGraphModel(modelUrl);
* const zeros = tf.zeros([1, 224, 224, 3]);
* model.predict(zeros).print();
* ```
*
* Example of loading MobileNetV2 from a TF Hub URL and making a prediction with
* a zeros input:
*
* ```js
* const modelUrl =
* 'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
* const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true});
* const zeros = tf.zeros([1, 224, 224, 3]);
* model.predict(zeros).print();
* ```
* @param modelUrl The url or an `io.IOHandler` that loads the model.
* @param options Options for the HTTP request, which allows to send credentials
* and custom headers.
*
* @doc {heading: 'Models', subheading: 'Loading'}
*/
export async function loadGraphModel(modelUrl, options = {}) {
if (modelUrl == null) {
throw new Error('modelUrl in loadGraphModel() cannot be null. Please provide a url ' +
'or an IOHandler that loads the model');
}
if (options == null) {
options = {};
}
if (options.fromTFHub) {
if (modelUrl.load == null) {
if (!modelUrl.endsWith('/')) {
modelUrl = modelUrl + '/';
}
modelUrl = `${modelUrl}${DEFAULT_MODEL_NAME}${TFHUB_SEARCH_PARAM}`;
}
}
const model = new GraphModel(modelUrl, options);
await model.load();
return model;
}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"graph_model.js","sourceRoot":"","sources":["../../../../../../tfjs-converter/src/executor/graph_model.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAiB,EAAE,EAAsC,MAAM,EAAC,MAAM,uBAAuB,CAAC;AAIrG,OAAO,EAAC,eAAe,EAAC,MAAM,gCAAgC,CAAC;AAE/D,OAAO,EAAC,aAAa,EAAC,MAAM,kBAAkB,CAAC;AAC/C,OAAO,EAAC,eAAe,EAAC,MAAM,oBAAoB,CAAC;AAEnD,MAAM,CAAC,MAAM,kBAAkB,GAAG,mBAAmB,CAAC;AACtD,MAAM,CAAC,MAAM,kBAAkB,GAAG,YAAY,CAAC;AAC/C;;;;;;;;;GASG;AACH,MAAM,OAAO,UAAU;IA0CrB;;;;;;;;OAQG;IACH,YACY,QAA6B,EAC7B,cAA8B,EAAE;QADhC,aAAQ,GAAR,QAAQ,CAAqB;QAC7B,gBAAW,GAAX,WAAW,CAAqB;QAnDpC,YAAO,GAAG,KAAK,CAAC;QAoDtB,IAAI,WAAW,IAAI,IAAI,EAAE;YACvB,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC;SACvB;QACD,IAAI,CAAC,eAAe,GAAG,IAAI,eAAe,EAAE,CAAC;IAC/C,CAAC;IAjDD,qEAAqE;IACrE,IAAI,YAAY;QACd,OAAO,IAAI,CAAC,OAAO,CAAC;IACtB,CAAC;IAED,IAAI,UAAU;QACZ,OAAO,IAAI,CAAC,QAAQ,CAAC,UAAU,CAAC;IAClC,CAAC;IAED,IAAI,WAAW;QACb,OAAO,IAAI,CAAC,QAAQ,CAAC,WAAW,CAAC;IACnC,CAAC;IAED,IAAI,MAAM;QACR,OAAO,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC;IAC9B,CAAC;IAED,IAAI,OAAO;QACT,OAAO,IAAI,CAAC,QAAQ,CAAC,OAAO,CAAC;IAC/B,CAAC;IAED,IAAI,OAAO;QACT,OAAO,IAAI,CAAC,QAAQ,CAAC,SAAS,CAAC;IACjC,CAAC;IAED,IAAI,QAAQ;QACV,OAAO,IAAI,CAAC,SAAS,CAAC,mBAAmB,CAAC;IAC5C,CAAC;IAED,IAAI,cAAc;QAChB,OAAO,IAAI,CAAC,SAAS,CAAC;IACxB,CAAC;IAoBO,aAAa;QACnB,MAAM,IAAI,GAAG,IAAI,CAAC,QAAQ,CAAC;QAC3B,IAAK,IAAqB,CAAC,IAAI,IAAI,IAAI,EAAE;YACvC,yBAAyB;YACzB,IAAI,CAAC,OAAO,GAAG,IAAoB,CAAC;SACrC;aAAM,IAAI,IAAI,CAAC,WAAW,CAAC,WAAW,IAAI,IAAI,EAAE;YAC/C,IAAI,CAAC,OAAO,GAAG,EAAE,CAAC,kBAAkB,CAAC,IAAc,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;SACxE;aAAM;YACL,MAAM,QAAQ,GAAG,EAAE,CAAC,eAAe,CAAC,IAAc,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;YACtE,IAAI,QAAQ,CAAC,MAAM,KAAK,CAAC,EAAE;gBACzB,+DAA+D;gBAC/D,qCAAqC;gBACrC,QAAQ,CAAC,IAAI,CAAC,EAAE,CAAC,kBAAkB,CAAC,IAAc,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC;aACxE;iBAAM,IAAI,QAAQ,CAAC,MAAM,GAAG,CAAC,EAAE;gBAC9B,MAAM,IAAI,KAAK,CACX,wBAAwB,QAAQ,CAAC,MAAM,sBAAsB;oBAC7D,QAAQ,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;aACxB;YACD,IAAI,CAAC,OAAO,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAC;SAC5B;IACH,CAAC;IAED;;;OAGG;IACH,KAAK,CAAC,IAAI;QACR,IAAI,CAAC,aAAa,EAAE,CAAC;QACrB,IAAI,IAAI,CAAC,OAAO,CAAC,IAAI,IAAI,IAAI,EAAE;YAC7B,MAAM,IAAI,KAAK,CACX,mEAAmE;gBACnE,8CAA8C,CAAC,CAAC;SACrD;QACD,MAAM,SAAS,GAAG,MAAM,IAAI,CAAC,OAAO,CAAC,IAAI,EAAE,CAAC;QAE5C,OAAO,IAAI,CAAC,QAAQ,CAAC,SAAS,CAAC,CAAC;IAClC,CAAC;IAED;;;;;OAKG;IACH,QAAQ,CAAC,SAA4B;QACnC,IAAI,CAAC,SAAS,GAAG,SAAS,CAAC;QAC3B,MAAM,KAAK,GAAG,IAAI,CAAC,SAAS,CAAC,aAAqC,CAAC;QAEnE,IAAI,SAAS,CAAC;QACd,IAAI,IAAI,CAAC,SAAS,CAAC,mBAAmB,IAAI,IAAI;YAC1C,IAAI,CAAC,SAAS,CAAC,mBAAmB,CAAC,SAAS,IAAI,IAAI,EAAE;YACxD,SAAS,GAAI,kCAAkC;gBAC1C,IAAI,CAAC,SAAS,CAAC,mBAA2B,CAAC,SACpB,CAAC;SAC9B;aAAM;YACL,SAAS,GAAG,IAAI,CAAC,SAAS,CAAC,SAAS,CAAC;SACtC;QACD,IAAI,CAAC,SAAS,GAAG,SAAS,CAAC;QAE3B,IAAI,CAAC,OAAO,GAAG,GAAG,KAAK,CAAC,QAAQ,CAAC,QAAQ,IAAI,KAAK,CAAC,QAAQ,CAAC,WAAW,EAAE,CAAC;QAC1E,MAAM,SAAS,GACX,EAAE,CAAC,aAAa,CAAC,IAAI,CAAC,SAAS,CAAC,UAAU,EAAE,IAAI,CAAC,SAAS,CAAC,WAAW,CAAC,CAAC;QAC5E,IAAI,CAAC,QAAQ,GAAG,IAAI,aAAa,CAC7B,eAAe,CAAC,QAAQ,CAAC,cAAc,CAAC,KAAK,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC,CAAC;QACpE,IAAI,CAAC,QAAQ,CAAC,SAAS,GAAG,IAAI,CAAC,4BAA4B,CAAC,SAAS,CAAC,CAAC;QACvE,4EAA4E;QAC5E,uBAAuB;QACvB,IAAI,CAAC,QAAQ,CAAC,eAAe,GAAG,IAAI,CAAC,eAAe,CAAC;QAErD,IAAI,SAAS,CAAC,gBAAgB,IAAI,IAAI;YACjC,SAAS,CAAC,gBAAyC,CAAC,IAAI,IAAI,IAAI,EAAE;YACrE,MAAM,WAAW,GACb,eAAe,CAAC,QAAQ,CAAC,cAAc,CAAC,SAAS,CAAC,gBAAgB,CAAC,CAAC;YACxE,IAAI,CAAC,WAAW,GAAG,IAAI,aAAa,CAAC,WAAW,CAAC,CAAC;YAClD,IAAI,CAAC,WAAW,CAAC,SAAS,GAAG,IAAI,CAAC,QAAQ,CAAC,SAAS,CAAC;YACrD,+DAA+D;YAC/D,wEAAwE;YACxE,0BAA0B;YAC1B,IAAI,CAAC,WAAW,CAAC,eAAe,GAAG,IAAI,CAAC,eAAe,CAAC;YACxD,IAAI,CAAC,WAAW,CAAC,YAAY,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC;SACvC;QAED,OAAO,IAAI,CAAC;IACd,CAAC;IAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;OA2CG;IACH,KAAK,CAAC,IAAI,CAAC,YAAiC,EAAE,MAAsB;QAElE,IAAI,OAAO,YAAY,KAAK,QAAQ,EAAE;YACpC,MAAM,QAAQ,GAAG,EAAE,CAAC,eAAe,CAAC,YAAY,CAAC,CAAC;YAClD,IAAI,QAAQ,CAAC,MAAM,KAAK,CAAC,EAAE;gBACzB,MAAM,IAAI,KAAK,CACX,0CAA0C,YAAY,GAAG,CAAC,CAAC;aAChE;iBAAM,IAAI,QAAQ,CAAC,MAAM,GAAG,CAAC,EAAE;gBAC9B,MAAM,IAAI,KAAK,CACX,wBAAwB,QAAQ,CAAC,MAAM,sBAAsB;oBAC7D,QAAQ,YAAY,GAAG,CAAC,CAAC;aAC9B;YACD,YAAY,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAC;SAC5B;QACD,IAAI,YAAY,CAAC,IAAI,IAAI,IAAI,EAAE;YAC7B,MAAM,IAAI,KAAK,CACX,yDAAyD;gBACzD,sDAAsD,CAAC,CAAC;SAC7D;QAED,OAAO,YAAY,CAAC,IAAI,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;IAC3C,CAAC;IAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;OAqCG;IACH,OAAO,CAAC,MAAsC,EAAE,MAA2B;QAEzE,OAAO,IAAI,CAAC,OAAO,CAAC,MAAM,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;IAChD,CAAC;IAEO,eAAe,CAAC,MACc;QACpC,IAAI,CAAC,CAAC,MAAM,YAAY,MAAM,CAAC,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;YACzD,yCAAyC;YACzC,OAAO,MAAM,CAAC;SACf;QACD,MAAM,GAAG,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;QACnD,IAAI,MAAM,CAAC,MAAM,KAAK,IAAI,CAAC,UAAU,CAAC,MAAM,EAAE;YAC5C,MAAM,IAAI,KAAK,CACX,8BAA8B;gBAC9B,uBAAuB,IAAI,CAAC,UAAU,CAAC,MAAM,iBAAiB;gBAC9D,mBAAmB,MAAM,CAAC,MAAM,iBAAiB,CAAC,CAAC;SACxD;QACD,OAAO,IAAI,CAAC,UAAU,CAAC,MAAM,CAAC,CAAC,GAAG,EAAE,SAAS,EAAE,CAAC,EAAE,EAAE;YAClD,GAAG,CAAC,SAAS,CAAC,GAAI,MAAmB,CAAC,CAAC,CAAC,CAAC;YACzC,OAAO,GAAG,CAAC;QACb,CAAC,EAAE,EAAoB,CAAC,CAAC;IAC3B,CAAC;IAEO,gBAAgB,CAAC,OAAwB;QAC/C,OAAO,GAAG,OAAO,IAAI,IAAI,CAAC,WAAW,CAAC;QACtC,OAAO,CAAC,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC;IACvD,CAAC;IAED;;;;;;;;;;;;;;;OAeG;IACH,OAAO,CAAC,MAAsC,EAAE,OAAyB;QAEvE,MAAM,GAAG,IAAI,CAAC,eAAe,CAAC,MAAM,CAAC,CAAC;QACtC,OAAO,GAAG,IAAI,CAAC,gBAAgB,CAAC,OAAO,CAAC,CAAC;QACzC,MAAM,MAAM,GAAG,IAAI,CAAC,QAAQ,CAAC,OAAO,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;QACtD,OAAO,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;IAChD,CAAC;IACD;;;;;;;;;;;;;;;OAeG;IACH,KAAK,CAAC,YAAY,CACd,MAAsC,EACtC,OAAyB;QAC3B,MAAM,GAAG,IAAI,CAAC,eAAe,CAAC,MAAM,CAAC,CAAC;QACtC,OAAO,GAAG,IAAI,CAAC,gBAAgB,CAAC,OAAO,CAAC,CAAC;QACzC,MAAM,MAAM,GAAG,MAAM,IAAI,CAAC,QAAQ,CAAC,YAAY,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;QACjE,OAAO,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;IAChD,CAAC;IAED;;;;;OAKG;IACH,sBAAsB;QACpB,OAAO,IAAI,CAAC,QAAQ,CAAC,sBAAsB,EAAE,CAAC;IAChD,CAAC;IAED;;;;;OAKG;IACH,0BAA0B;QACxB,IAAI,CAAC,QAAQ,CAAC,0BAA0B,EAAE,CAAC;IAC7C,CAAC;IAEO,4BAA4B,CAAC,GAAmB;QACtD,OAAO,MAAM,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,MAAM,CAAC,CAAC,MAAuB,EAAE,GAAG,EAAE,EAAE;YAC9D,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;YACzB,OAAO,MAAM,CAAC;QAChB,CAAC,EAAE,EAAE,CAAC,CAAC;IACT,CAAC;IAED;;;;OAIG;IACH,OAAO;QACL,IAAI,CAAC,QAAQ,CAAC,OAAO,EAAE,CAAC;QAExB,IAAI,IAAI,CAAC,WAAW,EAAE;YACpB,IAAI,CAAC,WAAW,CAAC,OAAO,EAAE,CAAC;SAC5B;QAED,IAAI,CAAC,eAAe,CAAC,OAAO,EAAE,CAAC;IACjC,CAAC;CACF;AAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA6BG;AACH,MAAM,CAAC,KAAK,UAAU,cAAc,CAChC,QAA6B,EAC7B,UAA0B,EAAE;IAC9B,IAAI,QAAQ,IAAI,IAAI,EAAE;QACpB,MAAM,IAAI,KAAK,CACX,oEAAoE;YACpE,sCAAsC,CAAC,CAAC;KAC7C;IACD,IAAI,OAAO,IAAI,IAAI,EAAE;QACnB,OAAO,GAAG,EAAE,CAAC;KACd;IAED,IAAI,OAAO,CAAC,SAAS,EAAE;QACrB,IAAK,QAAyB,CAAC,IAAI,IAAI,IAAI,EAAE;YAC3C,IAAI,CAAE,QAAmB,CAAC,QAAQ,CAAC,GAAG,CAAC,EAAE;gBACvC,QAAQ,GAAI,QAAmB,GAAG,GAAG,CAAC;aACvC;YACD,QAAQ,GAAG,GAAG,QAAQ,GAAG,kBAAkB,GAAG,kBAAkB,EAAE,CAAC;SACpE;KACF;IACD,MAAM,KAAK,GAAG,IAAI,UAAU,CAAC,QAAQ,EAAE,OAAO,CAAC,CAAC;IAChD,MAAM,KAAK,CAAC,IAAI,EAAE,CAAC;IACnB,OAAO,KAAK,CAAC;AACf,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {InferenceModel, io, ModelPredictConfig, NamedTensorMap, Tensor} from '@tensorflow/tfjs-core';\n\nimport * as tensorflow from '../data/compiled_api';\nimport {NamedTensorsMap, TensorInfo} from '../data/types';\nimport {OperationMapper} from '../operations/operation_mapper';\n\nimport {GraphExecutor} from './graph_executor';\nimport {ResourceManager} from './resource_manager';\n\nexport const TFHUB_SEARCH_PARAM = '?tfjs-format=file';\nexport const DEFAULT_MODEL_NAME = 'model.json';\n/**\n * A `tf.GraphModel` is a directed, acyclic graph built from a\n * SavedModel GraphDef and allows inference execution.\n *\n * A `tf.GraphModel` can only be created by loading from a model converted from\n * a [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) using\n * the command line converter tool and loaded via `tf.loadGraphModel`.\n *\n * @doc {heading: 'Models', subheading: 'Classes'}\n */\nexport class GraphModel implements InferenceModel {\n  private executor: GraphExecutor;\n  private version = 'n/a';\n  private handler: io.IOHandler;\n  private artifacts: io.ModelArtifacts;\n  private initializer: GraphExecutor;\n  private resourceManager: ResourceManager;\n  private signature: tensorflow.ISignatureDef;\n\n  // Returns the version information for the tensorflow model GraphDef.\n  get modelVersion(): string {\n    return this.version;\n  }\n\n  get inputNodes(): string[] {\n    return this.executor.inputNodes;\n  }\n\n  get outputNodes(): string[] {\n    return this.executor.outputNodes;\n  }\n\n  get inputs(): TensorInfo[] {\n    return this.executor.inputs;\n  }\n\n  get outputs(): TensorInfo[] {\n    return this.executor.outputs;\n  }\n\n  get weights(): NamedTensorsMap {\n    return this.executor.weightMap;\n  }\n\n  get metadata(): {} {\n    return this.artifacts.userDefinedMetadata;\n  }\n\n  get modelSignature(): {} {\n    return this.signature;\n  }\n\n  /**\n   * @param modelUrl url for the model, or an `io.IOHandler`.\n   * @param weightManifestUrl url for the weight file generated by\n   * scripts/convert.py script.\n   * @param requestOption options for Request, which allows to send credentials\n   * and custom headers.\n   * @param onProgress Optional, progress callback function, fired periodically\n   * before the load is completed.\n   */\n  constructor(\n      private modelUrl: string|io.IOHandler,\n      private loadOptions: io.LoadOptions = {}) {\n    if (loadOptions == null) {\n      this.loadOptions = {};\n    }\n    this.resourceManager = new ResourceManager();\n  }\n\n  private findIOHandler() {\n    const path = this.modelUrl;\n    if ((path as io.IOHandler).load != null) {\n      // Path is an IO Handler.\n      this.handler = path as io.IOHandler;\n    } else if (this.loadOptions.requestInit != null) {\n      this.handler = io.browserHTTPRequest(path as string, this.loadOptions);\n    } else {\n      const handlers = io.getLoadHandlers(path as string, this.loadOptions);\n      if (handlers.length === 0) {\n        // For backward compatibility: if no load handler can be found,\n        // assume it is a relative http path.\n        handlers.push(io.browserHTTPRequest(path as string, this.loadOptions));\n      } else if (handlers.length > 1) {\n        throw new Error(\n            `Found more than one (${handlers.length}) load handlers for ` +\n            `URL '${[path]}'`);\n      }\n      this.handler = handlers[0];\n    }\n  }\n\n  /**\n   * Loads the model and weight files, construct the in memory weight map and\n   * compile the inference graph.\n   */\n  async load(): Promise<boolean> {\n    this.findIOHandler();\n    if (this.handler.load == null) {\n      throw new Error(\n          'Cannot proceed with model loading because the IOHandler provided ' +\n          'does not have the `load` method implemented.');\n    }\n    const artifacts = await this.handler.load();\n\n    return this.loadSync(artifacts);\n  }\n\n  /**\n   * Synchronously construct the in memory weight map and\n   * compile the inference graph. Also initialize hashtable if any.\n   *\n   * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}\n   */\n  loadSync(artifacts: io.ModelArtifacts) {\n    this.artifacts = artifacts;\n    const graph = this.artifacts.modelTopology as tensorflow.IGraphDef;\n\n    let signature;\n    if (this.artifacts.userDefinedMetadata != null &&\n        this.artifacts.userDefinedMetadata.signature != null) {\n      signature =  // tslint:disable-next-line:no-any\n          (this.artifacts.userDefinedMetadata as any).signature as\n          tensorflow.ISignatureDef;\n    } else {\n      signature = this.artifacts.signature;\n    }\n    this.signature = signature;\n\n    this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`;\n    const weightMap =\n        io.decodeWeights(this.artifacts.weightData, this.artifacts.weightSpecs);\n    this.executor = new GraphExecutor(\n        OperationMapper.Instance.transformGraph(graph, this.signature));\n    this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap);\n    // Attach a model-level resourceManager to each executor to share resources,\n    // such as `HashTable`.\n    this.executor.resourceManager = this.resourceManager;\n\n    if (artifacts.modelInitializer != null &&\n        (artifacts.modelInitializer as tensorflow.IGraphDef).node != null) {\n      const initializer =\n          OperationMapper.Instance.transformGraph(artifacts.modelInitializer);\n      this.initializer = new GraphExecutor(initializer);\n      this.initializer.weightMap = this.executor.weightMap;\n      // Attach a model-level resourceManager to the initializer, the\n      // hashTables created from when executing the initializer will be stored\n      // in the resourceManager.\n      this.initializer.resourceManager = this.resourceManager;\n      this.initializer.executeAsync({}, []);\n    }\n\n    return true;\n  }\n\n  /**\n   * Save the configuration and/or weights of the GraphModel.\n   *\n   * An `IOHandler` is an object that has a `save` method of the proper\n   * signature defined. The `save` method manages the storing or\n   * transmission of serialized data (\"artifacts\") that represent the\n   * model's topology and weights onto or via a specific medium, such as\n   * file downloads, local storage, IndexedDB in the web browser and HTTP\n   * requests to a server. TensorFlow.js provides `IOHandler`\n   * implementations for a number of frequently used saving mediums, such as\n   * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`\n   * for more details.\n   *\n   * This method also allows you to refer to certain types of `IOHandler`s\n   * as URL-like string shortcuts, such as 'localstorage://' and\n   * 'indexeddb://'.\n   *\n   * Example 1: Save `model`'s topology and weights to browser [local\n   * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);\n   * then load it back.\n   *\n   * ```js\n   * const modelUrl =\n   *    'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';\n   * const model = await tf.loadGraphModel(modelUrl);\n   * const zeros = tf.zeros([1, 224, 224, 3]);\n   * model.predict(zeros).print();\n   *\n   * const saveResults = await model.save('localstorage://my-model-1');\n   *\n   * const loadedModel = await tf.loadGraphModel('localstorage://my-model-1');\n   * console.log('Prediction from loaded model:');\n   * model.predict(zeros).print();\n   * ```\n   *\n   * @param handlerOrURL An instance of `IOHandler` or a URL-like,\n   * scheme-based string shortcut for `IOHandler`.\n   * @param config Options for saving the model.\n   * @returns A `Promise` of `SaveResult`, which summarizes the result of\n   * the saving, such as byte sizes of the saved artifacts for the model's\n   *   topology and weight values.\n   *\n   * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}\n   */\n  async save(handlerOrURL: io.IOHandler|string, config?: io.SaveConfig):\n      Promise<io.SaveResult> {\n    if (typeof handlerOrURL === 'string') {\n      const handlers = io.getSaveHandlers(handlerOrURL);\n      if (handlers.length === 0) {\n        throw new Error(\n            `Cannot find any save handlers for URL '${handlerOrURL}'`);\n      } else if (handlers.length > 1) {\n        throw new Error(\n            `Found more than one (${handlers.length}) save handlers for ` +\n            `URL '${handlerOrURL}'`);\n      }\n      handlerOrURL = handlers[0];\n    }\n    if (handlerOrURL.save == null) {\n      throw new Error(\n          'GraphModel.save() cannot proceed because the IOHandler ' +\n          'provided does not have the `save` attribute defined.');\n    }\n\n    return handlerOrURL.save(this.artifacts);\n  }\n\n  /**\n   * Execute the inference for the input tensors.\n   *\n   * @param input The input tensors, when there is single input for the model,\n   * inputs param should be a `tf.Tensor`. For models with mutliple inputs,\n   * inputs params should be in either `tf.Tensor`[] if the input order is\n   * fixed, or otherwise NamedTensorMap format.\n   *\n   * For model with multiple inputs, we recommend you use NamedTensorMap as the\n   * input type, if you use `tf.Tensor`[], the order of the array needs to\n   * follow the\n   * order of inputNodes array. @see {@link GraphModel.inputNodes}\n   *\n   * You can also feed any intermediate nodes using the NamedTensorMap as the\n   * input type. For example, given the graph\n   *    InputNode => Intermediate => OutputNode,\n   * you can execute the subgraph Intermediate => OutputNode by calling\n   *    model.execute('IntermediateNode' : tf.tensor(...));\n   *\n   * This is useful for models that uses tf.dynamic_rnn, where the intermediate\n   * state needs to be fed manually.\n   *\n   * For batch inference execution, the tensors for each input need to be\n   * concatenated together. For example with mobilenet, the required input shape\n   * is [1, 244, 244, 3], which represents the [batch, height, width, channel].\n   * If we are provide a batched data of 100 images, the input tensor should be\n   * in the shape of [100, 244, 244, 3].\n   *\n   * @param config Prediction configuration for specifying the batch size and\n   * output node names. Currently the batch size option is ignored for graph\n   * model.\n   *\n   * @returns Inference result tensors. The output would be single `tf.Tensor`\n   * if model has single output node, otherwise Tensor[] or NamedTensorMap[]\n   * will be returned for model with multiple outputs.\n   *\n   * @doc {heading: 'Models', subheading: 'Classes'}\n   */\n  predict(inputs: Tensor|Tensor[]|NamedTensorMap, config?: ModelPredictConfig):\n      Tensor|Tensor[]|NamedTensorMap {\n    return this.execute(inputs, this.outputNodes);\n  }\n\n  private normalizeInputs(inputs: Tensor|Tensor[]|\n                          NamedTensorMap): NamedTensorMap {\n    if (!(inputs instanceof Tensor) && !Array.isArray(inputs)) {\n      // The input is already a NamedTensorMap.\n      return inputs;\n    }\n    inputs = Array.isArray(inputs) ? inputs : [inputs];\n    if (inputs.length !== this.inputNodes.length) {\n      throw new Error(\n          'Input tensor count mismatch,' +\n          `the graph model has ${this.inputNodes.length} placeholders, ` +\n          `while there are ${inputs.length} input tensors.`);\n    }\n    return this.inputNodes.reduce((map, inputName, i) => {\n      map[inputName] = (inputs as Tensor[])[i];\n      return map;\n    }, {} as NamedTensorMap);\n  }\n\n  private normalizeOutputs(outputs: string|string[]): string[] {\n    outputs = outputs || this.outputNodes;\n    return !Array.isArray(outputs) ? [outputs] : outputs;\n  }\n\n  /**\n   * Executes inference for the model for given input tensors.\n   * @param inputs tensor, tensor array or tensor map of the inputs for the\n   * model, keyed by the input node names.\n   * @param outputs output node name from the Tensorflow model, if no\n   * outputs are specified, the default outputs of the model would be used.\n   * You can inspect intermediate nodes of the model by adding them to the\n   * outputs array.\n   *\n   * @returns A single tensor if provided with a single output or no outputs\n   * are provided and there is only one default output, otherwise return a\n   * tensor array. The order of the tensor array is the same as the outputs\n   * if provided, otherwise the order of outputNodes attribute of the model.\n   *\n   * @doc {heading: 'Models', subheading: 'Classes'}\n   */\n  execute(inputs: Tensor|Tensor[]|NamedTensorMap, outputs?: string|string[]):\n      Tensor|Tensor[] {\n    inputs = this.normalizeInputs(inputs);\n    outputs = this.normalizeOutputs(outputs);\n    const result = this.executor.execute(inputs, outputs);\n    return result.length > 1 ? result : result[0];\n  }\n  /**\n   * Executes inference for the model for given input tensors in async\n   * fashion, use this method when your model contains control flow ops.\n   * @param inputs tensor, tensor array or tensor map of the inputs for the\n   * model, keyed by the input node names.\n   * @param outputs output node name from the Tensorflow model, if no outputs\n   * are specified, the default outputs of the model would be used. You can\n   * inspect intermediate nodes of the model by adding them to the outputs\n   * array.\n   *\n   * @returns A Promise of single tensor if provided with a single output or\n   * no outputs are provided and there is only one default output, otherwise\n   * return a tensor map.\n   *\n   * @doc {heading: 'Models', subheading: 'Classes'}\n   */\n  async executeAsync(\n      inputs: Tensor|Tensor[]|NamedTensorMap,\n      outputs?: string|string[]): Promise<Tensor|Tensor[]> {\n    inputs = this.normalizeInputs(inputs);\n    outputs = this.normalizeOutputs(outputs);\n    const result = await this.executor.executeAsync(inputs, outputs);\n    return result.length > 1 ? result : result[0];\n  }\n\n  /**\n   * Get intermediate tensors for model debugging mode (flag\n   * KEEP_INTERMEDIATE_TENSORS is true).\n   *\n   * @doc {heading: 'Models', subheading: 'Classes'}\n   */\n  getIntermediateTensors(): NamedTensorsMap {\n    return this.executor.getIntermediateTensors();\n  }\n\n  /**\n   * Dispose intermediate tensors for model debugging mode (flag\n   * KEEP_INTERMEDIATE_TENSORS is true).\n   *\n   * @doc {heading: 'Models', subheading: 'Classes'}\n   */\n  disposeIntermediateTensors() {\n    this.executor.disposeIntermediateTensors();\n  }\n\n  private convertTensorMapToTensorsMap(map: NamedTensorMap): NamedTensorsMap {\n    return Object.keys(map).reduce((newMap: NamedTensorsMap, key) => {\n      newMap[key] = [map[key]];\n      return newMap;\n    }, {});\n  }\n\n  /**\n   * Releases the memory used by the weight tensors and resourceManager.\n   *\n   * @doc {heading: 'Models', subheading: 'Classes'}\n   */\n  dispose() {\n    this.executor.dispose();\n\n    if (this.initializer) {\n      this.initializer.dispose();\n    }\n\n    this.resourceManager.dispose();\n  }\n}\n\n/**\n * Load a graph model given a URL to the model definition.\n *\n * Example of loading MobileNetV2 from a URL and making a prediction with a\n * zeros input:\n *\n * ```js\n * const modelUrl =\n *    'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';\n * const model = await tf.loadGraphModel(modelUrl);\n * const zeros = tf.zeros([1, 224, 224, 3]);\n * model.predict(zeros).print();\n * ```\n *\n * Example of loading MobileNetV2 from a TF Hub URL and making a prediction with\n * a zeros input:\n *\n * ```js\n * const modelUrl =\n *    'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';\n * const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true});\n * const zeros = tf.zeros([1, 224, 224, 3]);\n * model.predict(zeros).print();\n * ```\n * @param modelUrl The url or an `io.IOHandler` that loads the model.\n * @param options Options for the HTTP request, which allows to send credentials\n *    and custom headers.\n *\n * @doc {heading: 'Models', subheading: 'Loading'}\n */\nexport async function loadGraphModel(\n    modelUrl: string|io.IOHandler,\n    options: io.LoadOptions = {}): Promise<GraphModel> {\n  if (modelUrl == null) {\n    throw new Error(\n        'modelUrl in loadGraphModel() cannot be null. Please provide a url ' +\n        'or an IOHandler that loads the model');\n  }\n  if (options == null) {\n    options = {};\n  }\n\n  if (options.fromTFHub) {\n    if ((modelUrl as io.IOHandler).load == null) {\n      if (!(modelUrl as string).endsWith('/')) {\n        modelUrl = (modelUrl as string) + '/';\n      }\n      modelUrl = `${modelUrl}${DEFAULT_MODEL_NAME}${TFHUB_SEARCH_PARAM}`;\n    }\n  }\n  const model = new GraphModel(modelUrl, options);\n  await model.load();\n  return model;\n}\n"]}