UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

352 lines (319 loc) 12.2 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {ENV} from '../environment'; import {assert} from '../util'; import {arrayBufferToBase64String, base64StringToArrayBuffer, getModelArtifactsInfoForJSON} from './io_utils'; import {ModelStoreManagerRegistry} from './model_management'; import {IORouter, IORouterRegistry} from './router_registry'; import {IOHandler, ModelArtifacts, ModelArtifactsInfo, ModelStoreManager, SaveResult} from './types'; const PATH_SEPARATOR = '/'; const PATH_PREFIX = 'tensorflowjs_models'; const INFO_SUFFIX = 'info'; const MODEL_TOPOLOGY_SUFFIX = 'model_topology'; const WEIGHT_SPECS_SUFFIX = 'weight_specs'; const WEIGHT_DATA_SUFFIX = 'weight_data'; const MODEL_METADATA_SUFFIX = 'model_metadata'; /** * Purge all tensorflow.js-saved model artifacts from local storage. * * @returns Paths of the models purged. */ export function purgeLocalStorageArtifacts(): string[] { if (!ENV.getBool('IS_BROWSER') || typeof window.localStorage === 'undefined') { throw new Error( 'purgeLocalStorageModels() cannot proceed because local storage is ' + 'unavailable in the current environment.'); } const LS = window.localStorage; const purgedModelPaths: string[] = []; for (let i = 0; i < LS.length; ++i) { const key = LS.key(i); const prefix = PATH_PREFIX + PATH_SEPARATOR; if (key.startsWith(prefix) && key.length > prefix.length) { LS.removeItem(key); const modelName = getModelPathFromKey(key); if (purgedModelPaths.indexOf(modelName) === -1) { purgedModelPaths.push(modelName); } } } return purgedModelPaths; } function getModelKeys(path: string): { info: string, topology: string, weightSpecs: string, weightData: string, modelMetadata: string } { return { info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR), topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR), weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR), weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR), modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR) }; } /** * Get model path from a local-storage key. * * E.g., 'tensorflowjs_models/my/model/1/info' --> 'my/model/1' * * @param key */ function getModelPathFromKey(key: string) { const items = key.split(PATH_SEPARATOR); if (items.length < 3) { throw new Error(`Invalid key format: ${key}`); } return items.slice(1, items.length - 1).join(PATH_SEPARATOR); } function maybeStripScheme(key: string) { return key.startsWith(BrowserLocalStorage.URL_SCHEME) ? key.slice(BrowserLocalStorage.URL_SCHEME.length) : key; } declare type LocalStorageKeys = { info: string, topology: string, weightSpecs: string, weightData: string, modelMetadata: string }; /** * IOHandler subclass: Browser Local Storage. * * See the doc string to `browserLocalStorage` for more details. */ export class BrowserLocalStorage implements IOHandler { protected readonly LS: Storage; protected readonly modelPath: string; protected readonly keys: LocalStorageKeys; static readonly URL_SCHEME = 'localstorage://'; constructor(modelPath: string) { if (!ENV.getBool('IS_BROWSER') || typeof window.localStorage === 'undefined') { // TODO(cais): Add more info about what IOHandler subtypes are // available. // Maybe point to a doc page on the web and/or automatically determine // the available IOHandlers and print them in the error message. throw new Error( 'The current environment does not support local storage.'); } this.LS = window.localStorage; if (modelPath == null || !modelPath) { throw new Error( 'For local storage, modelPath must not be null, undefined or empty.'); } this.modelPath = modelPath; this.keys = getModelKeys(this.modelPath); } /** * Save model artifacts to browser local storage. * * See the documentation to `browserLocalStorage` for details on the saved * artifacts. * * @param modelArtifacts The model artifacts to be stored. * @returns An instance of SaveResult. */ async save(modelArtifacts: ModelArtifacts): Promise<SaveResult> { if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error( 'BrowserLocalStorage.save() does not support saving model topology ' + 'in binary formats yet.'); } else { const topology = JSON.stringify(modelArtifacts.modelTopology); const weightSpecs = JSON.stringify(modelArtifacts.weightSpecs); const modelArtifactsInfo: ModelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); try { this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo)); this.LS.setItem(this.keys.topology, topology); this.LS.setItem(this.keys.weightSpecs, weightSpecs); this.LS.setItem( this.keys.weightData, arrayBufferToBase64String(modelArtifacts.weightData)); this.LS.setItem(this.keys.modelMetadata, JSON.stringify({ format: modelArtifacts.format, generatedBy: modelArtifacts.generatedBy, convertedBy: modelArtifacts.convertedBy })); return {modelArtifactsInfo}; } catch (err) { // If saving failed, clean up all items saved so far. this.LS.removeItem(this.keys.info); this.LS.removeItem(this.keys.topology); this.LS.removeItem(this.keys.weightSpecs); this.LS.removeItem(this.keys.weightData); this.LS.removeItem(this.keys.modelMetadata); throw new Error( `Failed to save model '${this.modelPath}' to local storage: ` + `size quota being exceeded is a possible cause of this failure: ` + `modelTopologyBytes=${modelArtifactsInfo.modelTopologyBytes}, ` + `weightSpecsBytes=${modelArtifactsInfo.weightSpecsBytes}, ` + `weightDataBytes=${modelArtifactsInfo.weightDataBytes}.`); } } } /** * Load a model from local storage. * * See the documentation to `browserLocalStorage` for details on the saved * artifacts. * * @returns The loaded model (if loading succeeds). */ async load(): Promise<ModelArtifacts> { const info = JSON.parse(this.LS.getItem(this.keys.info)) as ModelArtifactsInfo; if (info == null) { throw new Error( `In local storage, there is no model with name '${this.modelPath}'`); } if (info.modelTopologyType !== 'JSON') { throw new Error( 'BrowserLocalStorage does not support loading non-JSON model ' + 'topology yet.'); } const out: ModelArtifacts = {}; // Load topology. const topology = JSON.parse(this.LS.getItem(this.keys.topology)); if (topology == null) { throw new Error( `In local storage, the topology of model '${this.modelPath}' ` + `is missing.`); } out.modelTopology = topology; // Load weight specs. const weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs)); if (weightSpecs == null) { throw new Error( `In local storage, the weight specs of model '${this.modelPath}' ` + `are missing.`); } out.weightSpecs = weightSpecs; // Load meta-data fields. const metadataString = this.LS.getItem(this.keys.modelMetadata); if (metadataString != null) { const metadata = JSON.parse(metadataString) as {format: string, generatedBy: string, convertedBy: string}; out.format = metadata['format']; out.generatedBy = metadata['generatedBy']; out.convertedBy = metadata['convertedBy']; } // Load weight data. const weightDataBase64 = this.LS.getItem(this.keys.weightData); if (weightDataBase64 == null) { throw new Error( `In local storage, the binary weight values of model ` + `'${this.modelPath}' are missing.`); } out.weightData = base64StringToArrayBuffer(weightDataBase64); return out; } } export const localStorageRouter: IORouter = (url: string|string[]) => { if (!ENV.getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) { return browserLocalStorage( url.slice(BrowserLocalStorage.URL_SCHEME.length)); } else { return null; } } }; IORouterRegistry.registerSaveRouter(localStorageRouter); IORouterRegistry.registerLoadRouter(localStorageRouter); /** * Factory function for local storage IOHandler. * * This `IOHandler` supports both `save` and `load`. * * For each model's saved artifacts, four items are saved to local storage. * - `${PATH_SEPARATOR}/${modelPath}/info`: Contains meta-info about the * model, such as date saved, type of the topology, size in bytes, etc. * - `${PATH_SEPARATOR}/${modelPath}/topology`: Model topology. For Keras- * style models, this is a stringized JSON. * - `${PATH_SEPARATOR}/${modelPath}/weight_specs`: Weight specs of the * model, can be used to decode the saved binary weight values (see * item below). * - `${PATH_SEPARATOR}/${modelPath}/weight_data`: Concatenated binary * weight values, stored as a base64-encoded string. * * Saving may throw an `Error` if the total size of the artifacts exceed the * browser-specific quota. * * @param modelPath A unique identifier for the model to be saved. Must be a * non-empty string. * @returns An instance of `IOHandler`, which can be used with, e.g., * `tf.Model.save`. */ export function browserLocalStorage(modelPath: string): IOHandler { return new BrowserLocalStorage(modelPath); } export class BrowserLocalStorageManager implements ModelStoreManager { private readonly LS: Storage; constructor() { assert( ENV.getBool('IS_BROWSER'), () => 'Current environment is not a web browser'); assert( typeof window.localStorage !== 'undefined', () => 'Current browser does not appear to support localStorage'); this.LS = window.localStorage; } async listModels(): Promise<{[path: string]: ModelArtifactsInfo}> { const out: {[path: string]: ModelArtifactsInfo} = {}; const prefix = PATH_PREFIX + PATH_SEPARATOR; const suffix = PATH_SEPARATOR + INFO_SUFFIX; for (let i = 0; i < this.LS.length; ++i) { const key = this.LS.key(i); if (key.startsWith(prefix) && key.endsWith(suffix)) { const modelPath = getModelPathFromKey(key); out[modelPath] = JSON.parse(this.LS.getItem(key)) as ModelArtifactsInfo; } } return out; } async removeModel(path: string): Promise<ModelArtifactsInfo> { path = maybeStripScheme(path); const keys = getModelKeys(path); if (this.LS.getItem(keys.info) == null) { throw new Error(`Cannot find model at path '${path}'`); } const info = JSON.parse(this.LS.getItem(keys.info)) as ModelArtifactsInfo; this.LS.removeItem(keys.info); this.LS.removeItem(keys.topology); this.LS.removeItem(keys.weightSpecs); this.LS.removeItem(keys.weightData); return info; } } if (ENV.getBool('IS_BROWSER')) { // Wrap the construction and registration, to guard against browsers that // don't support Local Storage. try { ModelStoreManagerRegistry.registerManager( BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager()); } catch (err) { } }