@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
136 lines • 4.51 kB
JavaScript
import { CPUTensor } from '../../tensor/cpu/tensor';
import { checkEquivShapes, incrementIndex } from '../../util/shape';
export function positionWiseUnaryOp(a, op) {
const result = new CPUTensor(a.shape, undefined, a.dtype);
for (let i = 0; i < result.size; i += 1) {
result.set(i, op(a.get(i)));
}
return result;
}
export function positionWiseBinaryOp(a, b, op, resultShape) {
if (!checkEquivShapes(a.shape, b.shape)) {
throw new Error('The shapes of the two tensors should be the same for a binary operation');
}
const result = new CPUTensor(resultShape, undefined, a.dtype);
const index = new Array(resultShape.length).fill(0);
for (let i = 0; i < result.size; i += 1) {
result.set(index, op(a.get(index), b.get(index)));
incrementIndex(index, resultShape);
}
return result;
}
export function exp(a) {
return positionWiseUnaryOp(a, o1 => Math.exp(o1));
}
export function log(a) {
return positionWiseUnaryOp(a, o1 => Math.log(o1));
}
export function sqrt(a) {
return positionWiseUnaryOp(a, o1 => Math.sqrt(o1));
}
export function abs(a) {
return positionWiseUnaryOp(a, o1 => Math.abs(o1));
}
export function sin(a) {
return positionWiseUnaryOp(a, o1 => Math.sin(o1));
}
export function cos(a) {
return positionWiseUnaryOp(a, o1 => Math.cos(o1));
}
export function tan(a) {
return positionWiseUnaryOp(a, o1 => Math.tan(o1));
}
export function asin(a) {
return positionWiseUnaryOp(a, o1 => Math.asin(o1));
}
export function acos(a) {
return positionWiseUnaryOp(a, o1 => Math.acos(o1));
}
export function atan(a) {
return positionWiseUnaryOp(a, o1 => Math.atan(o1));
}
export function sinh(a) {
return positionWiseUnaryOp(a, o1 => Math.sinh(o1));
}
export function cosh(a) {
return positionWiseUnaryOp(a, o1 => Math.cosh(o1));
}
export function tanh(a) {
return positionWiseUnaryOp(a, o1 => Math.tanh(o1));
}
export function asinh(a) {
return positionWiseUnaryOp(a, o1 => Math.asinh(o1));
}
export function acosh(a) {
return positionWiseUnaryOp(a, o1 => Math.acosh(o1));
}
export function atanh(a) {
return positionWiseUnaryOp(a, o1 => Math.atanh(o1));
}
export function floor(a) {
return positionWiseUnaryOp(a, o1 => Math.floor(o1));
}
export function ceil(a) {
return positionWiseUnaryOp(a, o1 => Math.ceil(o1));
}
export function round(a) {
return positionWiseUnaryOp(a, o1 => Math.round(o1));
}
export function sign(a) {
return positionWiseUnaryOp(a, o1 => (o1 < 0 ? -1 : o1 === 0 ? 0 : 1));
}
export function negate(a) {
return positionWiseUnaryOp(a, o1 => -o1);
}
export function powerScalar(a, power, factor) {
return positionWiseUnaryOp(a, o1 => Math.pow(o1, power) * factor);
}
export function addMultiplyScalar(a, factor, add) {
return positionWiseUnaryOp(a, o1 => o1 * factor + add);
}
export function sigmoid(a) {
return positionWiseUnaryOp(a, o1 => 1 / (1 + Math.exp(-o1)));
}
export function hardSigmoid(a, alpha, beta) {
return positionWiseUnaryOp(a, o1 => Math.max(0, Math.min(1, alpha * o1 + beta)));
}
export function clip(a, min, max) {
let f = (o1) => o1;
if (min !== undefined && max !== undefined) {
f = (o1) => Math.min(max, Math.max(min, o1));
}
else if (max !== undefined) {
f = (o1) => Math.min(max, o1);
}
else if (min !== undefined) {
f = (o1) => Math.max(min, o1);
}
return positionWiseUnaryOp(a, f);
}
export function add(a, b, resultShape, alpha, beta) {
return positionWiseBinaryOp(a, b, (o1, o2) => o1 * alpha + o2 * beta, resultShape);
}
export function subtract(a, b, resultShape, alpha, beta) {
return positionWiseBinaryOp(a, b, (o1, o2) => o1 * alpha - o2 * beta, resultShape);
}
export function multiply(a, b, resultShape, alpha) {
return positionWiseBinaryOp(a, b, (o1, o2) => o1 * o2 * alpha, resultShape);
}
export function divide(a, b, resultShape, alpha) {
return positionWiseBinaryOp(a, b, (o1, o2) => (o1 / o2) * alpha, resultShape);
}
export function power(a, b, resultShape) {
return positionWiseBinaryOp(a, b, (o1, o2) => Math.pow(o1, o2), resultShape);
}
export function clipBackward(value, grad, resultShape, min, max) {
return positionWiseBinaryOp(value, grad, (v, g) => {
if (min !== undefined && v < min) {
return 0;
}
if (max !== undefined && v > max) {
return 0;
}
return g;
}, resultShape);
}
//# sourceMappingURL=basic.js.map