@shumai/shumai
Version:
A fast, network-connected, differentiable tensor library for TypeScript (and JavaScript). Built with bun + flashlight for software engineers and researchers alike.
28 lines (23 loc) • 697 B
text/typescript
import type { Tensor } from '../tensor'
import * as ops from '../tensor/tensor_ops'
import { Module } from './module'
const sm = { ...ops }
export class Linear extends Module {
weight: Tensor
bias: Tensor
constructor(inp_dim: number, out_dim: number) {
super()
const fan_in = Math.sqrt(2 / inp_dim)
this.weight = sm.xavier_uniform([inp_dim, out_dim], inp_dim, out_dim, Math.sqrt(2))
this.bias = sm.full([out_dim], 0)
this.weight.requires_grad = true
this.bias.requires_grad = true
}
forward(x: Tensor): Tensor {
x = x.matmul(this.weight)
return x.add(this.bias)
}
}
export function linear(inp_dim, out_dim) {
return new Linear(inp_dim, out_dim)
}