UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

107 lines 3.96 kB
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); }; import { Variable } from '../../../autograd'; import { toCPU, toGPU, toWASM } from '../../../util/convert'; import { OnnxNode } from '../../node'; export class ConvNode extends OnnxNode { constructor(attributes, inputs, outputs, constants, onnxVersion, mode, kernel, bias, activation) { super(attributes, inputs, outputs, constants, onnxVersion, mode); const autoPad = this.getAttributeString('autoPad'); if (autoPad !== undefined) { throw new Error('Autopad in conv not supported yet'); } if (activation === undefined) { activation = 'id'; } this.activation = activation; this.group = this.getAttributeInt('group') || 1; this.dilations = this.getAttributeInts('dilations'); this.pads = this.getAttributeInts('pads'); this.strides = this.getAttributeInts('strides'); this.kernel = kernel; this.bias = bias; if (mode === 'train' && this.kernel !== undefined) { this.kernel = new Variable(this.kernel); } if (mode === 'train' && this.bias !== undefined) { this.bias = new Variable(this.bias); } } forward(inputs) { return __awaiter(this, void 0, void 0, function* () { const x = inputs[0]; const w = this.kernel !== undefined ? this.kernel : inputs[1]; const b = inputs.length > 2 ? inputs[2] : this.bias; return [ x.conv(w, b, this.dilations, this.group, this.pads, this.strides, this.activation), ]; }); } getDilations(rank) { if (this.dilations !== undefined) { return this.dilations; } return new Array(rank).fill(1); } getPads(rank) { if (this.pads !== undefined) { return this.pads; } return new Array(rank * 2).fill(0); } getStrides(rank) { if (this.strides !== undefined) { return this.strides; } return new Array(rank).fill(1); } getType() { return 'Conv'; } toCPU() { return __awaiter(this, void 0, void 0, function* () { if (this.kernel !== undefined) { this.kernel = yield toCPU(this.kernel); } if (this.bias !== undefined) { this.bias = yield toCPU(this.bias); } }); } toWASM() { return __awaiter(this, void 0, void 0, function* () { if (this.kernel !== undefined) { this.kernel = yield toWASM(this.kernel); } if (this.bias !== undefined) { this.bias = yield toWASM(this.bias); } }); } toGPU() { return __awaiter(this, void 0, void 0, function* () { if (this.kernel !== undefined) { this.kernel = yield toGPU(this.kernel); } if (this.bias !== undefined) { this.bias = yield toGPU(this.bias); } }); } delete() { if (this.kernel !== undefined) { this.kernel.delete(); } if (this.bias !== undefined) { this.bias.delete(); } } } //# sourceMappingURL=conv.js.map