@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
141 lines • 4.55 kB
JavaScript
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
return new (P || (P = Promise))(function (resolve, reject) {
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
};
import { Variable } from '../autograd/variable';
import { CPUTensor } from '../tensor/cpu/tensor';
import { normal } from '../util/math';
import { Module } from './module';
/**
* Linear layer calculates y=xW + b
*
* W is initialized with Xavier initialization, while the bias is
* initialized to zeros
*/
export class Linear extends Module {
/**
* Creates a linear layer
* @param dimIn Feature dimension of the input
* @param dimOut Feature dimension of the output
* @param bias Wether a bias should be added or not. Defaults to true
*/
constructor(dimIn, dimOut, bias) {
super();
bias = bias === undefined ? true : bias;
const weightVals = normal(dimIn * dimOut, 0, 2 / (dimIn + dimOut));
const tensor = new CPUTensor([dimIn, dimOut], weightVals);
this.weights = new Variable(tensor);
if (bias) {
const biasVals = new Array(dimOut).fill(0);
const tensorBias = new CPUTensor([1, dimOut], biasVals);
this.bias = new Variable(tensorBias);
}
}
forward(inputs) {
return __awaiter(this, void 0, void 0, function* () {
return [inputs[0].gemm(this.weights, false, false, 1, this.bias)];
});
}
}
/**
* Rectified linear unit, calculates y = max(x,0)
*/
export class Relu extends Module {
forward(inputs) {
return __awaiter(this, void 0, void 0, function* () {
return [inputs[0].clip(0)];
});
}
}
/**
* Sequence of modules. Passes the input sequentially into the specified modules
*/
export class Sequential extends Module {
constructor(modules) {
super();
this.modules = modules;
}
forward(inputs) {
return __awaiter(this, void 0, void 0, function* () {
let x = inputs;
for (let i = 0; i < this.modules.length; i++) {
const oldX = x;
x = yield this.modules[i].forward(x);
if (this.mode === 'inference' && i > 0) {
for (let j = 0; j < oldX.length; j++) {
oldX[j].delete();
}
}
}
return x;
});
}
getSubModules() {
const modules = super.getSubModules();
return modules.concat(this.modules);
}
}
/**
* Dictionary of modules. Use this if you want to store submodules in a dictionary
*/
export class ModuleDict extends Module {
constructor(modules = {}) {
super();
this.modules = modules;
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
forward(inputs) {
return __awaiter(this, void 0, void 0, function* () {
throw new Error('Module dict does not support forward');
});
}
getSubModules() {
const modules = [];
for (const k in this.modules) {
modules.push(this.modules[k]);
}
return modules;
}
get(key) {
return this.modules[key];
}
set(key, module) {
this.modules[key] = module;
}
}
/**
* List of modules. Use this if you want to store submodules in a list
*/
export class ModuleList extends Module {
constructor(modules = []) {
super();
this.modules = modules;
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
forward(inputs) {
return __awaiter(this, void 0, void 0, function* () {
throw new Error('Module list does not support forward');
});
}
getSubModules() {
return this.modules;
}
get(index) {
return this.modules[index];
}
set(index, module) {
this.modules[index] = module;
}
push(module) {
this.modules.push(module);
}
pop() {
return this.modules.pop();
}
}
//# sourceMappingURL=basic.js.map