@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
477 lines • 19.2 kB
JavaScript
import Tensor, { tensorValuesConstructor, } from '../../types';
import { compareShapes } from '../../util/shape';
let WASMTF32;
let WASMTF64;
let WASMTI32;
let WASMTI16;
let WASMTI8;
let WASMTU32;
let WASMTU16;
let WASMTU8;
export let tensorConstructor;
export const wasmLoaded = new Promise(resolve => {
import('../../wasm/rust_wasm_tensor').then(x => {
WASMTF32 = x.TensorF32;
WASMTF64 = x.TensorF64;
WASMTI32 = x.TensorI32;
WASMTI16 = x.TensorI16;
WASMTI8 = x.TensorI8;
WASMTU32 = x.TensorU32;
WASMTU16 = x.TensorU16;
WASMTU8 = x.TensorU8;
tensorConstructor = {
float64: WASMTF64,
float32: WASMTF32,
int32: WASMTI32,
int16: WASMTI16,
int8: WASMTI8,
uint32: WASMTU32,
uint16: WASMTU16,
uint8: WASMTU8,
};
resolve();
});
});
export class WASMTensor extends Tensor {
constructor(values, shape, dtype) {
super(dtype || 'float32');
if (values instanceof Array) {
if (shape === undefined) {
throw new Error('Need the shape when creating a Wasm tensor from values');
}
const array = new tensorValuesConstructor[this.dtype](values);
this.wasmTensor = tensorConstructor[this.dtype].create(shape, array);
}
else {
this.wasmTensor = values;
}
}
static range(start, limit, delta) {
const size = Math.max(Math.ceil((limit - start) / delta), 0);
const values = new Array(size);
for (let i = 0; i < size; i++) {
values[i] = start + i * delta;
}
return new WASMTensor(values, new Uint32Array([size]));
}
cast(dtype) {
throw new Error('Method not implemented.');
}
getValues() {
return Promise.resolve(this.wasmTensor.get_vals());
}
getShape() {
return Array.from(this.wasmTensor.get_shape());
}
constantLike(value) {
// TODO: Maybe more efficient in WASM?
return new WASMTensor([value], this.wasmTensor.get_shape(), this.dtype);
}
singleConstant(value) {
return new WASMTensor([value], new Uint32Array([1]), this.dtype);
}
delete() {
if (this.wasmTensor !== undefined) {
this.wasmTensor.free();
//@ts-ignore
this.wasmTensor = undefined;
}
}
copy() {
return new WASMTensor(this.wasmTensor.copy(), undefined, this.dtype);
}
exp() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Exp can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.exp());
}
log() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Log can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.log());
}
sqrt() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Sqrt can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.sqrt());
}
abs() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32) &&
!(this.wasmTensor instanceof WASMTI32) &&
!(this.wasmTensor instanceof WASMTI16) &&
!(this.wasmTensor instanceof WASMTI8)) {
throw new Error('Abs can only be called on signed tensors');
}
return new WASMTensor(this.wasmTensor.abs());
}
sin() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Sin can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.sin());
}
cos() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Cos can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.cos());
}
tan() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Tan can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.tan());
}
asin() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Asin can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.asin());
}
acos() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Acos can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.acos());
}
atan() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Atan can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.atan());
}
sinh() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Sinh can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.sinh());
}
cosh() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Cosh can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.cosh());
}
tanh() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Tanh can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.tanh());
}
asinh() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Asinh can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.asinh());
}
acosh() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Acosh can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.acosh());
}
atanh() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Atanh can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.atanh());
}
sigmoid() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Sigmoid can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.sigmoid());
}
hardSigmoid(alpha, beta) {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('HardSigmoid can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.hard_sigmoid(alpha, beta));
}
negate() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32) &&
!(this.wasmTensor instanceof WASMTI32) &&
!(this.wasmTensor instanceof WASMTI16) &&
!(this.wasmTensor instanceof WASMTI8)) {
throw new Error('Negate can only be called on signed tensors');
}
return new WASMTensor(this.wasmTensor.negate());
}
powerScalar(power, factor) {
return new WASMTensor(this.wasmTensor.power_scalar(power, factor));
}
addMultiplyScalar(factor, add) {
return new WASMTensor(this.wasmTensor.add_multiply_scalar(factor, add));
}
sign() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32) &&
!(this.wasmTensor instanceof WASMTI32) &&
!(this.wasmTensor instanceof WASMTI16) &&
!(this.wasmTensor instanceof WASMTI8)) {
throw new Error('Sign can only be called on signed tensors');
}
return new WASMTensor(this.wasmTensor.sign());
}
setValues(values, starts) {
if (!(values instanceof WASMTensor)) {
throw new Error('Can only set WASM values to WASM values');
}
return new WASMTensor(this.wasmTensor.set_values(values.wasmTensor, new Uint32Array(starts)));
}
add_impl(th, tensor,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
_resultShape, alpha, beta) {
if (!(tensor instanceof WASMTensor) || !(th instanceof WASMTensor)) {
throw new Error('Can only add WASM tensor to WASM tensor');
}
return new WASMTensor(th.wasmTensor.addition(tensor.wasmTensor, alpha, beta));
}
subtract_impl(th, tensor,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
resultShape, alpha, beta) {
if (!(tensor instanceof WASMTensor) || !(th instanceof WASMTensor)) {
throw new Error('Can only subtract WASM tensor from WASM tensor');
}
return new WASMTensor(th.wasmTensor.subtraction(tensor.wasmTensor, alpha, beta));
}
multiply_impl(th, tensor,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
resultShape, alpha) {
if (!(tensor instanceof WASMTensor) || !(th instanceof WASMTensor)) {
throw new Error('Can only multiply WASM tensor with WASM tensor');
}
return new WASMTensor(th.wasmTensor.multiply(tensor.wasmTensor, alpha));
}
divide_impl(th, tensor,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
resultShape, alpha) {
if (!(tensor instanceof WASMTensor) || !(th instanceof WASMTensor)) {
throw new Error('Can only divide WASM tensor by WASM tensor');
}
return new WASMTensor(th.wasmTensor.divide(tensor.wasmTensor, alpha));
}
power_impl(th, tensor,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
resultShape) {
if (!(tensor instanceof WASMTensor) || !(th instanceof WASMTensor)) {
throw new Error('Can only take WASM tensor to power of WASM tensor');
}
return new WASMTensor(th.wasmTensor.power(tensor.wasmTensor));
}
matMul(tensor) {
if (!(tensor instanceof WASMTensor)) {
throw new Error('Can only add WASM tensor to WASM tensor');
}
return new WASMTensor(this.wasmTensor.matmul(tensor.wasmTensor));
}
gemm_impl(b, aTranspose, bTranspose, alpha, beta, c) {
if (!(b instanceof WASMTensor && (c === undefined || c instanceof WASMTensor))) {
throw new Error('Can only do gemm with CPU tensors');
}
if (c !== undefined) {
return new WASMTensor(this.wasmTensor.gemm_with_c(b.wasmTensor, aTranspose, bTranspose, alpha, c.wasmTensor, beta));
}
else {
return new WASMTensor(this.wasmTensor.gemm(b.wasmTensor, aTranspose, bTranspose, alpha));
}
}
sum_impl(axes, keepDims) {
return new WASMTensor(this.wasmTensor.sum(new Uint32Array(axes), keepDims));
}
sumSquare_impl(axes, keepDims) {
return new WASMTensor(this.wasmTensor.sum_square(new Uint32Array(axes), keepDims));
}
product_impl(axes, keepDims) {
return new WASMTensor(this.wasmTensor.product(new Uint32Array(axes), keepDims));
}
max_impl(axes, keepDims) {
return new WASMTensor(this.wasmTensor.max(new Uint32Array(axes), keepDims));
}
min_impl(axes, keepDims) {
return new WASMTensor(this.wasmTensor.min(new Uint32Array(axes), keepDims));
}
reduceMean_impl(axes, keepDims) {
return new WASMTensor(this.wasmTensor.reduce_mean(new Uint32Array(axes), keepDims));
}
reduceMeanSquare_impl(axes, keepDims) {
return new WASMTensor(this.wasmTensor.reduce_mean_square(new Uint32Array(axes), keepDims));
}
reduceLogSum_impl(axes, keepDims) {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('ReduceLogSum can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.reduce_log_sum(new Uint32Array(axes), keepDims));
}
reduceLogSumExp_impl(axes, keepDims) {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('ReduceLogSumExp can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.reduce_log_sum_exp(new Uint32Array(axes), keepDims));
}
getActivationFlag(activation) {
if (activation === 'id') {
return 0;
}
else if (activation === 'relu') {
return 1;
}
else {
return 2;
}
}
conv_impl(kernel, dilations, group, pads, strides, activation, bias) {
if (!(kernel instanceof WASMTensor) ||
(bias !== undefined && !(bias instanceof WASMTensor))) {
throw new Error('Can only do convolution of WASM tensor with WASM tensor');
}
const activationFlag = this.getActivationFlag(activation);
if (bias !== undefined) {
return new WASMTensor(this.wasmTensor.conv_with_bias(kernel.wasmTensor, bias.wasmTensor, new Uint32Array(dilations), group, new Uint32Array(pads), new Uint32Array(strides), activationFlag));
}
else {
return new WASMTensor(this.wasmTensor.conv(kernel.wasmTensor, new Uint32Array(dilations), group, new Uint32Array(pads), new Uint32Array(strides), activationFlag));
}
}
convTranspose_impl(kernel, dilations, group, pads, strides) {
if (!(kernel instanceof WASMTensor)) {
throw new Error('Can only do transpose convolution of WASM tensor with WASM tensor');
}
return new WASMTensor(this.wasmTensor.conv_transpose(kernel.wasmTensor, new Uint32Array(dilations), group, new Uint32Array(pads), new Uint32Array(strides)));
}
averagePool_impl(kernelShape, pads, strides, includePad) {
return new WASMTensor(this.wasmTensor.average_pool(new Uint32Array(kernelShape), new Uint32Array(pads), new Uint32Array(strides), includePad));
}
reshape_impl(shape) {
const sh = new Uint32Array(shape);
return new WASMTensor(this.wasmTensor.reshape(sh), sh);
}
concat(tensor, axis) {
if (!(tensor instanceof WASMTensor)) {
throw new Error('Can only concat WASM tensor to WASM tensor');
}
if (axis < 0) {
axis += this.getShape().length;
}
return new WASMTensor(this.wasmTensor.concat(tensor.wasmTensor, axis));
}
transpose_impl(permutation) {
return new WASMTensor(this.wasmTensor.transpose(new Uint32Array(permutation)));
}
clip(min, max) {
if (min !== undefined && max !== undefined) {
return new WASMTensor(this.wasmTensor.clip(min, max));
}
else if (min !== undefined) {
return new WASMTensor(this.wasmTensor.clip_min(min));
}
else if (max !== undefined) {
return new WASMTensor(this.wasmTensor.clip_max(max));
}
return this.copy();
}
clipBackward(grad, min, max) {
if (!(grad instanceof WASMTensor)) {
throw new Error('Can only do grad backward with Wasm tensor');
}
if (min !== undefined && max !== undefined) {
return new WASMTensor(this.wasmTensor.clip_backward(min, max, grad.wasmTensor));
}
else if (min !== undefined) {
return new WASMTensor(this.wasmTensor.clip_min_backward(min, grad.wasmTensor));
}
else if (max !== undefined) {
return new WASMTensor(this.wasmTensor.clip_max_backward(max, grad.wasmTensor));
}
return this.copy();
}
repeat(repeats) {
return new WASMTensor(this.wasmTensor.repeat(new Uint32Array(repeats)));
}
expand(shape) {
const thisShape = this.getShape();
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const [_shape, goal, resultShape] = this.alignShapes(thisShape, shape);
if (compareShapes(thisShape, resultShape)) {
return this.copy();
}
const reshaped = this.reshape(_shape, false);
return new WASMTensor(reshaped.wasmTensor.expand(new Uint32Array(resultShape)));
}
pad_impl(pads, mode, value) {
return new WASMTensor(this.wasmTensor.pad(new Uint32Array(pads), WASMTensor.padModeToInt[mode], value));
}
gather(axis, indices) {
return new WASMTensor(this.wasmTensor.gather(axis, indices.values, new Uint32Array(indices.shape)));
}
floor() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Floor can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.floor());
}
ceil() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Ceil can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.ceil());
}
round() {
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Round can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.round());
}
slice_impl(starts, ends, axes, steps) {
return new WASMTensor(this.wasmTensor.slice(new Uint32Array(starts), new Uint32Array(ends), new Uint32Array(axes), new Int32Array(steps)));
}
upsample(scales) {
return new WASMTensor(this.wasmTensor.upsample(new Float32Array(scales)));
}
normalize(mean, variance, epsilon, scale, bias) {
if (!(mean instanceof WASMTensor) ||
!(variance instanceof WASMTensor) ||
!(scale instanceof WASMTensor) ||
!(bias instanceof WASMTensor)) {
throw new Error('Can only normalize with WASM tensors');
}
if (!(this.wasmTensor instanceof WASMTF64) &&
!(this.wasmTensor instanceof WASMTF32)) {
throw new Error('Normalize can only be called on float tensors');
}
return new WASMTensor(this.wasmTensor.normalize(mean.wasmTensor, variance.wasmTensor, epsilon, scale.wasmTensor, bias.wasmTensor));
}
}
WASMTensor.padModeToInt = {
constant: 0,
reflect: 1,
edge: 2,
};
//# sourceMappingURL=tensor.js.map