UNPKG

@tensorflow/tfjs-node

Version:

This repository provides native TensorFlow execution in backend JavaScript applications under the Node.js runtime, accelerated by the TensorFlow C binary under the hood. It provides the same API as [TensorFlow.js](https://js.tensorflow.org/api/latest/).

268 lines (244 loc) 9.92 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as tf from '@tensorflow/tfjs'; import * as fs from 'fs'; import {dirname, join, resolve} from 'path'; import {promisify} from 'util'; import {toArrayBuffer} from './io_utils'; const stat = promisify(fs.stat); const writeFile = promisify(fs.writeFile); const readFile = promisify(fs.readFile); const mkdir = promisify(fs.mkdir); function doesNotExistHandler(name: string): (e: NodeJS.ErrnoException) => never { return e => { switch (e.code) { case 'ENOENT': throw new Error(`${name} ${e.path} does not exist: loading failed`); default: throw e; } }; } export class NodeFileSystem implements tf.io.IOHandler { static readonly URL_SCHEME = 'file://'; protected readonly path: string|string[]; readonly MODEL_JSON_FILENAME = 'model.json'; readonly WEIGHTS_BINARY_FILENAME = 'weights.bin'; readonly MODEL_BINARY_FILENAME = 'tensorflowjs.pb'; /** * Constructor of the NodeFileSystem IOHandler. * @param path A single path or an Array of paths. * For saving: expects a single path pointing to an existing or nonexistent * directory. If the directory does not exist, it will be * created. * For loading: * - If the model has JSON topology (e.g., `tf.Model`), a single path * pointing to the JSON file (usually named `model.json`) is expected. * The JSON file is expected to contain `modelTopology` and/or * `weightsManifest`. If `weightManifest` exists, the values of the * weights will be loaded from relative paths (relative to the directory * of `model.json`) as contained in `weightManifest`. * - If the model has binary (protocol buffer GraphDef) topology, * an Array of two paths is expected: the first path should point to the * .pb file and the second path should point to the weight manifest * JSON file. */ constructor(path: string|string[]) { if (Array.isArray(path)) { tf.util.assert( path.length === 2, () => 'file paths must have a length of 2, ' + `(actual length is ${path.length}).`); this.path = path.map(p => resolve(p)); } else { this.path = resolve(path); } } async save(modelArtifacts: tf.io.ModelArtifacts): Promise<tf.io.SaveResult> { if (Array.isArray(this.path)) { throw new Error('Cannot perform saving to multiple paths.'); } await this.createOrVerifyDirectory(); if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error( 'NodeFileSystem.save() does not support saving model topology ' + 'in binary format yet.'); // TODO(cais, nkreeger): Implement this. See // https://github.com/tensorflow/tfjs/issues/343 } else { const weightsBinPath = join(this.path, this.WEIGHTS_BINARY_FILENAME); const weightsManifest = [{ paths: [this.WEIGHTS_BINARY_FILENAME], weights: modelArtifacts.weightSpecs }]; const modelJSON: tf.io.ModelJSON = { modelTopology: modelArtifacts.modelTopology, weightsManifest, format: modelArtifacts.format, generatedBy: modelArtifacts.generatedBy, convertedBy: modelArtifacts.convertedBy }; if (modelArtifacts.trainingConfig != null) { modelJSON.trainingConfig = modelArtifacts.trainingConfig; } if (modelArtifacts.signature != null) { modelJSON.signature = modelArtifacts.signature; } if (modelArtifacts.userDefinedMetadata != null) { modelJSON.userDefinedMetadata = modelArtifacts.userDefinedMetadata; } const modelJSONPath = join(this.path, this.MODEL_JSON_FILENAME); await writeFile(modelJSONPath, JSON.stringify(modelJSON), 'utf8'); await writeFile( weightsBinPath, Buffer.from(modelArtifacts.weightData), 'binary'); return { // TODO(cais): Use explicit tf.io.ModelArtifactsInfo type below once it // is available. // tslint:disable-next-line:no-any modelArtifactsInfo: tf.io.getModelArtifactsInfoForJSON(modelArtifacts), }; } } async load(): Promise<tf.io.ModelArtifacts> { return Array.isArray(this.path) ? this.loadBinaryModel() : this.loadJSONModel(); } protected async loadBinaryModel(): Promise<tf.io.ModelArtifacts> { const topologyPath = this.path[0]; const weightManifestPath = this.path[1]; const topology = await stat(topologyPath).catch(doesNotExistHandler('Topology Path')); const weightManifest = await stat(weightManifestPath) .catch(doesNotExistHandler('Weight Manifest Path')); // `this.path` can be either a directory or a file. If it is a file, assume // it is model.json file. if (!topology.isFile()) { throw new Error('File specified for topology is not a file!'); } if (!weightManifest.isFile()) { throw new Error('File specified for the weight manifest is not a file!'); } const modelTopology = await readFile(this.path[0]); const weightsManifest = JSON.parse(await readFile(this.path[1], 'utf8')); const modelArtifacts: tf.io.ModelArtifacts = { modelTopology, }; const [weightSpecs, weightData] = await this.loadWeights(weightsManifest, this.path[1]); modelArtifacts.weightSpecs = weightSpecs; modelArtifacts.weightData = weightData; return modelArtifacts; } protected async loadJSONModel(): Promise<tf.io.ModelArtifacts> { const path = this.path as string; const info = await stat(path).catch(doesNotExistHandler('Path')); // `path` can be either a directory or a file. If it is a file, assume // it is model.json file. if (info.isFile()) { const modelJSON = JSON.parse(await readFile(path, 'utf8')); return tf.io.getModelArtifactsForJSON( modelJSON, (weightsManifest) => this.loadWeights(weightsManifest, path)); } else { throw new Error( 'The path to load from must be a file. Loading from a directory ' + 'is not supported.'); } } private async loadWeights( weightsManifest: tf.io.WeightsManifestConfig, path: string): Promise<[tf.io.WeightsManifestEntry[], ArrayBuffer]> { const dirName = dirname(path); const buffers: Buffer[] = []; const weightSpecs: tf.io.WeightsManifestEntry[] = []; for (const group of weightsManifest) { for (const path of group.paths) { const weightFilePath = join(dirName, path); const buffer = await readFile(weightFilePath) .catch(doesNotExistHandler('Weight file')); buffers.push(buffer); } weightSpecs.push(...group.weights); } return [weightSpecs, toArrayBuffer(buffers)]; } /** * For each item in `this.path`, creates a directory at the path or verify * that the path exists as a directory. */ protected async createOrVerifyDirectory() { const paths = Array.isArray(this.path) ? this.path : [this.path]; for (const path of paths) { try { await mkdir(path); } catch (e) { if (e.code === 'EEXIST') { if ((await stat(path)).isFile()) { throw new Error( `Path ${path} exists as a file. The path must be ` + `nonexistent or point to a directory.`); } // else continue, the directory exists } else { throw e; } } } } } export const nodeFileSystemRouter = (url: string|string[]) => { if (Array.isArray(url)) { if (url.every( urlElement => urlElement.startsWith(NodeFileSystem.URL_SCHEME))) { return new NodeFileSystem(url.map( urlElement => urlElement.slice(NodeFileSystem.URL_SCHEME.length))); } else { return null; } } else { if (url.startsWith(NodeFileSystem.URL_SCHEME)) { return new NodeFileSystem(url.slice(NodeFileSystem.URL_SCHEME.length)); } else { return null; } } }; // Registration of `nodeFileSystemRouter` is done in index.ts. /** * Factory function for Node.js native file system IO Handler. * * @param path A single path or an Array of paths. * For saving: expects a single path pointing to an existing or nonexistent * directory. If the directory does not exist, it will be * created. * For loading: * - If the model has JSON topology (e.g., `tf.Model`), a single path * pointing to the JSON file (usually named `model.json`) is expected. * The JSON file is expected to contain `modelTopology` and/or * `weightsManifest`. If `weightManifest` exists, the values of the * weights will be loaded from relative paths (relative to the directory * of `model.json`) as contained in `weightManifest`. * - If the model has binary (protocol buffer GraphDef) topology, * an Array of two paths is expected: the first path should point to the * .pb file and the second path should point to the weight manifest * JSON file. */ export function fileSystem(path: string|string[]): NodeFileSystem { return new NodeFileSystem(path); }