UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

141 lines 4.55 kB
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); }; import { Variable } from '../autograd/variable'; import { CPUTensor } from '../tensor/cpu/tensor'; import { normal } from '../util/math'; import { Module } from './module'; /** * Linear layer calculates y=xW + b * * W is initialized with Xavier initialization, while the bias is * initialized to zeros */ export class Linear extends Module { /** * Creates a linear layer * @param dimIn Feature dimension of the input * @param dimOut Feature dimension of the output * @param bias Wether a bias should be added or not. Defaults to true */ constructor(dimIn, dimOut, bias) { super(); bias = bias === undefined ? true : bias; const weightVals = normal(dimIn * dimOut, 0, 2 / (dimIn + dimOut)); const tensor = new CPUTensor([dimIn, dimOut], weightVals); this.weights = new Variable(tensor); if (bias) { const biasVals = new Array(dimOut).fill(0); const tensorBias = new CPUTensor([1, dimOut], biasVals); this.bias = new Variable(tensorBias); } } forward(inputs) { return __awaiter(this, void 0, void 0, function* () { return [inputs[0].gemm(this.weights, false, false, 1, this.bias)]; }); } } /** * Rectified linear unit, calculates y = max(x,0) */ export class Relu extends Module { forward(inputs) { return __awaiter(this, void 0, void 0, function* () { return [inputs[0].clip(0)]; }); } } /** * Sequence of modules. Passes the input sequentially into the specified modules */ export class Sequential extends Module { constructor(modules) { super(); this.modules = modules; } forward(inputs) { return __awaiter(this, void 0, void 0, function* () { let x = inputs; for (let i = 0; i < this.modules.length; i++) { const oldX = x; x = yield this.modules[i].forward(x); if (this.mode === 'inference' && i > 0) { for (let j = 0; j < oldX.length; j++) { oldX[j].delete(); } } } return x; }); } getSubModules() { const modules = super.getSubModules(); return modules.concat(this.modules); } } /** * Dictionary of modules. Use this if you want to store submodules in a dictionary */ export class ModuleDict extends Module { constructor(modules = {}) { super(); this.modules = modules; } // eslint-disable-next-line @typescript-eslint/no-unused-vars forward(inputs) { return __awaiter(this, void 0, void 0, function* () { throw new Error('Module dict does not support forward'); }); } getSubModules() { const modules = []; for (const k in this.modules) { modules.push(this.modules[k]); } return modules; } get(key) { return this.modules[key]; } set(key, module) { this.modules[key] = module; } } /** * List of modules. Use this if you want to store submodules in a list */ export class ModuleList extends Module { constructor(modules = []) { super(); this.modules = modules; } // eslint-disable-next-line @typescript-eslint/no-unused-vars forward(inputs) { return __awaiter(this, void 0, void 0, function* () { throw new Error('Module list does not support forward'); }); } getSubModules() { return this.modules; } get(index) { return this.modules[index]; } set(index, module) { this.modules[index] = module; } push(module) { this.modules.push(module); } pop() { return this.modules.pop(); } } //# sourceMappingURL=basic.js.map