UNPKG

@tensorflow/tfjs-node

Version:

This repository provides native TensorFlow execution in backend JavaScript applications under the Node.js runtime, accelerated by the TensorFlow C binary under the hood. It provides the same API as [TensorFlow.js](https://js.tensorflow.org/api/latest/).

365 lines (364 loc) 18.4 kB
"use strict"; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); }; var __generator = (this && this.__generator) || function (thisArg, body) { var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g; return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g; function verb(n) { return function (v) { return step([n, v]); }; } function step(op) { if (f) throw new TypeError("Generator is already executing."); while (g && (g = 0, op[0] && (_ = 0)), _) try { if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t; if (y = 0, t) op = [op[0] & 2, t.value]; switch (op[0]) { case 0: case 1: t = op; break; case 4: _.label++; return { value: op[1], done: false }; case 5: _.label++; y = op[1]; op = [0]; continue; case 7: op = _.ops.pop(); _.trys.pop(); continue; default: if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; } if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; } if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; } if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; } if (t[2]) _.ops.pop(); _.trys.pop(); continue; } op = body.call(thisArg, _); } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; } if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true }; } }; Object.defineProperty(exports, "__esModule", { value: true }); exports.fileSystem = exports.nodeFileSystemRouter = exports.NodeFileSystem = void 0; var tf = require("@tensorflow/tfjs"); var fs = require("fs"); var path_1 = require("path"); var util_1 = require("util"); var io_utils_1 = require("./io_utils"); var stat = (0, util_1.promisify)(fs.stat); var writeFile = (0, util_1.promisify)(fs.writeFile); var readFile = (0, util_1.promisify)(fs.readFile); var mkdir = (0, util_1.promisify)(fs.mkdir); function doesNotExistHandler(name) { return function (e) { switch (e.code) { case 'ENOENT': throw new Error("".concat(name, " ").concat(e.path, " does not exist: loading failed")); default: throw e; } }; } var NodeFileSystem = exports.NodeFileSystem = /** @class */ (function () { /** * Constructor of the NodeFileSystem IOHandler. * @param path A single path or an Array of paths. * For saving: expects a single path pointing to an existing or nonexistent * directory. If the directory does not exist, it will be * created. * For loading: * - If the model has JSON topology (e.g., `tf.Model`), a single path * pointing to the JSON file (usually named `model.json`) is expected. * The JSON file is expected to contain `modelTopology` and/or * `weightsManifest`. If `weightManifest` exists, the values of the * weights will be loaded from relative paths (relative to the directory * of `model.json`) as contained in `weightManifest`. * - If the model has binary (protocol buffer GraphDef) topology, * an Array of two paths is expected: the first path should point to the * .pb file and the second path should point to the weight manifest * JSON file. */ function NodeFileSystem(path) { this.MODEL_JSON_FILENAME = 'model.json'; this.WEIGHTS_BINARY_FILENAME = 'weights.bin'; this.MODEL_BINARY_FILENAME = 'tensorflowjs.pb'; if (Array.isArray(path)) { tf.util.assert(path.length === 2, function () { return 'file paths must have a length of 2, ' + "(actual length is ".concat(path.length, ")."); }); this.path = path.map(function (p) { return (0, path_1.resolve)(p); }); } else { this.path = (0, path_1.resolve)(path); } } NodeFileSystem.prototype.save = function (modelArtifacts) { return __awaiter(this, void 0, void 0, function () { var weightsBinPath, weightsManifest, modelJSON, modelJSONPath; return __generator(this, function (_a) { switch (_a.label) { case 0: if (Array.isArray(this.path)) { throw new Error('Cannot perform saving to multiple paths.'); } return [4 /*yield*/, this.createOrVerifyDirectory()]; case 1: _a.sent(); if (!(modelArtifacts.modelTopology instanceof ArrayBuffer)) return [3 /*break*/, 2]; throw new Error('NodeFileSystem.save() does not support saving model topology ' + 'in binary format yet.'); case 2: weightsBinPath = (0, path_1.join)(this.path, this.WEIGHTS_BINARY_FILENAME); weightsManifest = [{ paths: [this.WEIGHTS_BINARY_FILENAME], weights: modelArtifacts.weightSpecs }]; modelJSON = { modelTopology: modelArtifacts.modelTopology, weightsManifest: weightsManifest, format: modelArtifacts.format, generatedBy: modelArtifacts.generatedBy, convertedBy: modelArtifacts.convertedBy }; if (modelArtifacts.trainingConfig != null) { modelJSON.trainingConfig = modelArtifacts.trainingConfig; } if (modelArtifacts.signature != null) { modelJSON.signature = modelArtifacts.signature; } if (modelArtifacts.userDefinedMetadata != null) { modelJSON.userDefinedMetadata = modelArtifacts.userDefinedMetadata; } modelJSONPath = (0, path_1.join)(this.path, this.MODEL_JSON_FILENAME); return [4 /*yield*/, writeFile(modelJSONPath, JSON.stringify(modelJSON), 'utf8')]; case 3: _a.sent(); return [4 /*yield*/, writeFile(weightsBinPath, Buffer.from(modelArtifacts.weightData), 'binary')]; case 4: _a.sent(); return [2 /*return*/, { // TODO(cais): Use explicit tf.io.ModelArtifactsInfo type below once it // is available. // tslint:disable-next-line:no-any modelArtifactsInfo: tf.io.getModelArtifactsInfoForJSON(modelArtifacts), }]; } }); }); }; NodeFileSystem.prototype.load = function () { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { return [2 /*return*/, Array.isArray(this.path) ? this.loadBinaryModel() : this.loadJSONModel()]; }); }); }; NodeFileSystem.prototype.loadBinaryModel = function () { return __awaiter(this, void 0, void 0, function () { var topologyPath, weightManifestPath, topology, weightManifest, modelTopology, weightsManifest, _a, _b, modelArtifacts, weightSpecs, weightData; var _c; return __generator(this, function (_d) { switch (_d.label) { case 0: topologyPath = this.path[0]; weightManifestPath = this.path[1]; return [4 /*yield*/, stat(topologyPath).catch(doesNotExistHandler('Topology Path'))]; case 1: topology = _d.sent(); return [4 /*yield*/, stat(weightManifestPath) .catch(doesNotExistHandler('Weight Manifest Path'))]; case 2: weightManifest = _d.sent(); // `this.path` can be either a directory or a file. If it is a file, assume // it is model.json file. if (!topology.isFile()) { throw new Error('File specified for topology is not a file!'); } if (!weightManifest.isFile()) { throw new Error('File specified for the weight manifest is not a file!'); } return [4 /*yield*/, readFile(this.path[0])]; case 3: modelTopology = _d.sent(); _b = (_a = JSON).parse; return [4 /*yield*/, readFile(this.path[1], 'utf8')]; case 4: weightsManifest = _b.apply(_a, [_d.sent()]); modelArtifacts = { modelTopology: modelTopology, }; return [4 /*yield*/, this.loadWeights(weightsManifest, this.path[1])]; case 5: weightSpecs = (_c = _d.sent(), _c[0]), weightData = _c[1]; modelArtifacts.weightSpecs = weightSpecs; modelArtifacts.weightData = weightData; return [2 /*return*/, modelArtifacts]; } }); }); }; NodeFileSystem.prototype.loadJSONModel = function () { return __awaiter(this, void 0, void 0, function () { var path, info, modelJSON, _a, _b; var _this = this; return __generator(this, function (_c) { switch (_c.label) { case 0: path = this.path; return [4 /*yield*/, stat(path).catch(doesNotExistHandler('Path'))]; case 1: info = _c.sent(); if (!info.isFile()) return [3 /*break*/, 3]; _b = (_a = JSON).parse; return [4 /*yield*/, readFile(path, 'utf8')]; case 2: modelJSON = _b.apply(_a, [_c.sent()]); return [2 /*return*/, tf.io.getModelArtifactsForJSON(modelJSON, function (weightsManifest) { return _this.loadWeights(weightsManifest, path); })]; case 3: throw new Error('The path to load from must be a file. Loading from a directory ' + 'is not supported.'); } }); }); }; NodeFileSystem.prototype.loadWeights = function (weightsManifest, path) { return __awaiter(this, void 0, void 0, function () { var dirName, buffers, weightSpecs, _i, weightsManifest_1, group, _a, _b, path_2, weightFilePath, buffer; return __generator(this, function (_c) { switch (_c.label) { case 0: dirName = (0, path_1.dirname)(path); buffers = []; weightSpecs = []; _i = 0, weightsManifest_1 = weightsManifest; _c.label = 1; case 1: if (!(_i < weightsManifest_1.length)) return [3 /*break*/, 7]; group = weightsManifest_1[_i]; _a = 0, _b = group.paths; _c.label = 2; case 2: if (!(_a < _b.length)) return [3 /*break*/, 5]; path_2 = _b[_a]; weightFilePath = (0, path_1.join)(dirName, path_2); return [4 /*yield*/, readFile(weightFilePath) .catch(doesNotExistHandler('Weight file'))]; case 3: buffer = _c.sent(); buffers.push(buffer); _c.label = 4; case 4: _a++; return [3 /*break*/, 2]; case 5: weightSpecs.push.apply(weightSpecs, group.weights); _c.label = 6; case 6: _i++; return [3 /*break*/, 1]; case 7: return [2 /*return*/, [weightSpecs, (0, io_utils_1.toArrayBuffer)(buffers)]]; } }); }); }; /** * For each item in `this.path`, creates a directory at the path or verify * that the path exists as a directory. */ NodeFileSystem.prototype.createOrVerifyDirectory = function () { return __awaiter(this, void 0, void 0, function () { var paths, _i, paths_1, path, e_1; return __generator(this, function (_a) { switch (_a.label) { case 0: paths = Array.isArray(this.path) ? this.path : [this.path]; _i = 0, paths_1 = paths; _a.label = 1; case 1: if (!(_i < paths_1.length)) return [3 /*break*/, 9]; path = paths_1[_i]; _a.label = 2; case 2: _a.trys.push([2, 4, , 8]); return [4 /*yield*/, mkdir(path)]; case 3: _a.sent(); return [3 /*break*/, 8]; case 4: e_1 = _a.sent(); if (!(e_1.code === 'EEXIST')) return [3 /*break*/, 6]; return [4 /*yield*/, stat(path)]; case 5: if ((_a.sent()).isFile()) { throw new Error("Path ".concat(path, " exists as a file. The path must be ") + "nonexistent or point to a directory."); } return [3 /*break*/, 7]; case 6: throw e_1; case 7: return [3 /*break*/, 8]; case 8: _i++; return [3 /*break*/, 1]; case 9: return [2 /*return*/]; } }); }); }; NodeFileSystem.URL_SCHEME = 'file://'; return NodeFileSystem; }()); var nodeFileSystemRouter = function (url) { if (Array.isArray(url)) { if (url.every(function (urlElement) { return urlElement.startsWith(NodeFileSystem.URL_SCHEME); })) { return new NodeFileSystem(url.map(function (urlElement) { return urlElement.slice(NodeFileSystem.URL_SCHEME.length); })); } else { return null; } } else { if (url.startsWith(NodeFileSystem.URL_SCHEME)) { return new NodeFileSystem(url.slice(NodeFileSystem.URL_SCHEME.length)); } else { return null; } } }; exports.nodeFileSystemRouter = nodeFileSystemRouter; // Registration of `nodeFileSystemRouter` is done in index.ts. /** * Factory function for Node.js native file system IO Handler. * * @param path A single path or an Array of paths. * For saving: expects a single path pointing to an existing or nonexistent * directory. If the directory does not exist, it will be * created. * For loading: * - If the model has JSON topology (e.g., `tf.Model`), a single path * pointing to the JSON file (usually named `model.json`) is expected. * The JSON file is expected to contain `modelTopology` and/or * `weightsManifest`. If `weightManifest` exists, the values of the * weights will be loaded from relative paths (relative to the directory * of `model.json`) as contained in `weightManifest`. * - If the model has binary (protocol buffer GraphDef) topology, * an Array of two paths is expected: the first path should point to the * .pb file and the second path should point to the weight manifest * JSON file. */ function fileSystem(path) { return new NodeFileSystem(path); } exports.fileSystem = fileSystem;