keras-js
Version:
Run Keras models in the browser, with GPU support using WebGL
386 lines (309 loc) • 14.7 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", {
value: true
});
exports.default = void 0;
var _WebGL = require("./WebGL2");
var tensorUtils = _interopRequireWildcard(require("./utils/tensorUtils"));
var _ndarray = _interopRequireDefault(require("ndarray"));
var _ndarrayOps = _interopRequireDefault(require("ndarray-ops"));
var _ndarraySqueeze = _interopRequireDefault(require("ndarray-squeeze"));
function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) { var desc = Object.defineProperty && Object.getOwnPropertyDescriptor ? Object.getOwnPropertyDescriptor(obj, key) : {}; if (desc.get || desc.set) { Object.defineProperty(newObj, key, desc); } else { newObj[key] = obj[key]; } } } } newObj.default = obj; return newObj; } }
class Tensor {
constructor(data, shape, options = {}) {
this.arrayType = options.type || Float32Array;
if (data && data.length && (data instanceof this.arrayType || data instanceof Array)) {
tensorUtils.checkShape(data, shape);
if (data instanceof this.arrayType) {
this.tensor = (0, _ndarray.default)(data, shape);
} else if (data instanceof Array) {
this.tensor = (0, _ndarray.default)(new this.arrayType(data), shape);
}
} else if (!data.length && shape.length) {
this.tensor = (0, _ndarray.default)(new this.arrayType(shape.reduce((a, b) => a * b, 1)), shape);
} else {
this.tensor = (0, _ndarray.default)(new this.arrayType([]), []);
}
}
createGLTexture({
type = '2d',
format = 'float',
supportsTextureFragments = false
}) {
let shape = [];
if (this.tensor.shape.length === 1) {
shape = [1, this.tensor.shape[0]];
this.is1D = true;
} else if (this.tensor.shape.length === 2) {
shape = this.tensor.shape;
} else if (this.tensor.shape.length === 3 && (type === '2d_array' || type === '3d')) {
shape = this.tensor.shape;
} else {
throw new Error('[Tensor] cannot create WebGL2 texture.');
}
this.glTextureShape = shape;
this.glTextureType = type;
this.glTextureFormat = format;
if (type === '2d') {
if (this.glTextureShape[0] > _WebGL.MAX_TEXTURE_SIZE && supportsTextureFragments) {
this._create2DRowFragmentedGLTexture();
} else {
this._create2DGLTexture();
}
} else if (type === '2d_array' || type === '3d') {
this._create3DGLTexture();
} else {
throw new Error(`[Tensor] invalid type ${type}.`);
}
}
_create2DGLTexture() {
const gl = _WebGL.webgl2.context;
const textureOptions = _WebGL.webgl2.getWebGLTextureOptions(this.glTextureType, this.glTextureFormat);
const {
textureTarget,
textureInternalFormat,
textureFormat,
textureType
} = textureOptions;
this.glTexture = gl.createTexture();
_WebGL.webgl2.storeRef('texture', this.glTexture);
gl.bindTexture(textureTarget, this.glTexture);
const shape = this.glTextureShape;
const data = this.tensor.data;
gl.texImage2D(textureTarget, 0, textureInternalFormat, shape[1], shape[0], 0, textureFormat, textureType, data);
gl.texParameteri(textureTarget, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(textureTarget, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(textureTarget, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texParameteri(textureTarget, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
}
_create2DRowFragmentedGLTexture() {
const gl = _WebGL.webgl2.context;
const textureOptions = _WebGL.webgl2.getWebGLTextureOptions(this.glTextureType, this.glTextureFormat);
const {
textureTarget,
textureInternalFormat,
textureFormat,
textureType
} = textureOptions;
this.glTextureFragments = [];
this.glTextureFragmentShape = [_WebGL.MAX_TEXTURE_SIZE, this.glTextureShape[1]];
const shape = this.glTextureFragmentShape;
const numFragments = Math.ceil(this.glTextureShape[0] / _WebGL.MAX_TEXTURE_SIZE);
let offset = 0;
for (let k = 0; k < numFragments; k++) {
const glTexture = gl.createTexture();
_WebGL.webgl2.storeRef('texture', glTexture);
gl.bindTexture(textureTarget, glTexture);
let data;
if (k === numFragments - 1) {
data = new this.arrayType(shape[0] * shape[1]);
data.set(this.tensor.data.slice(offset, offset + shape[0] * shape[1]), 0);
} else {
data = this.tensor.data.slice(offset, offset + shape[0] * shape[1]);
}
gl.texImage2D(textureTarget, 0, textureInternalFormat, shape[1], shape[0], 0, textureFormat, textureType, data);
gl.texParameteri(textureTarget, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(textureTarget, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(textureTarget, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texParameteri(textureTarget, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
this.glTextureFragments.push(glTexture);
offset += shape[0] * shape[1];
}
}
_create3DGLTexture() {
const gl = _WebGL.webgl2.context;
const textureOptions = _WebGL.webgl2.getWebGLTextureOptions(this.glTextureType, this.glTextureFormat);
const {
textureTarget,
textureInternalFormat,
textureFormat,
textureType
} = textureOptions;
this.glTexture = gl.createTexture();
_WebGL.webgl2.storeRef('texture', this.glTexture);
gl.bindTexture(textureTarget, this.glTexture);
const shape = this.glTextureShape;
const data = tensorUtils.data3DLayoutForGL(this.arrayType, this.tensor, this.glTextureShape);
gl.texImage3D(textureTarget, 0, textureInternalFormat, shape[1], shape[0], shape[2], 0, textureFormat, textureType, data);
gl.texParameteri(textureTarget, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(textureTarget, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(textureTarget, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texParameteri(textureTarget, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
}
convert2DRowFragmentedGLTextureToColStack() {
if (!this.glTextureFragments || !this.glTextureFragmentShape) {
throw new Error('[Tensor] no glTextureFragments available.');
}
const gl = _WebGL.webgl2.context;
const textureOptions = _WebGL.webgl2.getWebGLTextureOptions(this.glTextureType, this.glTextureFormat);
const {
textureTarget,
textureInternalFormat,
textureFormat,
textureType
} = textureOptions;
if (!this.glTextureFragmentsAsColStack) {
this.glTextureFragmentsAsColStack = gl.createTexture();
_WebGL.webgl2.storeRef('texture', this.glTextureFragmentsAsColStack);
gl.bindTexture(textureTarget, this.glTextureFragmentsAsColStack);
const numFragments = this.glTextureFragments.length;
this.glTextureFragmentsAsColStackShape = [this.glTextureFragmentShape[0], this.glTextureFragmentShape[1] * numFragments];
const shape = this.glTextureFragmentsAsColStackShape;
const data = new this.arrayType(shape.reduce((a, b) => a * b, 1));
gl.texImage2D(textureTarget, 0, textureInternalFormat, shape[1], shape[0], 0, textureFormat, textureType, data);
gl.texParameteri(textureTarget, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(textureTarget, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(textureTarget, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texParameteri(textureTarget, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
} else {
gl.bindTexture(textureTarget, this.glTextureFragmentsAsColStack);
}
const fbo = gl.createFramebuffer();
gl.bindFramebuffer(gl.READ_FRAMEBUFFER, fbo);
this.glTextureFragments.forEach((texture, k) => {
gl.framebufferTexture2D(gl.READ_FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
gl.copyTexSubImage2D(textureTarget, 0, k * this.glTextureFragmentShape[1], 0, 0, 0, this.glTextureFragmentShape[1], this.glTextureFragmentShape[0]);
});
gl.deleteFramebuffer(fbo);
}
removeGLTextureFragmentsAsColStack() {
if (this.glTextureFragmentsAsColStack) {
const gl = _WebGL.webgl2.context;
gl.deleteTexture(this.glTextureFragmentsAsColStack);
delete this.glTextureFragmentsAsColStack;
delete this.glTextureFragmentsAsColStackShape;
}
}
deleteGLTexture() {
const gl = _WebGL.webgl2.context;
if (this.glTexture) {
gl.deleteTexture(this.glTexture);
delete this.glTexture;
}
if (this.glTextureFragments) {
this.glTextureFragments.forEach(texture => {
gl.deleteTexture(texture);
});
delete this.glTextureFragments;
}
}
replaceTensorData(data) {
if (data && data.length && data instanceof this.arrayType) {
this.tensor.data.set(data);
} else if (data && data.length && data instanceof Array) {
this.tensor.data.set(new this.arrayType(data));
} else {
throw new Error('[Tensor] invalid input for replaceTensorData method.');
}
if (this.glTexture) {
const gl = _WebGL.webgl2.context;
const textureOptions = _WebGL.webgl2.getWebGLTextureOptions(this.glTextureType, this.glTextureFormat);
const {
textureTarget,
textureFormat,
textureType
} = textureOptions;
gl.bindTexture(textureTarget, this.glTexture);
const shape = this.glTextureShape;
if (this.glTextureType === '2d') {
const data = this.tensor.data;
gl.texSubImage2D(textureTarget, 0, 0, 0, shape[1], shape[0], textureFormat, textureType, data, 0);
} else if (this.glTextureType === '2d_array' || this.glTextureType === '3d') {
const data = tensorUtils.data3DLayoutForGL(this.arrayType, this.tensor, shape);
gl.texSubImage3D(textureTarget, 0, 0, 0, 0, shape[1], shape[0], shape[2], textureFormat, textureType, data, 0);
}
}
}
transferFromGLTexture() {
if (this.glTextureFragments) {
this.tensor = (0, _ndarray.default)(new this.arrayType(this.glTextureShape[0] * this.glTextureShape[1]), this.glTextureShape);
let offset = 0;
for (let k = 0; k < this.glTextureFragments.length; k++) {
_WebGL.webgl2.bindOutputTexture(this.glTextureFragments[k], this.glTextureFragmentShape);
const fragmentData = _WebGL.webgl2.readData(this.glTextureFragmentShape);
if (k === this.glTextureFragments.length - 1) {
const truncate = this.tensor.data.length - offset;
this.tensor.data.set(fragmentData.subarray(0, truncate), offset);
} else {
this.tensor.data.set(fragmentData, offset);
}
offset += fragmentData.length;
}
} else {
_WebGL.webgl2.bindOutputTexture(this.glTexture, this.glTextureShape);
this.tensor = (0, _ndarray.default)(new this.arrayType([]), this.glTextureShape);
this.tensor.data = _WebGL.webgl2.readData(this.glTextureShape);
}
if (this.is1D && this.glTextureShape[0] === 1) {
this.tensor = (0, _ndarraySqueeze.default)(this.tensor, [0]);
}
}
reshapeTo2D() {
const axis = this.tensor.shape.length - 1;
const axisSize = this.tensor.shape[axis];
const otherAxes = this.tensor.shape.slice(0, axis);
const otherAxesSize = otherAxes.reduce((a, b) => a * b, 1);
const reshaped = (0, _ndarray.default)(new this.arrayType(otherAxesSize * axisSize), [otherAxesSize, axisSize]);
const otherAxesData = (0, _ndarray.default)(new this.arrayType(otherAxesSize), otherAxes);
const otherAxesDataRaveled = (0, _ndarray.default)(new this.arrayType(otherAxesSize), [otherAxesSize]);
const axisSlices = Array(this.tensor.shape.length).fill(null);
for (let n = 0; n < axisSize; n++) {
axisSlices[axis] = n;
_ndarrayOps.default.assign(otherAxesData, this.tensor.pick(...axisSlices));
otherAxesDataRaveled.data = otherAxesData.data;
_ndarrayOps.default.assign(reshaped.pick(null, n), otherAxesDataRaveled);
}
this.originalShape = this.tensor.shape;
this.indicesForReshaped = tensorUtils.createIndicesFor2DReshaped(this.tensor.shape, false, axis);
this.tensor = reshaped;
this.is2DReshaped = true;
}
reshapeFrom2D(axis = -1) {
if (!this.is2DReshaped) {
throw new Error('[Tensor] not in reshaped 2D representation.');
}
if (!this.originalShape) {
throw new Error('[Tensor] does not contain originalShape.');
}
if (axis < 0) {
axis = this.originalShape.length + axis;
}
const channelDataSize = this.tensor.shape[0];
const channels = this.tensor.shape[1];
const reshaped = (0, _ndarray.default)(new this.arrayType(this.originalShape.reduce((a, b) => a * b, 1)), this.originalShape);
const channelDataRaveled = (0, _ndarray.default)(new this.arrayType(channelDataSize), [channelDataSize]);
const unraveledChannelShape = [...this.originalShape.slice(0, axis), ...this.originalShape.slice(axis + 1)];
const unraveledChannel = (0, _ndarray.default)(new this.arrayType(unraveledChannelShape.reduce((a, b) => a * b, 1)), unraveledChannelShape);
const axisSlices = Array(this.originalShape.length).fill(null);
for (let n = 0; n < channels; n++) {
_ndarrayOps.default.assign(channelDataRaveled, this.tensor.pick(null, n));
unraveledChannel.data = channelDataRaveled.data;
axisSlices[axis] = n;
_ndarrayOps.default.assign(reshaped.pick(...axisSlices), unraveledChannel);
}
this.tensor = reshaped;
}
reshapeTo2DSquare() {
const squareDim = Math.ceil(Math.sqrt(this.tensor.size));
const reshaped = (0, _ndarray.default)(new this.arrayType(squareDim ** 2), [squareDim, squareDim]);
reshaped.data.set(this.tensor.data);
this.originalShape = this.tensor.shape;
this.indicesForReshaped = tensorUtils.createIndicesFor2DReshaped(this.tensor.shape, true);
this.tensor = reshaped;
this.is2DSquareReshaped = true;
}
reshapeFrom2DSquare() {
if (!this.is2DSquareReshaped) {
throw new Error('[Tensor] not in reshaped 2D square representation.');
}
if (!this.originalShape) {
throw new Error('[Tensor] does not contain originalShape.');
}
const size = this.originalShape.reduce((a, b) => a * b, 1);
const reshaped = (0, _ndarray.default)(new this.arrayType(size), this.originalShape);
reshaped.data.set(this.tensor.data.subarray(0, size));
this.tensor = reshaped;
}
}
exports.default = Tensor;