UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

184 lines (166 loc) 6.61 kB
/** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {AsyncStorageStatic} from '@react-native-community/async-storage'; import {io} from '@tensorflow/tfjs-core'; import {fromByteArray, toByteArray} from 'base64-js'; type StorageKeys = { info: string, modelArtifactsWithoutWeights: string, weightData: string, }; const PATH_SEPARATOR = '/'; const PATH_PREFIX = 'tensorflowjs_models'; const INFO_SUFFIX = 'info'; const MODEL_SUFFIX = 'model_without_weight'; const WEIGHT_DATA_SUFFIX = 'weight_data'; function getModelKeys(path: string): StorageKeys { return { info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR), modelArtifactsWithoutWeights: [PATH_PREFIX, path, MODEL_SUFFIX].join(PATH_SEPARATOR), weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR), }; } /** * Populate ModelArtifactsInfo fields for a model with JSON topology. * @param modelArtifacts * @returns A ModelArtifactsInfo object. */ function getModelArtifactsInfoForJSON(modelArtifacts: io.ModelArtifacts): io.ModelArtifactsInfo { if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error('Expected JSON model topology, received ArrayBuffer.'); } return { dateSaved: new Date(), // TODO followup on removing this from the the interface modelTopologyType: 'JSON', weightDataBytes: modelArtifacts.weightData == null ? 0 : modelArtifacts.weightData.byteLength, }; } class AsyncStorageHandler implements io.IOHandler { protected readonly keys: StorageKeys; protected asyncStorage: AsyncStorageStatic; constructor(protected readonly modelPath: string) { if (modelPath == null || !modelPath) { throw new Error('modelPath must not be null, undefined or empty.'); } this.keys = getModelKeys(this.modelPath); // We import this dynamically because it binds to a native library that // needs to be installed by the user if they use this handler. We don't // want users who are not using AsyncStorage to have to install this // library. this.asyncStorage = // tslint:disable-next-line:no-require-imports require('@react-native-community/async-storage').default; } /** * Save model artifacts to AsyncStorage * * @param modelArtifacts The model artifacts to be stored. * @returns An instance of SaveResult. */ async save(modelArtifacts: io.ModelArtifacts): Promise<io.SaveResult> { if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error( 'AsyncStorageHandler.save() does not support saving model topology ' + 'in binary format.'); } else { // We save three items separately for each model, // a ModelArtifactsInfo, a ModelArtifacts without weights // and the model weights. const modelArtifactsInfo: io.ModelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); const {weightData, ...modelArtifactsWithoutWeights} = modelArtifacts; try { this.asyncStorage.setItem( this.keys.info, JSON.stringify(modelArtifactsInfo)); this.asyncStorage.setItem( this.keys.modelArtifactsWithoutWeights, JSON.stringify(modelArtifactsWithoutWeights)); this.asyncStorage.setItem( this.keys.weightData, fromByteArray(new Uint8Array(weightData))); return {modelArtifactsInfo}; } catch (err) { // If saving failed, clean up all items saved so far. this.asyncStorage.removeItem(this.keys.info); this.asyncStorage.removeItem(this.keys.weightData); this.asyncStorage.removeItem(this.keys.modelArtifactsWithoutWeights); throw new Error( `Failed to save model '${this.modelPath}' to AsyncStorage. Error info ${err}`); } } } /** * Load a model from local storage. * * See the documentation to `browserLocalStorage` for details on the saved * artifacts. * * @returns The loaded model (if loading succeeds). */ async load(): Promise<io.ModelArtifacts> { const info = JSON.parse(await this.asyncStorage.getItem(this.keys.info)) as io.ModelArtifactsInfo; if (info == null) { throw new Error( `In local storage, there is no model with name '${this.modelPath}'`); } if (info.modelTopologyType !== 'JSON') { throw new Error( 'BrowserLocalStorage does not support loading non-JSON model ' + 'topology yet.'); } const modelArtifacts: io.ModelArtifacts = JSON.parse(await this.asyncStorage.getItem( this.keys.modelArtifactsWithoutWeights)); // Load weight data. const weightDataBase64 = await this.asyncStorage.getItem(this.keys.weightData); if (weightDataBase64 == null) { throw new Error( `In local storage, the binary weight values of model ` + `'${this.modelPath}' are missing.`); } modelArtifacts.weightData = toByteArray(weightDataBase64).buffer; return modelArtifacts; } } /** * Factory function for AsyncStorage IOHandler. * * This `IOHandler` supports both `save` and `load`. * * For each model's saved artifacts, three items are saved to async storage. * - `${PATH_PREFIX}/${modelPath}/info`: Contains meta-info about the * model, such as date saved, type of the topology, size in bytes, etc. * - `${PATH_PREFIX}/${modelPath}/model_without_weight`: The topology, * weights_specs and all other information about the model except for the * weights. * - `${PATH_PREFIX}/${modelPath}/weight_data`: Concatenated binary * weight values, stored as a base64-encoded string. * * @param modelPath A unique identifier for the model to be saved. Must be a * non-empty string. * @returns An instance of `IOHandler` */ export function asyncStorageIO(modelPath: string): io.IOHandler { return new AsyncStorageHandler(modelPath); }