@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
58 lines • 2.38 kB
JavaScript
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
return new (P || (P = Promise))(function (resolve, reject) {
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
};
import { CPUTensor } from '../../tensor/cpu/tensor';
import { GPUTensor } from '../../tensor/gpu/tensor';
import { WASMTensor } from '../../tensor/wasm/tensor';
import { OnnxNode } from '../node';
export class RangeNode extends OnnxNode {
constructor(attributes, inputs, outputs, constants, onnxVersion, mode) {
super(attributes, inputs, outputs, constants, onnxVersion, mode);
this.backend = 'CPU';
}
forward(inputs) {
return __awaiter(this, void 0, void 0, function* () {
const start = inputs[0];
const limit = inputs[1];
const delta = inputs[2];
const startValue = (yield this.toValues(start))[0];
const limitValue = (yield this.toValues(limit))[0];
const deltaValue = (yield this.toValues(delta))[0];
if (this.backend === 'CPU') {
return [CPUTensor.range(startValue, limitValue, deltaValue)];
}
else if (this.backend === 'WASM') {
return [WASMTensor.range(startValue, limitValue, deltaValue)];
}
else {
return [GPUTensor.range(startValue, limitValue, deltaValue)];
}
});
}
delete() { }
getType() {
return 'Range';
}
toCPU() {
return __awaiter(this, void 0, void 0, function* () {
this.backend = 'CPU';
});
}
toWASM() {
return __awaiter(this, void 0, void 0, function* () {
this.backend = 'WASM';
});
}
toGPU() {
return __awaiter(this, void 0, void 0, function* () {
this.backend = 'GPU';
});
}
}
//# sourceMappingURL=range.js.map