keras-js
Version:
Run Keras models in the browser, with GPU support using WebGL
161 lines (128 loc) • 5.41 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", {
value: true
});
exports.default = void 0;
var _Layer = _interopRequireDefault(require("../../Layer"));
var _Tensor = _interopRequireDefault(require("../../Tensor"));
var _WebGL = require("../../WebGL2");
var _ndarrayOps = _interopRequireDefault(require("ndarray-ops"));
function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
const mapInputProgramSource = "#version 300 es\nprecision highp float;\nprecision highp isampler2D;\n\nin vec2 outTex;\nuniform sampler2D x;\nuniform isampler2D indexMap;\nuniform int inputCols;\nout vec4 outColor;\n\nvoid main() {\n ivec2 size = textureSize(indexMap, 0);\n int out_x = int(float(size[0]) * outTex.x);\n int out_y = int(float(size[1]) * outTex.y);\n\n int index = texelFetch(indexMap, ivec2(out_x, out_y), 0).r;\n\n if (index != -1) {\n int rowIndex = int(floor(float(index) / float(inputCols)));\n int colIndex = int(mod(float(index), float(inputCols)));\n float val = texelFetch(x, ivec2(colIndex, rowIndex), 0).r;\n outColor = vec4(val);\n } else {\n outColor = vec4(0.0);\n }\n}\n";
class Cropping2D extends _Layer.default {
constructor(attrs = {}) {
super(attrs);
this.layerClass = 'Cropping2D';
const {
cropping = [[0, 0], [0, 0]],
data_format = 'channels_last'
} = attrs;
if (Array.isArray(cropping)) {
if (Array.isArray(cropping[0])) {
this.cropping = cropping;
} else {
this.cropping = [[cropping[0], cropping[0]], [cropping[1], cropping[1]]];
}
} else {
this.cropping = [[cropping, cropping], [cropping, cropping]];
}
this.dataFormat = data_format;
this.description = `${JSON.stringify(this.cropping)}`;
if (this.gpu) {
this.mapInputProgram = _WebGL.webgl2.compileProgram(mapInputProgramSource);
}
}
call(x) {
if (this.gpu) {
this._callGPU(x);
} else {
this._callCPU(x);
}
return this.output;
}
_callCPU(x) {
if (this.dataFormat === 'channels_first') {
x.tensor = x.tensor.transpose(1, 2, 0);
}
this.inputShape = x.tensor.shape;
this.outputShape = [this.inputShape[0] - this.cropping[0][0] - this.cropping[0][1], this.inputShape[1] - this.cropping[1][0] - this.cropping[1][1], this.inputShape[2]];
this.output = new _Tensor.default([], this.outputShape);
_ndarrayOps.default.assign(this.output.tensor, x.tensor.hi(this.inputShape[0] - this.cropping[0][1], this.inputShape[1] - this.cropping[1][1], this.inputShape[2]).lo(this.cropping[0][0], this.cropping[1][0], 0));
if (this.dataFormat === 'channels_first') {
x.tensor = x.tensor.transpose(2, 0, 1);
this.output.tensor = this.output.tensor.transpose(2, 0, 1);
}
}
_createIndexMap(indicesForReshaped, is2DReshaped) {
if (this.indexMap) {
return;
}
const indices = new _Tensor.default(indicesForReshaped.data, indicesForReshaped.shape, {
type: Int32Array
});
this.indexMap = new _Tensor.default([], this.outputShape, {
type: Int32Array
});
const sliceStart = this.dataFormat === 'channels_first' ? [0, this.cropping[0][0], this.cropping[1][0]] : [this.cropping[0][0], this.cropping[1][0], 0];
const sliceEnd = this.dataFormat === 'channels_first' ? [this.inputShape[0], this.inputShape[1] - this.cropping[0][1], this.inputShape[2] - this.cropping[1][1]] : [this.inputShape[0] - this.cropping[0][1], this.inputShape[1] - this.cropping[1][1], this.inputShape[2]];
_ndarrayOps.default.assign(this.indexMap.tensor, indices.tensor.hi(...sliceEnd).lo(...sliceStart));
if (is2DReshaped) {
this.indexMap.reshapeTo2D();
} else {
this.indexMap.reshapeTo2DSquare();
}
this.indexMap.createGLTexture({
type: '2d',
format: 'int'
});
}
_callGPU(x) {
if (!x.glTexture) {
x.reshapeTo2DSquare();
x.createGLTexture({
type: '2d',
format: 'float'
});
}
this.inputShape = x.originalShape;
this.outputShape = this.dataFormat === 'channels_first' ? [this.inputShape[0], this.inputShape[1] - this.cropping[0][0] - this.cropping[0][1], this.inputShape[2] - this.cropping[1][0] - this.cropping[1][1]] : [this.inputShape[0] - this.cropping[0][0] - this.cropping[0][1], this.inputShape[1] - this.cropping[1][0] - this.cropping[1][1], this.inputShape[2]];
this._createIndexMap(x.indicesForReshaped, x.is2DReshaped);
if (!this.output) {
this.output = new _Tensor.default([], this.outputShape);
if (x.is2DReshaped) {
this.output.reshapeTo2D();
} else {
this.output.reshapeTo2DSquare();
}
this.output.createGLTexture({
type: '2d',
format: 'float'
});
}
_WebGL.webgl2.runProgram({
program: this.mapInputProgram,
output: this.output,
inputs: [{
input: x,
name: 'x'
}, {
input: this.indexMap,
name: 'indexMap'
}],
uniforms: [{
value: x.glTextureShape[1],
type: 'int',
name: 'inputCols'
}]
});
if (this.outbound.length === 0) {
this.output.transferFromGLTexture();
if (this.output.is2DReshaped) {
this.output.reshapeFrom2D();
} else {
this.output.reshapeFrom2DSquare();
}
}
}
}
exports.default = Cropping2D;