ggml-js
Version:
JS bindings for the ggml library.
132 lines (114 loc) • 4.38 kB
JavaScript
import { Context, Module, Embedding, LayerNorm, Linear, F } from 'ggml-js/core'
import { CausalLM } from './index.js'
/**
* RWKV model
* @see https://github.com/BlinkDL/RWKV-LM
*
* based on https://johanwind.github.io/2023/03/23/rwkv_details.html
* and https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py
* and https://github.com/saharNooby/rwkv.cpp
*/
export class RWKV extends CausalLM {
#state
/**
* Create a new RWKV model.
* @param {Context} context
* @param {{ vocabSize: number, embedDim: number, numLayers: number }} hparams
*/
constructor(context, { vocabSize, embedDim, numLayers }) {
super(context)
this.emb = new Embedding(this, vocabSize, embedDim)
this.blocks = Array.from(Array(numLayers), _ => new Block(this, embedDim))
this.ln_out = new LayerNorm(this, embedDim)
this.head = new Linear(this, embedDim, vocabSize, { bias: false })
}
getInitialState(ctx) {
const embedDim = this.emb.weight.shape[1]
return Array.from(Array(this.blocks.length * 5), (_, i) =>
ctx.newTensor1D('f32', embedDim).setAll(i % 5 === 4 ? -1e30 : 0)
)
}
static loadHparams(header) {
const [vocabSize, embedDim] = header[`emb.weight`].shape
const numLayers = Object.keys(header).filter(k => k.endsWith('.ln1.weight')).length
return { vocabSize, embedDim, numLayers }
}
forward(x, state, updates = []) {
x = this.emb.forward(x)
x = this.blocks.reduce((x, block, i, _, o = i * 5) => block.forward(x, state.slice(o, o + 5), updates), x)
return this.head.forward(this.ln_out.forward(x))
}
}
class Block extends Module {
constructor(parentModule, dim) {
super(parentModule)
this.ln1 = new LayerNorm(this, dim)
this.att = new TimeMix(this, dim)
this.ln2 = new LayerNorm(this, dim)
this.ffn = new ChannelMix(this, dim)
}
forward(x, state, updates) {
x = F.add(x, this.att.forward(this.ln1.forward(x), state, updates))
return F.add(x, this.ffn.forward(this.ln2.forward(x), state, updates))
}
}
class TimeMix extends Module {
constructor(parentModule, dim) {
super(parentModule)
this.time_decay = this.context.newTensor1D('f32', dim)
this.time_first = this.context.newTensor1D('f32', dim)
this.time_mix_k = this.context.newTensor1D('f32', dim)
this.time_mix_v = this.context.newTensor1D('f32', dim)
this.time_mix_r = this.context.newTensor1D('f32', dim)
this.key = new Linear(this, dim, dim, { bias: false })
this.value = new Linear(this, dim, dim, { bias: false })
this.receptance = new Linear(this, dim, dim, { bias: false })
this.output = new Linear(this, dim, dim, { bias: false })
}
forward(x, [_, prev_x, aa, bb, pp], updates) {
const xk = mix(x, prev_x, this.time_mix_k)
const xv = mix(x, prev_x, this.time_mix_v)
const xr = mix(x, prev_x, this.time_mix_r)
const r = F.sigmoid(this.receptance.forward(xr))
const k = this.key.forward(xk)
const v = this.value.forward(xv)
let ww = F.add(k, this.time_first)
let qq = F.max(pp, ww)
let e1 = F.exp(F.sub(pp, qq))
let e2 = F.exp(F.sub(ww, qq))
const a = F.add(F.mul(e1, aa), F.mul(e2, v))
const b = F.add(F.mul(e1, bb), e2)
const wkv = F.div(a, b)
ww = F.add(pp, this.time_decay)
qq = F.max(ww, k)
e1 = F.exp(F.sub(ww, qq))
e2 = F.exp(F.sub(k, qq))
updates.push(
// dest, src
[prev_x, x],
[aa, F.add(F.mul(e1, aa), F.mul(e2, v))],
[bb, F.add(F.mul(e1, bb), e2)],
[pp, qq]
)
return this.output.forward(F.mul(r, wkv))
}
}
class ChannelMix extends Module {
constructor(parentModule, dim) {
super(parentModule)
this.time_mix_k = this.context.newTensor1D('f32', dim)
this.time_mix_r = this.context.newTensor1D('f32', dim)
this.key = new Linear(this, dim, 4 * dim, { bias: false })
this.receptance = new Linear(this, dim, dim, { bias: false })
this.value = new Linear(this, 4 * dim, dim, { bias: false })
}
forward(x, [prev_x], updates) {
const xk = mix(x, prev_x, this.time_mix_k)
const xr = mix(x, prev_x, this.time_mix_r)
const r = F.sigmoid(this.receptance.forward(xr))
const k = F.square(F.relu(this.key.forward(xk)))
updates.push([prev_x, x])
return F.mul(r, this.value.forward(k))
}
}
const mix = (x, prev_x, w) => F.sub(F.add(F.mul(x, w), prev_x), F.mul(prev_x, w))