@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
57 lines • 2.44 kB
JavaScript
import { CPUTensor } from '../tensor/cpu/tensor';
import { TENSOR_FLOAT, TENSOR_INT64 } from './definitions';
// eslint-disable-next-line node/no-extraneous-import
import Long from 'long';
import { getSize } from '../util/shape';
export function createTensor(tensorProto, castFloats = false) {
if (tensorProto.segment !== undefined && tensorProto.segment !== null) {
throw new Error('Handling of tensor proto segment not yet implemented');
}
let shape = tensorProto.dims;
if (shape === undefined || shape === null) {
throw new Error('Tensor shape must be specified');
}
for (let i = 0; i < shape.length; i++) {
if (Long.isLong(shape[i])) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
shape[i] = shape[i].toNumber();
}
}
if (shape.length === 0) {
shape = [1];
}
const size = getSize(shape);
if (tensorProto.dataType === TENSOR_FLOAT) {
if (tensorProto.floatData && tensorProto.floatData.length > 0) {
return new CPUTensor(shape, tensorProto.floatData);
}
else if (tensorProto.rawData && tensorProto.rawData.length > 0) {
const buffer = tensorProto.rawData.buffer.slice(tensorProto.rawData.byteOffset, tensorProto.rawData.byteOffset + tensorProto.rawData.byteLength);
const values = new Float32Array(buffer);
return new CPUTensor(shape, values, castFloats ? 'float16' : 'float32');
}
else if (size === 0) {
return new CPUTensor(shape);
}
else {
throw new Error('Cant process float tensor without float or raw data');
}
}
else if (tensorProto.dataType === TENSOR_INT64) {
if (tensorProto.rawData && tensorProto.rawData.length > 0) {
const values = new Int32Array(tensorProto.rawData.length / 8);
for (let i = 0; i < tensorProto.rawData.length; i += 8) {
const value = Long.fromBytesLE(Array.from(tensorProto.rawData.slice(i, i + 8))).toNumber();
values[i / 8] = value;
}
return new CPUTensor(shape, values, 'int32');
}
else {
throw new Error('Cant process int64 tensor without raw data');
}
}
else {
throw new Error(`Handling of tensor type ${tensorProto.dataType} not yet implemented`);
}
}
//# sourceMappingURL=util.js.map