@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
113 lines • 3.86 kB
JavaScript
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
return new (P || (P = Promise))(function (resolve, reject) {
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
};
import { Variable } from '../autograd/variable';
import Tensor from '../types';
import { toCPU, toWASM, toGPU } from '../util/convert';
/**
* A module is a self contained unit that transforms
* a list of inputs when forward is called.
*
* It can be in two modes, training and inference.
* In training mode, gradients will be tracked, while
* in inference mode, only the forward pass will be calculated
*/
export class Module {
constructor() {
this.backend = 'CPU';
this.mode = 'train';
}
getSubModules() {
const modules = [];
for (const k of Object.keys(this)) {
//@ts-ignore
if (this[k] instanceof Module) {
//@ts-ignore
modules.push(this[k]);
}
}
return modules;
}
getParameters() {
let parameters = [];
for (const k of Object.keys(this)) {
//@ts-ignore
if (this[k] instanceof Variable) {
//@ts-ignore
parameters.push(this[k]);
}
}
const modules = this.getSubModules();
for (const module of modules) {
const params = module.getParameters();
parameters = parameters.concat(params);
}
return parameters;
}
toBackend(backend) {
if (backend === 'CPU') {
return this.toCPU();
}
else if (backend === 'WASM') {
return this.toWASM();
}
else {
return this.toGPU();
}
}
toCPU() {
return __awaiter(this, void 0, void 0, function* () {
const submodules = this.getSubModules();
for (const submodule of submodules) {
yield submodule.toCPU();
}
for (const k of Object.keys(this)) {
//@ts-ignore
if (this[k] instanceof Tensor) {
//@ts-ignore
this[k] = yield toCPU(this[k]);
}
}
this.backend = 'CPU';
});
}
toWASM() {
return __awaiter(this, void 0, void 0, function* () {
const submodules = this.getSubModules();
for (const submodule of submodules) {
yield submodule.toWASM();
}
for (const k of Object.keys(this)) {
//@ts-ignore
if (this[k] instanceof Tensor) {
//@ts-ignore
this[k] = yield toWASM(this[k]);
}
}
this.backend = 'WASM';
});
}
toGPU() {
return __awaiter(this, void 0, void 0, function* () {
const submodules = this.getSubModules();
for (const submodule of submodules) {
yield submodule.toGPU();
}
for (const k of Object.keys(this)) {
//@ts-ignore
if (this[k] instanceof Tensor) {
//@ts-ignore
this[k] = yield toGPU(this[k]);
}
}
this.backend = 'GPU';
});
}
}
//# sourceMappingURL=module.js.map