@shumai/shumai
Version:
A fast, network-connected, differentiable tensor library for TypeScript (and JavaScript). Built with bun + flashlight for software engineers and researchers alike.
629 lines (557 loc) • 24.1 kB
text/typescript
import type { Tensor } from '../tensor'
import * as tensor from '../tensor/tensor'
import * as ops from '../tensor/tensor_ops'
import * as util from '../util'
import { Linear } from './linear'
import { Module } from './module'
import { LayerNorm } from './norm'
import { Sequential } from './sequential'
const sm = { ...ops, ...tensor, util }
/**
* A module to generate the positional encoding for a Transformer of a given input dimension,
*
* $$ \mathrm{PE}_{i, 2z} = \sin \left( \frac{i}{10000^{2z/d}} \right) $$
*
* $$ \mathrm{PE}_{i, 2z + 1} = \cos \left( \frac{i}{10000^{2z/d}} \right) $$
*
* where $i$ is the sequence position, $2z$ and $2z+1$ are the dimensions of the input embedding, and $d$ is the dimensionality of the input embedding.
*
* The multiplicative factors $\frac{1}{10000^{2z/d}}$ are precomputed during object creation as they are constant for all $i$.
*
* The full PE is initially precomputed for all $i$ up to 256 (or `initSequenceLength` given in the constructor). If the module is called with a sequence length larger than what has already been computed, the additional PE values are also calculated and then stored.
*/
export class TransformerPositionalEncoding extends Module {
/**
* The default `initSequenceLength` if none is supplied in the constructor.
*/
static readonly DEFAULT_SEQUENCE_LENGTH = 256
/**
* The base of the exponent in the positional encoding.
*/
static readonly ENCODING_BASE = 10000
private dim: number
private sequenceLength: number
private encodingFactors: Tensor
private encoding: Tensor
/**
* @param dim - Number of dimensions of each input embedding
* @param initSequenceLength - Initial sequence length that the positional embedding should be computed for, or {@link DEFAULT_SEQUENCE_LENGTH} if not specified
*/
constructor(dim: number, initSequenceLength?: number) {
super()
if (dim <= 0) {
throw new Error(`Module dimension must be > 0: got ${dim}`)
}
this.dim = dim
if (initSequenceLength === undefined) {
this.sequenceLength = TransformerPositionalEncoding.DEFAULT_SEQUENCE_LENGTH
} else if (initSequenceLength <= 0) {
throw new Error(`Initial sequenceLength must be > 0: got ${initSequenceLength}`)
} else {
this.sequenceLength = initSequenceLength
}
// base and numerator must be full([1], x) instead of scalar(x)
// Otherwise, if the other operand has shape [1], the result will be reduced to scalar
const base = sm.full([1], TransformerPositionalEncoding.ENCODING_BASE)
const numerator = sm.full([1], 1)
const denominator = sm.scalar(this.dim)
const evenDims = sm.arange(0, this.dim + 1, 2)
this.encodingFactors = numerator.div(base.power(evenDims.div(denominator))) // shape [floor((dim + 1) / 2)]
this.encoding = this.calculate(0, this.sequenceLength) // shape [sequenceLength, dim]
}
/**
* Calculate positional encodings at a given range of sequence positions.
*
* @param start - Start of the range to calculate
* @param end - End of the range to calculate
*
* @returns a Tensor of calculated positional embeddings with shape `[end - start, dim]`
*/
calculate(start: number, end: number): Tensor {
const length = end - start
const pairedDim = this.encodingFactors.shape[0]
const pos = sm.arange(start, end).reshape([length, 1])
const evenEncoding = sm.sin(pos.mul(this.encodingFactors)).reshape([length, pairedDim, 1])
const oddEncoding = sm.cos(pos.mul(this.encodingFactors)).reshape([length, pairedDim, 1])
let encoding = sm.concatenate([evenEncoding, oddEncoding], -1) // shape [length, pairedDim, 2]
encoding = encoding.reshape([length, pairedDim * 2])
if (this.dim % 2 !== 0) {
encoding = encoding.index([':', `:${this.dim}`]).reshape([length, this.dim])
// reshape is necessary to preserve the last axis if this.dim is 1
}
return encoding
}
/**
* @param sequenceLength - Length of the sequence for which the positional embedding should be calculated
* @returns a Tensor of positional embeddings with shape `[length, dim]`, using precomputed values if available
*/
forward(sequenceLength: number): Tensor {
if (sequenceLength > this.sequenceLength) {
const extension = this.calculate(this.sequenceLength, sequenceLength)
this.encoding = sm.concatenate([this.encoding, extension], 0)
this.sequenceLength = sequenceLength
}
if (sequenceLength === this.sequenceLength) {
return this.encoding
} else {
return this.encoding.index([`:${sequenceLength}`, ':'])
}
}
}
function checkAttentionInputs(
attentionDim: number,
queries: Tensor,
keys: Tensor,
values: Tensor,
mask?: Tensor
) {
const shape = queries.shape
if (keys.shape.length !== shape.length || values.shape.length !== shape.length) {
throw new Error(
`Input tensors must have the same shape, except the 2nd last axis: queries shape ${shape}, keys shape ${keys.shape}, values shape ${values.shape}`
)
}
for (let i = 0; i < shape.length; i++) {
if (shape[i] !== keys.shape[i] || shape[i] !== values.shape[i]) {
if (i !== shape.length - 2) {
throw new Error(
`Input tensors must have the same shape, except the 2nd last axis: queries shape ${shape}, keys shape ${keys.shape}, values shape ${values.shape}`
)
} else if (keys.shape[i] !== values.shape[i]) {
throw new Error(
`Tensors keys and values must have the same shape: keys shape ${keys.shape}, values shape ${values.shape}`
)
}
}
}
const dim = shape[shape.length - 1]
if (dim !== attentionDim) {
throw new Error(
`Last axis of input tensors (${dim}) must match attention dimension (${attentionDim})`
)
}
if (mask !== undefined) {
const maskShape = mask.shape
if (
maskShape.length !== 2 ||
maskShape[0] !== shape[shape.length - 2] ||
maskShape[1] !== keys.shape[keys.shape.length - 2]
) {
throw new Error(
`Mask shape (${maskShape}) must match sequence lengths of queries and keys: must be [${
shape[shape.length - 2]
}, ${keys.shape[keys.shape.length - 2]}]`
)
}
}
}
/**
* Scaled dot-product mechanism as described by Vaswani et al. The {@link scaleFactor} is computed during object creation as $\frac{1}{\sqrt{d}}$, where $d$ is the dimensionality of the inputs.
*/
export class TransformerDotProductAttention extends Module {
private dim: number
private scaleFactor: Tensor
/**
* @param dim - Number of dimensions of the inputs
*/
constructor(dim: number) {
super()
this.dim = dim
this.scaleFactor = sm.scalar(1 / Math.sqrt(dim))
}
protected scale(tensor: Tensor): Tensor {
return tensor.mul(this.scaleFactor)
}
/**
* @param queries - Tensor of query embeddings, shape `[..., queryTokens, dim]`
* @param keys - Tensor of key embeddings, shape `[..., keyTokens, dim]`
* @param values - Tensor of value embeddings each corresponding to a key, shape `[..., keyTokens, dim]`
* @param mask - Tensor mask of shape `[queryTokens, keyTokens]` where a 1 in position $(i, j)$ indicates that the $i$th query should not attend to the $j$th key
* @returns A Tensor of shape `[..., queryTokens, dim]`
*/
forward(queries: Tensor, keys: Tensor, values: Tensor, mask?: Tensor): Tensor {
// queries shape [..., queryTokens, dim]
// keys and values shape [..., keyTokens, dim]
// mask shape [queryTokens, keyTokens]
checkAttentionInputs(this.dim, queries, keys, values, mask)
let output = queries.matmul(keys.T()) // shape [..., queryTokens, keyTokens]
if (mask !== undefined) {
if (output.shape.length > 2) {
// mask.shape.length is always 2
const tile = output.shape
tile[tile.length - 1] = 1
tile[tile.length - 2] = 1
mask = mask.tile(tile)
}
const negativeInfinities = sm.full([1], -Infinity).tile(output.shape)
output = sm.where(mask, negativeInfinities, output)
}
output = this.scale(output).softmax(-1)
output = output.matmul(values) // shape [..., queryTokens, dim]
return output
}
}
/**
* Multi-head attention mechanism as described by Vaswani et al. The input Tensors are linearly embedded before being passed to {@link TransformerDotProductAttention | scaled dot-product attentions}.
*/
export class TransformerMultiheadAttention extends Module {
private dim: number
private heads: number
private attentionDim: number
private queryEmbed: Linear
private keyEmbed: Linear
private valueEmbed: Linear
private attention: TransformerDotProductAttention
private concatEmbed: Linear
/**
* @param dim - Number of dimensions of the input embeddings
* @param heads - Number of heads for the multi-head attention
* @param attentionDim - Number of dimensions of the further embeddings which are passed to the scaled dot-product attention mechanisms, or `dim` if not specified
*/
constructor(dim: number, heads: number, attentionDim?: number) {
super()
if (dim % heads !== 0) {
throw new Error(
`Model dimensions must be divisible by the number of heads: ${dim} not divisible by ${heads}`
)
}
this.dim = dim
this.heads = heads
if (attentionDim === undefined) {
this.attentionDim = dim / heads
} else {
this.attentionDim = attentionDim
}
this.queryEmbed = new Linear(dim, this.attentionDim * heads)
this.keyEmbed = new Linear(dim, this.attentionDim * heads)
this.valueEmbed = new Linear(dim, this.attentionDim * heads)
this.attention = new TransformerDotProductAttention(this.attentionDim)
this.concatEmbed = new Linear(this.attentionDim * heads, dim)
}
/**
* @param queries - Tensor of query vectors, shape `[..., queryTokens, dim]`
* @param keys - Tensor of key vectors, shape `[..., keyTokens, dim]`
* @param values - Tensor of value vectors each corresponding to a key, shape `[..., keyTokens, dim]`
* @param mask - Tensor mask of shape `[queryTokens, keyTokens]` for the {@link TransformerDotProductAttention}
* @returns A Tensor of shape `[..., queryTokens, dim]`
*/
forward(queries: Tensor, keys: Tensor, values: Tensor, mask?: Tensor): Tensor {
// queries shape [..., queryTokens, dim]
// keys and values shape [..., keyTokens, dim]
checkAttentionInputs(this.dim, queries, keys, values)
const originalShape = queries.shape
const queriesReshape = [...originalShape] // shape [..., queryTokens, dim]
queriesReshape[queriesReshape.length - 1] = this.heads
queriesReshape.push(this.attentionDim) // shape [..., queryTokens, heads, attentionDim]
const keysValuesReshape = [...queriesReshape]
keysValuesReshape[keysValuesReshape.length - 3] = keys.shape[keys.shape.length - 2] // shape [..., keyTokens, heads, attentionDim]
// Swap 2nd and 3rd last axes
const transpose = Array.from(sm.util.range(queriesReshape.length))
transpose[transpose.length - 3] = transpose.length - 2
transpose[transpose.length - 2] = transpose.length - 3
queries = this.queryEmbed(queries).reshape(queriesReshape).transpose(transpose)
keys = this.keyEmbed(keys).reshape(keysValuesReshape).transpose(transpose)
values = this.valueEmbed(values).reshape(keysValuesReshape).transpose(transpose)
// embed shape [..., {key|query}Tokens, heads * attentionDim]
// reshape shape [..., {key|query}Tokens, heads, attentionDim]
// transpose shape [..., heads, {key|query}Tokens, attentionDim]
const reverseTranspose = transpose
const reverseReshape = [...originalShape]
reverseReshape[reverseReshape.length - 1] = this.heads * this.attentionDim
let output = this.attention(queries, keys, values, mask) // shape [..., heads, queryTokens, attentionDim]
output = output.transpose(reverseTranspose) // shape [..., queryTokens, heads, attentionDim]
output = output.reshape(reverseReshape) // shape [..., queryTokens, heads * attentionDim]
output = this.concatEmbed(output) // shape [..., queryTokens, dim]
return output
}
}
class FeedForward extends Module {
private dim: number
private hiddenDim: number
private affineIn: Linear
private affineOut: Linear
constructor(dim: number, hiddenDim?: number) {
super()
this.dim = dim
if (hiddenDim === undefined) {
this.hiddenDim = dim
} else {
this.hiddenDim = hiddenDim
}
this.affineIn = new Linear(this.dim, this.hiddenDim)
this.affineOut = new Linear(this.hiddenDim, this.dim)
}
forward(input: Tensor): Tensor {
// shape [..., dim]
let output = this.affineIn(input).relu() // shape [..., hiddenDim]
output = this.affineOut(output) // shape [..., dim]
return output
}
}
/**
* A layer of the Transformer encoder, as described by Vaswani et al, consisting of a {@link TransformerMultiheadAttention | multi-head attention} layer and a fully-connected feed forward network. Both of these use residual connections and are normalised with {@link LayerNorm}.
*/
export class TransformerEncoderLayer extends Module {
private dim: number
private heads: number
private attentionDim: number
private feedForwardDim: number
private mha: TransformerMultiheadAttention
private mhaNorm: LayerNorm
private ff: FeedForward
private ffNorm: LayerNorm
/**
* @param dim - Number of dimensions of the input embeddings
* @param heads - Number of heads in the multi-head attention mechanism
* @param attentionDim - Number of dimensions of the embeddings which are passed to the scaled dot-product attention mechanisms, or `dim` if not specified
* @param feedForwardDim - Number of dimensions in the hidden layer of the feed forward network, or `dim` if not specified
*/
constructor(dim: number, heads: number, attentionDim?: number, feedForwardDim?: number) {
super()
this.dim = dim
this.heads = heads
if (attentionDim === undefined) {
this.attentionDim = dim
} else {
this.attentionDim = attentionDim
}
if (feedForwardDim === undefined) {
this.feedForwardDim = dim
} else {
this.feedForwardDim = feedForwardDim
}
this.mha = new TransformerMultiheadAttention(this.dim, this.heads, this.attentionDim)
this.mhaNorm = new LayerNorm([this.dim])
this.ff = new FeedForward(this.dim, this.feedForwardDim)
this.ffNorm = new LayerNorm([this.dim])
}
/**
* @param input - Input Tensor of shape `[..., tokens, dim]`
* @returns A Tensor of shape `[..., tokens, dim]`
*/
forward(input: Tensor): Tensor {
// shape [..., tokens, dim]
let mhaOutput = this.mha(input, input, input) // shape [..., tokens, dim]
mhaOutput = this.mhaNorm(input.add(mhaOutput))
let ffOutput = this.ff(mhaOutput) // shape [..., tokens, dim]
ffOutput = this.ffNorm(mhaOutput.add(ffOutput))
return ffOutput
}
}
/**
* Transformer encoder as described by Vaswani et al containing an arbitrary number of {@link TransformerEncoderLayer | TransformerEncoderLayers}.
*
* This module includes the {@link TransformerPositionalEncoding | positional encoding}, but does not include any initial embedding of an input sequence into vectors (which should have been separately done by e.g. word2vec).
*/
export class TransformerEncoder extends Module {
private dim: number
private heads: number
private depth: number
private attentionDim: number
private feedForwardDim: number
private positional: TransformerPositionalEncoding
private layers: Sequential
/**
* @param dim - Number of dimensions of the input embeddings
* @param heads - Number of heads in each multi-head attention mechanism
* @param depth - Number of encoder layers
* @param attentionDim - Number of dimensions of the embeddings which are passed to the scaled dot-product attention mechanisms, or `dim` if not specified
* @param feedForwardDim - Number of dimensions in the hidden layer of each feed forward network, or `dim` if not specified
* @param initSequenceLength - Initial sequence length that the positional encoding should be computed for, or {@link TransformerPositionalEncoding.DEFAULT_SEQUENCE_LENGTH} if not specified
*/
constructor(
dim: number,
heads: number,
depth: number,
attentionDim?: number,
feedForwardDim?: number,
initSequenceLength?: number
) {
super()
this.dim = dim
this.heads = heads
this.depth = depth
if (attentionDim === undefined) {
this.attentionDim = dim
} else {
this.attentionDim = attentionDim
}
if (feedForwardDim === undefined) {
this.feedForwardDim = dim
} else {
this.feedForwardDim = feedForwardDim
}
if (initSequenceLength === undefined) {
this.positional = new TransformerPositionalEncoding(this.dim)
} else {
this.positional = new TransformerPositionalEncoding(this.dim, initSequenceLength)
}
const layers: TransformerEncoderLayer[] = []
for (let i = 0; i < this.depth; i++) {
layers.push(
new TransformerEncoderLayer(this.dim, this.heads, this.attentionDim, this.feedForwardDim)
)
}
this.layers = new Sequential(...layers)
}
/**
* @param input - Input Tensor of shape `[..., tokens, dim]`
* @returns A Tensor of shape `[..., tokens, dim]`
*/
forward(input: Tensor): Tensor {
// shape [..., tokens, dim]
const positionalEncoding = this.positional(input.shape[input.shape.length - 2]) // shape [tokens, dim]
let output = input.add(positionalEncoding) // shape [..., tokens, dim]
output = this.layers(output) // shape [..., tokens, dim]
return output
}
}
/**
* A layer of the Transformer decoder, as described by Vaswani et al, consisting of a masked {@link TransformerMultiheadAttention | multi-head} self-attention layer, an unmasked {@link TransformerMultiheadAttention | multi-head} cross-attention layer and a fully-connected feed forward network. All of these use residual connections and are normalised with {@link LayerNorm}.
*/
export class TransformerDecoderLayer extends Module {
private dim: number
private heads: number
private attentionDim: number
private feedForwardDim: number
private maskedSelfAttention: TransformerMultiheadAttention
private maskedSelfAttentionNorm: LayerNorm
private crossAttention: TransformerMultiheadAttention
private crossAttentionNorm: LayerNorm
private ff: FeedForward
private ffNorm: LayerNorm
/**
* @param dim - Number of dimensions of the input embeddings
* @param heads - Number of heads in each multi-head attention mechanism
* @param attentionDim - Number of dimensions of the embeddings which are passed to the scaled dot-product attention mechanisms, or `dim` if not specified
* @param feedForwardDim - Number of dimensions in the hidden layer of the feed forward network, or `dim` if not specified
*/
constructor(dim: number, heads: number, attentionDim?: number, feedForwardDim?: number) {
super()
this.dim = dim
this.heads = heads
if (attentionDim === undefined) {
this.attentionDim = dim
} else {
this.attentionDim = attentionDim
}
if (feedForwardDim === undefined) {
this.feedForwardDim = dim
} else {
this.feedForwardDim = feedForwardDim
}
this.maskedSelfAttention = new TransformerMultiheadAttention(
this.dim,
this.heads,
this.attentionDim
)
this.maskedSelfAttentionNorm = new LayerNorm([this.dim])
this.crossAttention = new TransformerMultiheadAttention(this.dim, this.heads, this.attentionDim)
this.crossAttentionNorm = new LayerNorm([this.dim])
this.ff = new FeedForward(this.dim, this.feedForwardDim)
this.ffNorm = new LayerNorm([this.dim])
}
/**
* @param sequenceLength - Length of sequence for which the mask should be generated
* @returns A Tensor mask of shape `[sequenceLength, sequenceLength]` where row $i$ should have 0s in positions up to $i$ and 1s everywhere else
*/
static getSelfAttentionMask(sequenceLength: number): Tensor {
return sm
.full([sequenceLength, sequenceLength], 1)
.astype(sm.dtype.BoolInt8)
.tril()
.logicalNot()
}
/**
* @param input - Tensor from the previous decoder layer, shape `[..., tokens, dim]`
* @param encoderOutput - Tensor output by the encoder, shape `[..., encoderTokens, dim]`
* @returns A Tensor of shape `[..., tokens, dim]`
*/
forward(input: Tensor, encoderOutput: Tensor): Tensor {
const decoderLength = input.shape[input.shape.length - 2]
const mask = TransformerDecoderLayer.getSelfAttentionMask(decoderLength)
let residual = input
let output = this.maskedSelfAttention(input, input, input, mask) // shape [..., tokens, dim]
output = this.maskedSelfAttentionNorm(residual.add(output))
residual = output
output = this.crossAttention(output, encoderOutput, encoderOutput) // shape [..., tokens, dim]
output = this.crossAttentionNorm(residual.add(output))
residual = output
output = this.ff(output) // shape [..., tokens, dim]
output = this.ffNorm(residual.add(output))
return output
}
}
/**
* Transformer decoder as described by Vaswani et al containing an arbitrary number of {@link TransformerDecoderLayer | TransformerDecoderLayers}.
*/
export class TransformerDecoder extends Module {
private dim: number
private heads: number
private depth: number
private attentionDim: number
private feedForwardDim: number
private positional: TransformerPositionalEncoding
private layers: Sequential
/**
* @param dim - Number of dimensions of the input embeddings
* @param heads - Number of heads in each multi-head attention mechanism
* @param depth - Number of decoder layers
* @param attentionDim - Number of dimensions of the embeddings which are passed to the scaled dot-product mechanisms, or `dim` if not specified
* @param feedForwardDim - Number of dimensions in the hidden layer of each feed forward network, or `dim` if not specified
* @param initSequenceLength - Initial sequence length that the positional encoding should be computed for, or {@link TransformerPositionalEncoding.DEFAULT_SEQUENCE_LENGTH} if not specified
*/
constructor(
dim: number,
heads: number,
depth: number,
attentionDim?: number,
feedForwardDim?: number,
initSequenceLength?: number
) {
super()
this.dim = dim
this.heads = heads
this.depth = depth
if (attentionDim === undefined) {
this.attentionDim = dim
} else {
this.attentionDim = attentionDim
}
if (feedForwardDim === undefined) {
this.feedForwardDim = dim
} else {
this.feedForwardDim = feedForwardDim
}
if (initSequenceLength === undefined) {
this.positional = new TransformerPositionalEncoding(this.dim)
} else {
this.positional = new TransformerPositionalEncoding(this.dim, initSequenceLength)
}
const layers: CallableFunction[] = []
for (let i = 0; i < this.depth; i++) {
const layer = new TransformerDecoderLayer(
this.dim,
this.heads,
this.attentionDim,
this.feedForwardDim
)
layers.push((input: Tensor, encoderOutput: Tensor) => [
layer(input, encoderOutput),
encoderOutput
])
}
this.layers = new Sequential(...layers)
}
/**
* @param input - Input Tensor of shape `[..., tokens, dim]`
* @param encoderOutput - Tensor output by the encoder, shape `[..., encoderTokens, dim]`
* @returns A Tensor of shape `[..., tokens, dim]`
*/
forward(input: Tensor, encoderOutput: Tensor): Tensor {
const positionalEncoding = this.positional(input.shape[input.shape.length - 2]) // shape [tokens, dim]
let output = input.add(positionalEncoding) // shape [..., tokens, dim]
output = this.layers(output, encoderOutput)[0] // shape [..., tokens, dim]
return output
}
}