@shumai/shumai
Version:
A fast, network-connected, differentiable tensor library for TypeScript (and JavaScript). Built with bun + flashlight for software engineers and researchers alike.
202 lines (190 loc) • 6.59 kB
text/typescript
import { Buffer } from 'buffer'
import * as sm from '../tensor'
/** @private */
export function jsonStringifyHandler(key: string, value: any) {
if (typeof value === 'bigint') {
return ['bigint', value.toString()] // tuple helps avoid waste
}
return value
}
/** @private */
export function jsonParseHandler(key: string, value: any) {
if (Array.isArray(value) && value[0] === 'bigint') {
return BigInt(value[1])
}
return value
}
export function encodeBinary(tensor: sm.Tensor, props?: object): ArrayBuffer {
const shape = tensor.shape64
const provenance = tensor.provenance ? BigInt('0x' + tensor.provenance) : BigInt(0xffffffff)
const flags = Number(tensor.requires_grad) & 0x1
// meta_data: ndim, provenance, flags, props_len
const props_buf = props && Buffer.from(JSON.stringify(props, jsonStringifyHandler))
const props_len = props_buf ? props_buf.byteLength : 0
const tensor_buf = new Uint8Array(tensor.toFloat32Array().buffer)
const tensor_len = tensor_buf.byteLength
const meta_data = new BigInt64Array([
BigInt(shape.length),
provenance,
BigInt(flags),
BigInt(tensor_len),
BigInt(props_len)
])
const meta_data_buf = new Uint8Array(meta_data.buffer)
const shape_buf = new Uint8Array(new BigInt64Array(shape).buffer)
const buf = new Uint8Array(
meta_data_buf.length + shape_buf.length + tensor_buf.length + props_len
)
let byteOffset = 0
buf.set(meta_data_buf, byteOffset)
byteOffset += meta_data_buf.byteLength
buf.set(shape_buf, byteOffset)
byteOffset += shape_buf.byteLength
buf.set(tensor_buf, byteOffset)
byteOffset += tensor_buf.byteLength
if (props_buf) buf.set(props_buf, byteOffset)
return buf.buffer
}
export function decodeBinary(buf: ArrayBuffer): { tensor: sm.Tensor; props?: object } {
if (buf.byteLength < 16) {
throw 'buffer cannot be decoded, too short to parse'
}
// meta_data: ndim, provenance, flags
const meta_data_len = 5
const meta_data = new BigInt64Array(buf, 0, meta_data_len)
let byteOffset = 8 * meta_data_len
const shape_len = Number(meta_data[0])
const provenance = meta_data[1].toString(16)
const flags = Number(meta_data[2])
const tensor_len = Number(meta_data[3])
const props_len = Number(meta_data[4])
const requires_grad = flags & 0x1
const actual_tensor_len = buf.byteLength - 8 * meta_data_len - 8 * shape_len - props_len
if (tensor_len != actual_tensor_len) {
throw `buffer cannot be decoded, tensor expected ${tensor_len}B but received ${actual_tensor_len}B`
}
const shape = new BigInt64Array(buf, byteOffset, shape_len)
byteOffset += 8 * shape_len
const t = sm.tensor(new Float32Array(buf, byteOffset, tensor_len / 4)).reshape(shape)
byteOffset += tensor_len
const props = props_len
? JSON.parse(Buffer.from(buf, byteOffset, props_len).toString(), jsonParseHandler)
: void 0
t.op = 'network'
t.provenance = provenance ? provenance : null
t.requires_grad = !!requires_grad
return { tensor: t, props }
}
/** @private */
function encodeBase64Buffer(buf) {
const u8 = new Uint8Array(buf)
const b64encoded = btoa(String.fromCharCode.apply(null, u8))
return b64encoded
}
/** @private */
function decodeBase64Buffer(s) {
const blob = atob(s)
const buf = new ArrayBuffer(blob.length)
const dv = new DataView(buf)
for (let i = 0; i < blob.length; i++) {
dv.setUint8(i, blob.charCodeAt(i))
}
return buf
}
export function encodeBase64(tensor: sm.Tensor) {
return encodeBase64Buffer(encodeBinary(tensor))
}
export function decodeBase64(base64String: string) {
return decodeBinary(decodeBase64Buffer(base64String))
}
export function encodeReadable(tensor: sm.Tensor) {
function encodeArray(flat, shape, offset) {
if (shape.length === 0) {
return flat[0]
}
if (shape.length === 1) {
const out = []
for (let i = 0; i < shape[0]; ++i) {
out.push(flat[offset + i])
}
return `[${out.join(', ')}]`
}
const out = []
for (let i = 0; i < shape[0]; ++i) {
out.push(encodeArray(flat, shape.slice(1), offset + i * shape[1]))
}
return `[${out.join(', ')}]`
}
const array = encodeArray(tensor.toFloat32Array(), tensor.shape, 0)
const dtype = sm.typeToString(tensor.dtype)
return `${array}:${dtype}`
}
export function decodeReadable(readableString: string) {
const [arrayString, typeString] = readableString.split(':')
let dtype: sm.dtype = sm.dtype.Float32
if (typeString !== undefined) {
dtype = sm.stringToType(typeString)
}
// returns array, shape and characters parsed
function parseTensor(s: string, offset = 0): [number[], number[], number] {
const total_chars = s.length
let idx = 1
if (!s.includes(']') && !s.includes(',')) {
return [[Number.parseFloat(s)], [], 0]
}
if (s[offset] !== '[') {
throw `Invalid string passed into Tensor parser: found '${s[offset + idx]}', expected '['`
}
let array: number[] = []
let outer_shape = 0
let inner_shape: number[] | null | 0 = null
let cur_chars = ''
const digest = () => {
if (cur_chars.length) {
if (inner_shape !== null && inner_shape !== 0) {
throw `Invalid array (ragged not supported)`
}
inner_shape = 0
array.push(Number(cur_chars))
cur_chars = ''
}
}
while (idx < total_chars) {
const char = s[offset + idx]
if (char === ',' || char === ' ') {
digest()
idx += 1
continue
} else if (char === '[') {
const [_array, _shape, _idx] = parseTensor(s, offset + idx)
idx += _idx
const same = (a, b) => a.length === b.length && a.every((e, i) => e === b[i])
if (inner_shape !== null && !same(inner_shape, _shape)) {
throw `Invalid array (ragged not supported): found inner dim == ${_shape}, expected ${inner_shape}`
}
outer_shape += 1
inner_shape = _shape
array = array.concat(_array)
} else if (char === ']') {
digest()
let shape = [array.length]
if (inner_shape !== 0) {
shape = [outer_shape, ...inner_shape]
}
return [array, shape, idx + 1]
} else {
cur_chars += char
idx += 1
}
}
throw `Never found closing ']'`
}
const [array, shape] = parseTensor(arrayString.trim())
// TODO: this is an issue for larger types
const native_array = new Float32Array(array)
const t = sm.tensor(native_array).reshape(shape)
if (dtype !== sm.dtype.Float32) {
return t.astype(dtype)
}
return t
}