catniff
Version:
A small Torch-like deep learning framework for Javascript
256 lines (255 loc) • 11.3 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.nn = void 0;
const core_1 = require("./core");
function linearTransform(input, weight, bias) {
let output = input.matmul(weight.t());
if (bias) {
output = output.add(bias);
}
return output;
}
class Linear {
weight;
bias;
constructor(inFeatures, outFeatures, bias = true, device) {
const bound = 1 / Math.sqrt(inFeatures);
this.weight = core_1.Tensor.uniform([outFeatures, inFeatures], -bound, bound, { requiresGrad: true, device });
if (bias) {
this.bias = core_1.Tensor.uniform([outFeatures], -bound, bound, { requiresGrad: true, device });
}
}
forward(input) {
input = core_1.Tensor.forceTensor(input);
return linearTransform(input, this.weight, this.bias);
}
}
function rnnTransform(input, hidden, inputWeight, hiddenWeight, inputBias, hiddenBias) {
let output = input.matmul(inputWeight.t()).add(hidden.matmul(hiddenWeight.t()));
if (inputBias) {
output = output.add(inputBias);
}
if (hiddenBias) {
output = output.add(hiddenBias);
}
return output;
}
class RNNCell {
weightIH;
weightHH;
biasIH;
biasHH;
constructor(inputSize, hiddenSize, bias = true, device) {
const bound = 1 / Math.sqrt(hiddenSize);
this.weightIH = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
this.weightHH = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
if (bias) {
this.biasIH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
this.biasHH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
}
}
forward(input, hidden) {
input = core_1.Tensor.forceTensor(input);
hidden = core_1.Tensor.forceTensor(hidden);
return rnnTransform(input, hidden, this.weightIH, this.weightHH, this.biasIH, this.biasHH).tanh();
}
}
class GRUCell {
weightIR;
weightIZ;
weightIN;
weightHR;
weightHZ;
weightHN;
biasIR;
biasIZ;
biasIN;
biasHR;
biasHZ;
biasHN;
constructor(inputSize, hiddenSize, bias = true, device) {
const bound = 1 / Math.sqrt(hiddenSize);
this.weightIR = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
this.weightIZ = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
this.weightIN = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
this.weightHR = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
this.weightHZ = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
this.weightHN = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
if (bias) {
this.biasIR = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
this.biasIZ = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
this.biasIN = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
this.biasHR = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
this.biasHZ = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
this.biasHN = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
}
}
forward(input, hidden) {
input = core_1.Tensor.forceTensor(input);
hidden = core_1.Tensor.forceTensor(hidden);
const r = rnnTransform(input, hidden, this.weightIR, this.weightHR, this.biasIR, this.biasHR).sigmoid();
const z = rnnTransform(input, hidden, this.weightIZ, this.weightHZ, this.biasIZ, this.biasHZ).sigmoid();
const n = linearTransform(input, this.weightIN, this.biasIN).add(r.mul(linearTransform(hidden, this.weightHN, this.biasHN))).tanh();
return (z.neg().add(1).mul(n).add(z.mul(hidden)));
}
}
class LSTMCell {
weightII;
weightIF;
weightIG;
weightIO;
weightHI;
weightHF;
weightHG;
weightHO;
biasII;
biasIF;
biasIG;
biasIO;
biasHI;
biasHF;
biasHG;
biasHO;
constructor(inputSize, hiddenSize, bias = true, device) {
const bound = 1 / Math.sqrt(hiddenSize);
this.weightII = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
this.weightIF = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
this.weightIG = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
this.weightIO = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
this.weightHI = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
this.weightHF = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
this.weightHG = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
this.weightHO = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
if (bias) {
this.biasII = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
this.biasIF = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
this.biasIG = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
this.biasIO = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
this.biasHI = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
this.biasHF = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
this.biasHG = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
this.biasHO = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
}
}
forward(input, hidden, cell) {
input = core_1.Tensor.forceTensor(input);
hidden = core_1.Tensor.forceTensor(hidden);
cell = core_1.Tensor.forceTensor(cell);
const i = rnnTransform(input, hidden, this.weightII, this.weightHI, this.biasII, this.biasHI).sigmoid();
const f = rnnTransform(input, hidden, this.weightIF, this.weightHF, this.biasIF, this.biasHF).sigmoid();
const g = rnnTransform(input, hidden, this.weightIG, this.weightHG, this.biasIG, this.biasHG).tanh();
const o = rnnTransform(input, hidden, this.weightIO, this.weightHO, this.biasIO, this.biasHO).sigmoid();
const c = f.mul(cell).add(i.mul(g));
const h = o.mul(c.tanh());
return [h, c];
}
}
class LayerNorm {
weight;
bias;
eps;
normalizedShape;
constructor(normalizedShape, eps = 1e-5, elementwiseAffine = true, bias = true, device) {
this.eps = eps;
this.normalizedShape = Array.isArray(normalizedShape) ? normalizedShape : [normalizedShape];
if (this.normalizedShape.length === 0) {
throw new Error("Normalized shape cannot be empty");
}
if (elementwiseAffine) {
this.weight = core_1.Tensor.ones(this.normalizedShape, { requiresGrad: true, device });
if (bias) {
this.bias = core_1.Tensor.zeros(this.normalizedShape, { requiresGrad: true, device });
}
}
}
forward(input) {
input = core_1.Tensor.forceTensor(input);
// Normalize over the specified dimensions
const normalizedDims = this.normalizedShape.length;
const startDim = input.shape.length - normalizedDims;
if (startDim < 0) {
throw new Error("Input does not have enough dims to normalize");
}
const dims = [];
for (let i = 0; i < normalizedDims; i++) {
if (input.shape[startDim + i] !== this.normalizedShape[i]) {
throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${this.normalizedShape[i]}, got ${input.shape[startDim + i]}`);
}
dims.push(startDim + i);
}
const mean = input.mean(dims, true);
const variance = input.sub(mean).pow(2).mean(dims, true);
let normalized = input.sub(mean).div(variance.add(this.eps).sqrt());
if (this.weight) {
normalized = normalized.mul(this.weight);
}
if (this.bias) {
normalized = normalized.add(this.bias);
}
return normalized;
}
}
const state = {
getParameters(model, visited = new WeakSet()) {
if (visited.has(model))
return [];
visited.add(model);
const parameters = [];
for (const key in model) {
if (!model.hasOwnProperty(key))
continue;
const value = model[key];
if (value instanceof core_1.Tensor) {
parameters.push(value);
}
else if (typeof value === "object" && value !== null) {
parameters.push(...state.getParameters(value, visited));
}
}
return parameters;
},
getStateDict(model, prefix = "", visited = new WeakSet()) {
if (visited.has(model))
return {};
visited.add(model);
const stateDict = {};
for (const key in model) {
if (!model.hasOwnProperty(key))
continue;
const value = model[key];
const fullKey = prefix ? `${prefix}.${key}` : key;
if (value instanceof core_1.Tensor) {
stateDict[fullKey] = value.val();
}
else if (typeof value === "object" && value !== null) {
Object.assign(stateDict, state.getStateDict(value, fullKey, visited));
}
}
return stateDict;
},
loadStateDict(model, stateDict, prefix = "", visited = new WeakSet()) {
if (visited.has(model))
return;
visited.add(model);
for (const key in model) {
if (!model.hasOwnProperty(key))
continue;
const value = model[key];
const fullKey = prefix ? `${prefix}.${key}` : key;
if (value instanceof core_1.Tensor && stateDict[fullKey]) {
value.replace(new core_1.Tensor(stateDict[fullKey], { device: value.device }));
}
else if (typeof value === "object" && value !== null) {
state.loadStateDict(value, stateDict, fullKey, visited);
}
}
}
};
exports.nn = {
Linear,
RNNCell,
GRUCell,
LSTMCell,
LayerNorm,
state
};