@shumai/shumai
Version:
A fast, network-connected, differentiable tensor library for TypeScript (and JavaScript). Built with bun + flashlight for software engineers and researchers alike.
59 lines (50 loc) • 1.77 kB
text/typescript
import type { Tensor } from '../tensor'
import * as tensor from '../tensor/tensor'
import * as ops from '../tensor/tensor_ops'
import { Module } from './module'
const sm = { ...ops, ...tensor }
export class LayerNorm extends Module {
private dims: number[]
private axes: number[]
private eps: Tensor
private gamma: Tensor
private beta: Tensor
constructor(dims: number[], eps?: number | Tensor) {
super()
if (dims.length === 0) {
throw new Error(`LayerNorm cannot be applied to scalars; dims cannot be []`)
}
this.dims = dims
this.axes = dims.map((x, i) => -1 * (i + 1))
if (eps === undefined) {
this.eps = sm.scalar(1e-6)
} else if (typeof eps === 'number') {
this.eps = sm.scalar(eps)
} else if (eps.shape.length === 0) {
this.eps = eps
} else {
throw new Error(`Parameter eps (${eps}) must be a number or scalar Tensor`)
}
if (this.eps.greaterThan(sm.scalar(0)).toUint8Array()[0] === 0) {
throw new Error(`Parameter eps (${eps}) must be greater than 0`)
}
this.resetParameters()
}
resetParameters(): void {
this.gamma = sm.full(this.dims, 1).requireGrad()
this.beta = sm.full(this.dims, 0).requireGrad()
}
forward(tensor: Tensor): Tensor {
const shape = tensor.shape
for (const negAxis of this.axes) {
if (shape[shape.length + negAxis] !== this.dims[this.dims.length + negAxis]) {
throw new Error(
`Final axes of input tensor (shape ${shape}) must match module dimensions (${this.dims})`
)
}
}
const mean = tensor.mean(this.axes, true)
const std = tensor.variance(this.axes, false, true).add(this.eps).sqrt()
return tensor.sub(mean).div(std).mul(this.gamma).add(this.beta)
}
}