UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

120 lines (109 loc) 4.26 kB
/** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as tf from '@tensorflow/tfjs-core'; import {io} from '@tensorflow/tfjs-core'; import {Image, ImageSourcePropType} from 'react-native'; class BundleResourceHandler implements io.IOHandler { constructor( protected readonly modelJson: io.ModelJSON, protected readonly modelWeightsId: string|number) { if (modelJson == null || modelWeightsId == null) { throw new Error( 'Must pass the model json object and the model weights path.'); } if (Array.isArray(modelWeightsId)) { throw new Error( 'Bundle resource IO handler does not currently support loading ' + 'sharded weights'); } } /** * Save model artifacts. This IO handler cannot support writing to the * packaged bundle at runtime and is exclusively for loading a model * that is already packages with the app. */ async save(): Promise<io.SaveResult> { throw new Error( 'Bundle resource IO handler does not support saving. ' + 'Consider using asyncStorageIO instead'); } /** * Load a model from local storage. * * See the documentation to `browserLocalStorage` for details on the saved * artifacts. * * @returns The loaded model (if loading succeeds). */ async load(): Promise<io.ModelArtifacts> { const modelJson = this.modelJson; // Load the weights const weightsAssetPath = Image.resolveAssetSource(this.modelWeightsId as ImageSourcePropType); const response = await tf.util.fetch(weightsAssetPath.uri, { headers: { responseType: 'arraybuffer', } }); const weightData = await response.arrayBuffer(); if (modelJson.weightsManifest.length > 1) { throw new Error( 'Bundle resource IO handler does not currently support loading ' + 'sharded weights and the modelJson indicates that this model has ' + 'sharded weights (more than one weights file).'); } const modelArtifacts: io.ModelArtifacts = Object.assign({}, modelJson); modelArtifacts.weightSpecs = modelJson.weightsManifest[0].weights; //@ts-ignore delete modelArtifacts.weightManifest; modelArtifacts.weightData = weightData; return modelArtifacts; } } /** * Factory function for BundleResource IOHandler. * * This `IOHandler` only supports `load`. It is designed to support * loading models that have been statically bundled (at compile time) * with an app. * * @param modelJson The JSON object for the serialized model. * @param modelWeightsId An identifier for the model weights file. This is * generally a resourceId or a path to the resource in the app package. * This is typically obtained with a `require` statement. * * See * facebook.github.io/react-native/docs/images#static-non-image-resources * for more details on how to include static resources into your react-native * app including how to configure `metro` to bundle `.bin` files. * * @returns An instance of `IOHandler` */ export function bundleResourceIO( modelJson: io.ModelJSON, modelWeightsId: number): io.IOHandler { if (typeof modelWeightsId !== 'object') { throw new Error( 'modelJson must be a JavaScript object (and not a string).\n' + 'Have you wrapped yor asset path in a require() statment?'); } if (typeof modelWeightsId !== 'number') { throw new Error( 'modelWeightsID must be a number.\n' + 'Have you wrapped yor asset path in a require() statment?'); } return new BundleResourceHandler(modelJson, modelWeightsId); }