UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

450 lines 16.9 kB
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); }; // eslint-disable-next-line node/no-extraneous-import import Long from 'long'; import { onnx } from 'onnx-proto'; import { Variable } from '../autograd/variable'; import { Module } from '../model/module'; import { glContext } from '../tensor/gpu/gl'; import { toCPU, toGPU, toWASM } from '../util/convert'; import { ConstantNode } from './nodes/constant'; import { defaultOptimizations } from './optimizations/default'; import { nodeResolve } from './resolve'; import { createTensor } from './util'; export class OnnxModel extends Module { /** * Builds a new onnx model * * @param buffer Onnx model * @param args Optional arguments for the model */ constructor(buffer, args) { super(); this.inputSet = new Set(); this.nodes = {}; this.nodeIds = []; this.defaultReady = []; this.intermediaries = {}; this.constants = {}; this.nodeIdCounter = 10000; if (args === undefined) { args = {}; } this.noConvertConstants = new Set(args.noConvertConstants !== undefined ? args.noConvertConstants : []); this.noConvertNodes = new Set(args.noConvertNodes !== undefined ? args.noConvertNodes : []); this.mode = args.mode || 'inference'; this.precision = args.precision || 32; let arr; if (buffer instanceof ArrayBuffer) { arr = new Uint8Array(buffer); } else { arr = buffer; } this.modelProto = onnx.ModelProto.decode(arr); let ver = this.modelProto.opsetImport[0].version; if (Long.isLong(ver)) { ver = ver.toNumber(); } this.version = ver; //@ts-ignore this.inputs = this.modelProto.graph.input; for (let i = 0; i < this.inputs.length; i++) { this.inputSet.add(this.inputs[i].name); } //@ts-ignore this.outputs = this.modelProto.graph.output.map(x => x.name); //@ts-ignore this.initializer(this.modelProto.graph.initializer); this.initNodes(this.modelProto); } initNodes(modelProto) { //@ts-ignore for (let i = 0; i < modelProto.graph.node.length; i++) { //@ts-ignore const nodeData = modelProto.graph.node[i]; //@ts-ignore const cls = nodeResolve[nodeData.opType]; if (cls === undefined) { throw new Error(`Node operator ${nodeData.opType} can not be resolved`); } const attributes = nodeData.attribute || []; const inputs = nodeData.input || []; const outputs = nodeData.output || []; const node = cls(attributes, inputs, outputs, this.constants, this.version, this.mode); this.nodes[i] = node; this.nodeIds.push(i); for (let j = 0; j < inputs.length; j++) { const input = inputs[j]; if (this.intermediaries[input] === undefined) { this.intermediaries[input] = { to: [], deletable: true, }; } this.intermediaries[input].to.push(i); } if (node.variableInputs === 0) { this.defaultReady.push(i); } if (nodeData.opType === 'Constant') { //@ts-ignore if (this.intermediaries[nodeData.output[0]] === undefined) { //@ts-ignore this.intermediaries[nodeData.output[0]] = { to: [], deletable: false, }; } else { //@ts-ignore this.intermediaries[nodeData.output[0]].deletable = false; } } } for (const nodeId of this.nodeIds) { this.nodes[nodeId].initialize(name => this.resolveConstant(name)); } } initializer(initializer) { for (let i = 0; i < initializer.length; i++) { const tensorProto = initializer[i]; let tensor = createTensor(tensorProto, this.precision === 16); if (this.mode === 'train') { tensor = new Variable(tensor); } //@ts-ignore this.constants[tensorProto.name] = tensor; } } /** * Do a forward pass for the specified inputs * * @param wait Number of milliseconds to wait between each layer. This * is especially useful, if your model is complex and * you dont want your model to block your whole application. * @param returnIntermediary return after the given intermediary result * has been computed. */ forward(inputs, wait) { return __awaiter(this, void 0, void 0, function* () { const intermediaryRes = {}; const nodes = {}; for (const i of this.nodeIds) { nodes[i] = { variableInputs: 0, }; } const nodesReady = [...this.defaultReady]; this.initializeForward(inputs, intermediaryRes, nodes, nodesReady); while (nodesReady.length > 0) { const nodeId = nodesReady.shift(); //@ts-ignore const node = this.nodes[nodeId]; const { inputs, toDelete } = this.getInputsToNode(node, intermediaryRes); let outputs; try { outputs = yield node.forward(inputs); } catch (e) { console.error(`Error occurred in node ${nodeId} with inputs ${node.inputs} from nodes ${node.inputs.map((x) => this.getNodeWithOutput(x))}`); throw e; } glContext.flush(); this.propagateResults(node, intermediaryRes, outputs, nodes, nodesReady); for (let i = 0; i < toDelete.length; i++) { if (!this.inputSet.has(toDelete[i])) { const inter = intermediaryRes[toDelete[i]]; inter.value.delete(); delete intermediaryRes[toDelete[i]]; } } if (wait !== undefined) { yield new Promise(resolve => { setTimeout(resolve, wait); }); } } const outputs = []; for (let i = 0; i < this.outputs.length; i++) { outputs.push(intermediaryRes[this.outputs[i]].value); } return outputs; }); } initializeForward(inputs, intermediaryRes, nodes, nodesReady) { for (let i = 0; i < inputs.length; i++) { //@ts-ignore intermediaryRes[this.inputs[i].name] = { value: inputs[i], used: 0, }; //@ts-ignore const inter = this.intermediaries[this.inputs[i].name]; for (let j = 0; j < inter.to.length; j++) { const id = inter.to[j]; nodes[id].variableInputs++; if (nodes[id].variableInputs === this.nodes[id].variableInputs) { nodesReady.push(id); delete nodes[id]; } } } } getInputsToNode(node, intermediaryRes) { const inputs = []; const toDelete = []; for (let i = 0; i < node.inputs.length; i++) { const input = node.inputs[i]; if (this.constants[input] !== undefined) { inputs.push(this.constants[input]); } else { const inter = intermediaryRes[input]; inter.used++; if (inter.used >= this.intermediaries[input].to.length && this.intermediaries[input].deletable) { toDelete.push(input); } inputs.push(inter.value); } } return { inputs, toDelete }; } propagateResults(node, intermediaryRes, outputs, nodes, nodesReady) { for (let i = 0; i < node.outputs.length; i++) { const output = node.outputs[i]; intermediaryRes[output] = { value: outputs[i], used: 0, }; const inter = this.intermediaries[output]; if (inter !== undefined) { for (let j = 0; j < inter.to.length; j++) { const id = inter.to[j]; nodes[id].variableInputs++; if (nodes[id].variableInputs === this.nodes[id].variableInputs) { nodesReady.push(id); delete nodes[id]; } } } } } /** * Transfer the model to the CPU */ toCPU() { return __awaiter(this, void 0, void 0, function* () { for (const i in this.constants) { if (!this.noConvertConstants.has(i)) { this.constants[i] = yield toCPU(this.constants[i]); } } for (const i of this.nodeIds) { if (!this.noConvertNodes.has(i)) { yield this.nodes[i].toCPU(); } } }); } /** * Transfer the model to WASM */ toWASM() { return __awaiter(this, void 0, void 0, function* () { for (const i in this.constants) { if (!this.noConvertConstants.has(i)) { this.constants[i] = yield toWASM(this.constants[i]); } } for (const i of this.nodeIds) { if (!this.noConvertNodes.has(i)) { yield this.nodes[i].toWASM(); } } }); } /** * Transfer the model to the GPU */ toGPU() { return __awaiter(this, void 0, void 0, function* () { for (const i in this.constants) { if (!this.noConvertConstants.has(i)) { this.constants[i] = yield toGPU(this.constants[i]); } } for (const i of this.nodeIds) { if (!this.noConvertNodes.has(i)) { yield this.nodes[i].toGPU(); } } }); } /** * Optimize the model. */ optimize() { for (const optimization of defaultOptimizations) { //@ts-ignore const applications = optimization.findApplications(this); for (const nodeIds of applications) { const nodes = nodeIds.map(x => this.nodes[x]); const newNode = optimization.apply(nodes, name => this.resolveConstant(name), this.constants, this.version); const outputs = new Set(newNode.outputs); for (const nodeId of nodeIds) { this.removeNode(nodeId, outputs); } this.insertNode(newNode); } } this.prune(); } prune(intermediariesToDelete) { // eslint-disable-next-line no-constant-condition while (true) { const nodesToDelete = this.pruneIntermediaries(intermediariesToDelete); intermediariesToDelete = []; if (nodesToDelete.size > 0) { nodesToDelete.forEach(id => { const interToDelete = this.removeNode(id, new Set()); intermediariesToDelete = intermediariesToDelete === null || intermediariesToDelete === void 0 ? void 0 : intermediariesToDelete.concat(interToDelete); }); } else { break; } } } pruneIntermediaries(intermediariesToDelete) { const nodesToDelete = new Set(); if (intermediariesToDelete === undefined) { intermediariesToDelete = []; } for (let i = 0; i < intermediariesToDelete.length; i++) { const id = intermediariesToDelete[i]; const nodeOutputId = this.getNodeWithOutput(id); if (nodeOutputId !== undefined) { nodesToDelete.add(nodeOutputId); } const nodeInputId = this.getNodeWithInput(id); if (nodeInputId !== undefined) { nodesToDelete.add(nodeInputId); } } for (const id in this.intermediaries) { const intermediary = this.intermediaries[id]; if (intermediary.to.length === 0 && this.outputs.find(x => x === id) === undefined) { intermediariesToDelete.push(id); const nodeOutputId = this.getNodeWithOutput(id); if (nodeOutputId !== undefined) { nodesToDelete.add(nodeOutputId); } const nodeInputId = this.getNodeWithInput(id); if (nodeInputId !== undefined) { nodesToDelete.add(nodeInputId); } } } for (const id of intermediariesToDelete) { delete this.intermediaries[id]; } return nodesToDelete; } removeNode(nodeId, preserveIntermediaries) { const node = this.nodes[nodeId]; for (const input of node.inputs) { if (this.intermediaries[input] !== undefined) { this.intermediaries[input].to = this.intermediaries[input].to.filter(x => x.toString() !== nodeId.toString()); } } const intermediariesToDelete = []; if (!preserveIntermediaries.has(node.outputs[0])) { intermediariesToDelete.push(node.outputs[0]); } this.nodeIds = this.nodeIds.filter(x => x.toString() !== nodeId.toString()); this.nodes[nodeId].delete(); delete this.nodes[nodeId]; this.defaultReady = this.defaultReady.filter(x => x !== nodeId); return intermediariesToDelete; } insertNode(node) { const id = this.nodeIdCounter++; this.nodeIds.push(id); this.nodes[id] = node; for (const input of node.inputs) { this.intermediaries[input].to.push(id); } } // Utility functions getNodeWithOutput(output) { for (const id of this.nodeIds) { if (this.nodes[id].outputs.findIndex(x => x === output) !== -1) { return id; } } return undefined; } getNodeWithInput(output) { for (const id of this.nodeIds) { if (this.nodes[id].inputs.findIndex(x => x === output) !== -1) { return id; } } return undefined; } resolveConstant(name) { if (this.constants[name] !== undefined) { return this.constants[name]; } const nodeIdOut = this.getNodeWithOutput(name); //@ts-ignore const nodeOut = this.nodes[nodeIdOut]; if (nodeOut instanceof ConstantNode) { return nodeOut.tensor; } return undefined; } getNodes() { return this.nodes; } /** * Deletes the model * * This will release the memory/framebuffers (depending on the backend) */ delete() { for (const c in this.constants) { this.constants[c].delete(); } for (const nodeId of this.nodeIds) { this.nodes[nodeId].delete(); } } getSubModules() { const modules = super.getSubModules(); for (const nodeId of this.nodeIds) { modules.push(this.nodes[nodeId]); } return modules; } getParameters() { const parameters = super.getParameters(); for (const c in this.constants) { if (this.constants[c] instanceof Variable) { parameters.push(this.constants[c]); } } return parameters; } } //# sourceMappingURL=model.js.map