UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

477 lines 19.2 kB
import Tensor, { tensorValuesConstructor, } from '../../types'; import { compareShapes } from '../../util/shape'; let WASMTF32; let WASMTF64; let WASMTI32; let WASMTI16; let WASMTI8; let WASMTU32; let WASMTU16; let WASMTU8; export let tensorConstructor; export const wasmLoaded = new Promise(resolve => { import('../../wasm/rust_wasm_tensor').then(x => { WASMTF32 = x.TensorF32; WASMTF64 = x.TensorF64; WASMTI32 = x.TensorI32; WASMTI16 = x.TensorI16; WASMTI8 = x.TensorI8; WASMTU32 = x.TensorU32; WASMTU16 = x.TensorU16; WASMTU8 = x.TensorU8; tensorConstructor = { float64: WASMTF64, float32: WASMTF32, int32: WASMTI32, int16: WASMTI16, int8: WASMTI8, uint32: WASMTU32, uint16: WASMTU16, uint8: WASMTU8, }; resolve(); }); }); export class WASMTensor extends Tensor { constructor(values, shape, dtype) { super(dtype || 'float32'); if (values instanceof Array) { if (shape === undefined) { throw new Error('Need the shape when creating a Wasm tensor from values'); } const array = new tensorValuesConstructor[this.dtype](values); this.wasmTensor = tensorConstructor[this.dtype].create(shape, array); } else { this.wasmTensor = values; } } static range(start, limit, delta) { const size = Math.max(Math.ceil((limit - start) / delta), 0); const values = new Array(size); for (let i = 0; i < size; i++) { values[i] = start + i * delta; } return new WASMTensor(values, new Uint32Array([size])); } cast(dtype) { throw new Error('Method not implemented.'); } getValues() { return Promise.resolve(this.wasmTensor.get_vals()); } getShape() { return Array.from(this.wasmTensor.get_shape()); } constantLike(value) { // TODO: Maybe more efficient in WASM? return new WASMTensor([value], this.wasmTensor.get_shape(), this.dtype); } singleConstant(value) { return new WASMTensor([value], new Uint32Array([1]), this.dtype); } delete() { if (this.wasmTensor !== undefined) { this.wasmTensor.free(); //@ts-ignore this.wasmTensor = undefined; } } copy() { return new WASMTensor(this.wasmTensor.copy(), undefined, this.dtype); } exp() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Exp can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.exp()); } log() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Log can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.log()); } sqrt() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Sqrt can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.sqrt()); } abs() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32) && !(this.wasmTensor instanceof WASMTI32) && !(this.wasmTensor instanceof WASMTI16) && !(this.wasmTensor instanceof WASMTI8)) { throw new Error('Abs can only be called on signed tensors'); } return new WASMTensor(this.wasmTensor.abs()); } sin() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Sin can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.sin()); } cos() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Cos can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.cos()); } tan() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Tan can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.tan()); } asin() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Asin can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.asin()); } acos() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Acos can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.acos()); } atan() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Atan can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.atan()); } sinh() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Sinh can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.sinh()); } cosh() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Cosh can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.cosh()); } tanh() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Tanh can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.tanh()); } asinh() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Asinh can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.asinh()); } acosh() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Acosh can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.acosh()); } atanh() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Atanh can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.atanh()); } sigmoid() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Sigmoid can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.sigmoid()); } hardSigmoid(alpha, beta) { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('HardSigmoid can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.hard_sigmoid(alpha, beta)); } negate() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32) && !(this.wasmTensor instanceof WASMTI32) && !(this.wasmTensor instanceof WASMTI16) && !(this.wasmTensor instanceof WASMTI8)) { throw new Error('Negate can only be called on signed tensors'); } return new WASMTensor(this.wasmTensor.negate()); } powerScalar(power, factor) { return new WASMTensor(this.wasmTensor.power_scalar(power, factor)); } addMultiplyScalar(factor, add) { return new WASMTensor(this.wasmTensor.add_multiply_scalar(factor, add)); } sign() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32) && !(this.wasmTensor instanceof WASMTI32) && !(this.wasmTensor instanceof WASMTI16) && !(this.wasmTensor instanceof WASMTI8)) { throw new Error('Sign can only be called on signed tensors'); } return new WASMTensor(this.wasmTensor.sign()); } setValues(values, starts) { if (!(values instanceof WASMTensor)) { throw new Error('Can only set WASM values to WASM values'); } return new WASMTensor(this.wasmTensor.set_values(values.wasmTensor, new Uint32Array(starts))); } add_impl(th, tensor, // eslint-disable-next-line @typescript-eslint/no-unused-vars _resultShape, alpha, beta) { if (!(tensor instanceof WASMTensor) || !(th instanceof WASMTensor)) { throw new Error('Can only add WASM tensor to WASM tensor'); } return new WASMTensor(th.wasmTensor.addition(tensor.wasmTensor, alpha, beta)); } subtract_impl(th, tensor, // eslint-disable-next-line @typescript-eslint/no-unused-vars resultShape, alpha, beta) { if (!(tensor instanceof WASMTensor) || !(th instanceof WASMTensor)) { throw new Error('Can only subtract WASM tensor from WASM tensor'); } return new WASMTensor(th.wasmTensor.subtraction(tensor.wasmTensor, alpha, beta)); } multiply_impl(th, tensor, // eslint-disable-next-line @typescript-eslint/no-unused-vars resultShape, alpha) { if (!(tensor instanceof WASMTensor) || !(th instanceof WASMTensor)) { throw new Error('Can only multiply WASM tensor with WASM tensor'); } return new WASMTensor(th.wasmTensor.multiply(tensor.wasmTensor, alpha)); } divide_impl(th, tensor, // eslint-disable-next-line @typescript-eslint/no-unused-vars resultShape, alpha) { if (!(tensor instanceof WASMTensor) || !(th instanceof WASMTensor)) { throw new Error('Can only divide WASM tensor by WASM tensor'); } return new WASMTensor(th.wasmTensor.divide(tensor.wasmTensor, alpha)); } power_impl(th, tensor, // eslint-disable-next-line @typescript-eslint/no-unused-vars resultShape) { if (!(tensor instanceof WASMTensor) || !(th instanceof WASMTensor)) { throw new Error('Can only take WASM tensor to power of WASM tensor'); } return new WASMTensor(th.wasmTensor.power(tensor.wasmTensor)); } matMul(tensor) { if (!(tensor instanceof WASMTensor)) { throw new Error('Can only add WASM tensor to WASM tensor'); } return new WASMTensor(this.wasmTensor.matmul(tensor.wasmTensor)); } gemm_impl(b, aTranspose, bTranspose, alpha, beta, c) { if (!(b instanceof WASMTensor && (c === undefined || c instanceof WASMTensor))) { throw new Error('Can only do gemm with CPU tensors'); } if (c !== undefined) { return new WASMTensor(this.wasmTensor.gemm_with_c(b.wasmTensor, aTranspose, bTranspose, alpha, c.wasmTensor, beta)); } else { return new WASMTensor(this.wasmTensor.gemm(b.wasmTensor, aTranspose, bTranspose, alpha)); } } sum_impl(axes, keepDims) { return new WASMTensor(this.wasmTensor.sum(new Uint32Array(axes), keepDims)); } sumSquare_impl(axes, keepDims) { return new WASMTensor(this.wasmTensor.sum_square(new Uint32Array(axes), keepDims)); } product_impl(axes, keepDims) { return new WASMTensor(this.wasmTensor.product(new Uint32Array(axes), keepDims)); } max_impl(axes, keepDims) { return new WASMTensor(this.wasmTensor.max(new Uint32Array(axes), keepDims)); } min_impl(axes, keepDims) { return new WASMTensor(this.wasmTensor.min(new Uint32Array(axes), keepDims)); } reduceMean_impl(axes, keepDims) { return new WASMTensor(this.wasmTensor.reduce_mean(new Uint32Array(axes), keepDims)); } reduceMeanSquare_impl(axes, keepDims) { return new WASMTensor(this.wasmTensor.reduce_mean_square(new Uint32Array(axes), keepDims)); } reduceLogSum_impl(axes, keepDims) { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('ReduceLogSum can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.reduce_log_sum(new Uint32Array(axes), keepDims)); } reduceLogSumExp_impl(axes, keepDims) { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('ReduceLogSumExp can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.reduce_log_sum_exp(new Uint32Array(axes), keepDims)); } getActivationFlag(activation) { if (activation === 'id') { return 0; } else if (activation === 'relu') { return 1; } else { return 2; } } conv_impl(kernel, dilations, group, pads, strides, activation, bias) { if (!(kernel instanceof WASMTensor) || (bias !== undefined && !(bias instanceof WASMTensor))) { throw new Error('Can only do convolution of WASM tensor with WASM tensor'); } const activationFlag = this.getActivationFlag(activation); if (bias !== undefined) { return new WASMTensor(this.wasmTensor.conv_with_bias(kernel.wasmTensor, bias.wasmTensor, new Uint32Array(dilations), group, new Uint32Array(pads), new Uint32Array(strides), activationFlag)); } else { return new WASMTensor(this.wasmTensor.conv(kernel.wasmTensor, new Uint32Array(dilations), group, new Uint32Array(pads), new Uint32Array(strides), activationFlag)); } } convTranspose_impl(kernel, dilations, group, pads, strides) { if (!(kernel instanceof WASMTensor)) { throw new Error('Can only do transpose convolution of WASM tensor with WASM tensor'); } return new WASMTensor(this.wasmTensor.conv_transpose(kernel.wasmTensor, new Uint32Array(dilations), group, new Uint32Array(pads), new Uint32Array(strides))); } averagePool_impl(kernelShape, pads, strides, includePad) { return new WASMTensor(this.wasmTensor.average_pool(new Uint32Array(kernelShape), new Uint32Array(pads), new Uint32Array(strides), includePad)); } reshape_impl(shape) { const sh = new Uint32Array(shape); return new WASMTensor(this.wasmTensor.reshape(sh), sh); } concat(tensor, axis) { if (!(tensor instanceof WASMTensor)) { throw new Error('Can only concat WASM tensor to WASM tensor'); } if (axis < 0) { axis += this.getShape().length; } return new WASMTensor(this.wasmTensor.concat(tensor.wasmTensor, axis)); } transpose_impl(permutation) { return new WASMTensor(this.wasmTensor.transpose(new Uint32Array(permutation))); } clip(min, max) { if (min !== undefined && max !== undefined) { return new WASMTensor(this.wasmTensor.clip(min, max)); } else if (min !== undefined) { return new WASMTensor(this.wasmTensor.clip_min(min)); } else if (max !== undefined) { return new WASMTensor(this.wasmTensor.clip_max(max)); } return this.copy(); } clipBackward(grad, min, max) { if (!(grad instanceof WASMTensor)) { throw new Error('Can only do grad backward with Wasm tensor'); } if (min !== undefined && max !== undefined) { return new WASMTensor(this.wasmTensor.clip_backward(min, max, grad.wasmTensor)); } else if (min !== undefined) { return new WASMTensor(this.wasmTensor.clip_min_backward(min, grad.wasmTensor)); } else if (max !== undefined) { return new WASMTensor(this.wasmTensor.clip_max_backward(max, grad.wasmTensor)); } return this.copy(); } repeat(repeats) { return new WASMTensor(this.wasmTensor.repeat(new Uint32Array(repeats))); } expand(shape) { const thisShape = this.getShape(); // eslint-disable-next-line @typescript-eslint/no-unused-vars const [_shape, goal, resultShape] = this.alignShapes(thisShape, shape); if (compareShapes(thisShape, resultShape)) { return this.copy(); } const reshaped = this.reshape(_shape, false); return new WASMTensor(reshaped.wasmTensor.expand(new Uint32Array(resultShape))); } pad_impl(pads, mode, value) { return new WASMTensor(this.wasmTensor.pad(new Uint32Array(pads), WASMTensor.padModeToInt[mode], value)); } gather(axis, indices) { return new WASMTensor(this.wasmTensor.gather(axis, indices.values, new Uint32Array(indices.shape))); } floor() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Floor can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.floor()); } ceil() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Ceil can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.ceil()); } round() { if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Round can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.round()); } slice_impl(starts, ends, axes, steps) { return new WASMTensor(this.wasmTensor.slice(new Uint32Array(starts), new Uint32Array(ends), new Uint32Array(axes), new Int32Array(steps))); } upsample(scales) { return new WASMTensor(this.wasmTensor.upsample(new Float32Array(scales))); } normalize(mean, variance, epsilon, scale, bias) { if (!(mean instanceof WASMTensor) || !(variance instanceof WASMTensor) || !(scale instanceof WASMTensor) || !(bias instanceof WASMTensor)) { throw new Error('Can only normalize with WASM tensors'); } if (!(this.wasmTensor instanceof WASMTF64) && !(this.wasmTensor instanceof WASMTF32)) { throw new Error('Normalize can only be called on float tensors'); } return new WASMTensor(this.wasmTensor.normalize(mean.wasmTensor, variance.wasmTensor, epsilon, scale.wasmTensor, bias.wasmTensor)); } } WASMTensor.padModeToInt = { constant: 0, reflect: 1, edge: 2, }; //# sourceMappingURL=tensor.js.map