@tensorflow-models/coco-ssd
Version:
Object detection model (coco-ssd) in TensorFlow.js
144 lines (135 loc) • 5.63 kB
text/typescript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {deprecationWarn, io} from '@tensorflow/tfjs-core';
import {DEFAULT_MANIFEST_NAME, FrozenModel, loadFrozenModel as loadFrozenModelPB, loadTfHubModule} from './executor/frozen_model';
import {loadFrozenModel as loadFrozenModelJSON} from './executor/frozen_model_json';
export {FrozenModel, loadTfHubModule} from './executor/frozen_model';
export {FrozenModel as GraphModel} from './executor/frozen_model';
export {FrozenModel as FrozenModelJSON} from './executor/frozen_model_json';
export {version as version_converter} from './version';
/**
* Deprecated. Use `tf.loadGraphModel`.
*
* Load the frozen model through url.
*
* Example of loading the MobileNetV2 model and making a prediction with a zero
* input.
*
* ```js
* const GOOGLE_CLOUD_STORAGE_DIR =
* 'https://storage.googleapis.com/tfjs-models/savedmodel/';
* const MODEL_URL = 'mobilenet_v2_1.0_224/tensorflowjs_model.pb';
* const WEIGHTS_URL =
* 'mobilenet_v2_1.0_224/weights_manifest.json';
* const model = await tf.loadFrozenModel(GOOGLE_CLOUD_STORAGE_DIR + MODEL_URL,
* GOOGLE_CLOUD_STORAGE_DIR + WEIGHTS_URL);
* const zeros = tf.zeros([1, 224, 224, 3]);
* model.predict(zeros).print();
* ```
*
* @param modelUrl url for the model file generated by scripts/convert.py
* script.
* @param weightManifestUrl url for the weight file generated by
* scripts/convert.py script.
* @param requestOption options for Request, which allows to send credentials
* and custom headers.
* @param onProgress Optional, progress callback function, fired periodically
* before the load is completed.
*/
/** @doc {heading: 'Models', subheading: 'Loading'} */
export function loadFrozenModel(
modelUrl: string, weightsManifestUrl?: string, requestOption?: RequestInit,
onProgress?: Function): Promise<FrozenModel> {
deprecationWarn(
'tf.loadFrozenModel() is going away. ' +
'Use tf.loadGraphModel() instead, and note the positional argument changes.');
if (modelUrl && modelUrl.endsWith('.json')) {
return (loadFrozenModelJSON(modelUrl, requestOption, onProgress) as
// tslint:disable-next-line:no-any
Promise<any>) as Promise<FrozenModel>;
}
// if users are using the new loadGraphModel API, the weightManifestUrl
// will be omitted. We will build the url using the model URL path and
// default manifest file name.
if (modelUrl != null && weightsManifestUrl == null) {
weightsManifestUrl = getWeightsManifestUrl(modelUrl);
}
return loadFrozenModelPB(
modelUrl, weightsManifestUrl, requestOption, onProgress);
}
function getWeightsManifestUrl(modelUrl: string): string {
let weightsManifestUrl: string;
if (modelUrl != null) {
const path = modelUrl.substr(0, modelUrl.lastIndexOf('/'));
weightsManifestUrl = path + '/' + DEFAULT_MANIFEST_NAME;
}
return weightsManifestUrl;
}
/**
* Load a graph model given a URL to the model definition.
*
* Example of loading MobileNetV2 from a URL and making a prediction with a
* zeros input:
*
* ```js
* const modelUrl =
* 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/tensorflowjs_model.pb';
* const model = await tf.loadGraphModel(modelUrl);
* const zeros = tf.zeros([1, 224, 224, 3]);
* model.predict(zeros).print();
* ```
*
* Example of loading MobileNetV2 from a TF Hub URL and making a prediction with
* a zeros input:
*
* ```js
* const modelUrl =
* 'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
* const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true});
* const zeros = tf.zeros([1, 224, 224, 3]);
* model.predict(zeros).print();
* ```
* @param modelUrl url for the model file generated by scripts/convert.py
* script or a TF Hub url.
* @param options options for the Request, which allows to send credentials
* and custom headers.
*/
/** @doc {heading: 'Models', subheading: 'Loading'} */
export function loadGraphModel(
modelUrl: string, options: io.LoadOptions = {}): Promise<FrozenModel> {
if (options == null) {
options = {};
}
if (options.fromTFHub) {
return loadTfHubModule(modelUrl, options.requestInit, options.onProgress);
}
let weightsManifestUrl: string = undefined;
if (modelUrl && modelUrl.endsWith('.json')) {
return (loadFrozenModelJSON(
modelUrl, options.requestInit, options.onProgress) as
// tslint:disable-next-line:no-any
Promise<any>) as Promise<FrozenModel>;
}
// if users are using the new loadGraphModel API, the weightManifestUrl will
// be omitted. We will build the url using the model URL path and default
// manifest file name.
if (modelUrl != null && weightsManifestUrl == null) {
weightsManifestUrl = getWeightsManifestUrl(modelUrl);
}
return loadFrozenModelPB(
modelUrl, weightsManifestUrl, options.requestInit, options.onProgress);
}