keras-js
Version:
Run Keras models in the browser, with GPU support using WebGL
475 lines (384 loc) • 17.9 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", {
value: true
});
exports.default = void 0;
var _Layer = _interopRequireDefault(require("../../Layer"));
var _Tensor = _interopRequireDefault(require("../../Tensor"));
var activations = _interopRequireWildcard(require("../../activations"));
var _WebGL = require("../../WebGL2");
var _createGLSLProgram = _interopRequireDefault(require("../../webgl/dynamic/createGLSLProgram"));
var tensorUtils = _interopRequireWildcard(require("../../utils/tensorUtils"));
var _cwise = _interopRequireDefault(require("cwise"));
var _ndarrayOps = _interopRequireDefault(require("ndarray-ops"));
var _ndarrayGemm = _interopRequireDefault(require("ndarray-gemm"));
var activationProgramSources = _interopRequireWildcard(require("../../activations/programSources"));
function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) { var desc = Object.defineProperty && Object.getOwnPropertyDescriptor ? Object.getOwnPropertyDescriptor(obj, key) : {}; if (desc.get || desc.set) { Object.defineProperty(newObj, key, desc); } else { newObj[key] = obj[key]; } } } } newObj.default = obj; return newObj; } }
function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
const matMulProgramSource = "#version 300 es\nprecision highp float;\n\nin vec2 outTex;\nuniform sampler2D A;\nuniform sampler2D B;\nuniform sampler2D C;\nuniform bool addC;\nout vec4 outColor;\n\nvoid main() {\n ivec2 A_size = textureSize(A, 0);\n ivec2 B_size = textureSize(B, 0);\n int out_x = int(float(B_size[0]) * outTex.x);\n int out_y = int(float(A_size[1]) * outTex.y);\n int commonDim = A_size[0];\n\n float sum = 0.;\n for (int i = 0; i < commonDim; ++i) {\n float a = texelFetch(A, ivec2(i, out_y), 0).r;\n float b = texelFetch(B, ivec2(out_x, i), 0).r;\n sum += a * b;\n }\n\n if (addC) {\n sum += texelFetch(C, ivec2(out_x, 0), 0).r;\n }\n\n outColor = vec4(sum);\n}\n";
const assignToRowIndicesMap = (0, _cwise.default)({
args: [{
blockIndices: -1
}, 'scalar', 'scalar'],
body: function (x, rowIndex, size) {
for (let i = 0; i < size; i++) {
if (x[i] === -1) {
x[i] = rowIndex;
break;
}
}
}
});
const assignToColIndicesMap = (0, _cwise.default)({
args: [{
blockIndices: -1
}, 'array', 'scalar'],
body: function (x, colIndex, size) {
for (let i = 0; i < size; i++) {
if (x[i] === -1) {
x[i] = colIndex;
break;
}
}
}
});
class Conv2DTranspose extends _Layer.default {
constructor(attrs = {}) {
super(attrs);
this.layerClass = 'Conv2DTranspose';
const {
filters = 1,
kernel_size = [3, 3],
strides = [1, 1],
padding = 'valid',
data_format = 'channels_last',
activation = 'linear',
use_bias = true
} = attrs;
if (Array.isArray(kernel_size)) {
this.kernelShape = [filters, ...kernel_size];
} else {
this.kernelShape = [filters, kernel_size, kernel_size];
}
if (Array.isArray(strides)) {
this.strides = strides;
} else {
this.strides = [strides, strides];
}
if (padding === 'valid' || padding === 'same') {
this.padding = padding;
} else {
this.throwError('Invalid padding.');
}
if (data_format === 'channels_last' || data_format === 'channels_first') {
this.dataFormat = data_format;
} else {
this.throwError('Only channels_last and channels_first data formats are allowed.');
}
this.activation = activation;
this.activationFunc = activations[activation];
this.useBias = use_bias;
this.params = this.useBias ? ['kernel', 'bias'] : ['kernel'];
this.description = `${this.kernelShape[0]} ${this.kernelShape.slice(1).join('x')} filters`;
this.description += this.strides.some(s => s > 1) ? `, ${this.strides.join('x')} striding` : '';
this.description += this.padding === 'valid' ? `, no border padding` : ', pad to same borders';
this.description += this.activation !== 'linear' ? `, ${this.activation} activation` : '';
if (this.gpu) {
this.matMulProgram = _WebGL.webgl2.compileProgram(matMulProgramSource);
this.activationProgram = _WebGL.webgl2.compileProgram(activationProgramSources[this.activation]);
}
}
setWeights(weightsArr) {
if (this.dataFormat === 'channels_first') {
weightsArr[0].tensor = weightsArr[0].tensor.transpose(2, 3, 1, 0);
}
super.setWeights(weightsArr, false);
this._w2row();
if (this.gpu) {
this.weights['kernel'] = this.wRowsMat;
this.weights['kernel'].createGLTexture({
type: '2d',
format: 'float'
});
if (this.useBias) {
this.weights['bias'].createGLTexture({
type: '2d',
format: 'float'
});
}
}
}
call(x) {
if (this.gpu) {
this._callGPU(x);
} else {
this._callCPU(x);
}
return this.output;
}
_calcOutputShape(inputShape) {
if (this.outputShape && this.outputPadding) {
return;
}
const inputRows = inputShape[0];
const inputCols = inputShape[1];
const [nbFilter, nbRow, nbCol] = this.kernelShape;
const outputRows = this.padding === 'same' ? inputRows * this.strides[0] : inputRows * this.strides[0] + Math.max(nbRow - this.strides[0], 0);
const outputCols = this.padding === 'same' ? inputCols * this.strides[1] : inputCols * this.strides[1] + Math.max(nbCol - this.strides[1], 0);
const outputChannels = nbFilter;
const paddingRow = this.padding === 'same' ? Math.max(0, Math.floor((inputRows - 1) * this.strides[0] + nbRow - outputRows)) : 0;
const paddingCol = this.padding === 'same' ? Math.max(0, Math.floor((inputCols - 1) * this.strides[1] + nbCol - outputCols)) : 0;
const paddingRowBefore = Math.floor(paddingRow / 2);
const paddingRowAfter = paddingRow - paddingRowBefore;
const paddingColBefore = Math.floor(paddingCol / 2);
const paddingColAfter = paddingCol - paddingColBefore;
this.outputShape = [outputRows, outputCols, outputChannels];
this.outputPadding = [paddingRowBefore, paddingRowAfter, paddingColBefore, paddingColAfter];
}
_im2col(x) {
const [inputRows, inputCols, inputChannels] = x.tensor.shape;
if (!this.imColsMat) {
this.imColsMat = new _Tensor.default([], [inputRows * inputCols, inputChannels]);
}
const channelRaveled = new _Tensor.default([], [inputRows * inputCols]);
const channel = new _Tensor.default([], [inputRows, inputCols]);
for (let c = 0; c < inputChannels; c++) {
_ndarrayOps.default.assign(channel.tensor, x.tensor.pick(null, null, c));
channelRaveled.replaceTensorData(channel.tensor.data);
_ndarrayOps.default.assign(this.imColsMat.tensor.pick(null, c), channelRaveled.tensor);
}
return this.imColsMat;
}
_w2row() {
const [nbRow, nbCol, nbFilter, inputChannels] = this.weights['kernel'].tensor.shape;
this.wRowsMat = new _Tensor.default([], [inputChannels, nbRow * nbCol * nbFilter]);
const channelRaveled = new _Tensor.default([], [nbRow * nbCol * nbFilter]);
const channel = new _Tensor.default([], [nbRow, nbCol, nbFilter]);
for (let c = 0; c < inputChannels; c++) {
_ndarrayOps.default.assign(channel.tensor, this.weights['kernel'].tensor.pick(null, null, null, c));
channelRaveled.replaceTensorData(channel.tensor.data);
_ndarrayOps.default.assign(this.wRowsMat.tensor.pick(c, null), channelRaveled.tensor);
}
return this.wRowsMat;
}
_callCPU(x) {
this.inputShape = x.tensor.shape;
this._calcOutputShape(this.inputShape);
this._im2col(x);
const inputRows = x.tensor.shape[0];
const inputCols = x.tensor.shape[1];
const [nbFilter, nbRow, nbCol] = this.kernelShape;
const matMul = new _Tensor.default([], [inputRows * inputCols, nbRow * nbCol * nbFilter]);
(0, _ndarrayGemm.default)(matMul.tensor, this.imColsMat.tensor, this.wRowsMat.tensor, 1, 1);
const [paddingRowBefore, paddingRowAfter, paddingColBefore, paddingColAfter] = this.outputPadding;
this.output = new _Tensor.default([], this.outputShape);
let outputPadded = new _Tensor.default([], [this.outputShape[0] + paddingRowBefore + paddingRowAfter, this.outputShape[1] + paddingColBefore + paddingColAfter, this.outputShape[2]]);
const patchShape = [nbRow, nbCol, nbFilter];
let patch = new _Tensor.default([], patchShape);
let patchRaveled = new _Tensor.default([], [nbRow * nbCol * nbFilter]);
let index = 0;
for (let i = 0; i < inputRows; i++) {
for (let j = 0; j < inputCols; j++) {
_ndarrayOps.default.assign(patchRaveled.tensor, matMul.tensor.pick(index, null));
patch.replaceTensorData(patchRaveled.tensor.data);
const iOutPos = i * this.strides[0];
const jOutPos = j * this.strides[1];
_ndarrayOps.default.addeq(outputPadded.tensor.hi(iOutPos + nbRow, jOutPos + nbCol, this.outputShape[2]).lo(iOutPos, jOutPos, 0), patch.tensor);
index += 1;
}
}
_ndarrayOps.default.assign(this.output.tensor, outputPadded.tensor.hi(this.outputShape[0] + paddingRowBefore, this.outputShape[1] + paddingColBefore, this.outputShape[2]).lo(paddingRowBefore, paddingColBefore, 0));
if (this.useBias) {
for (let n = 0; n < nbFilter; n++) {
_ndarrayOps.default.addseq(this.output.tensor.pick(null, null, n), this.weights['bias'].tensor.get(n));
}
}
this.activationFunc(this.output);
if (this.dataFormat === 'channels_first') {
this.output.tensor = this.output.tensor.transpose(2, 0, 1);
}
}
_createIndexMap() {
if (this.indexMap) {
return;
}
const inputRows = this.inputShape[0];
const inputCols = this.inputShape[1];
const [nbFilter, nbRow, nbCol] = this.kernelShape;
const [paddingRowBefore, paddingRowAfter, paddingColBefore, paddingColAfter] = this.outputPadding;
const effectiveKernelSize = (nbRow - this.strides[0] + 1) * (nbCol - this.strides[1] + 1);
const indicesMapShape = [this.outputShape[0], this.outputShape[1], effectiveKernelSize];
const indicesMapShapePadded = [this.outputShape[0] + paddingRowBefore + paddingRowAfter, this.outputShape[1] + paddingColBefore + paddingColAfter, effectiveKernelSize];
const outputRowIndicesMap = new _Tensor.default([], indicesMapShape, {
type: Int32Array
});
const outputColIndicesMap = new _Tensor.default([], indicesMapShape, {
type: Int32Array
});
const outputRowIndicesMapPadded = new _Tensor.default([], indicesMapShapePadded, {
type: Int32Array
});
const outputColIndicesMapPadded = new _Tensor.default([], indicesMapShapePadded, {
type: Int32Array
});
_ndarrayOps.default.assigns(outputRowIndicesMap.tensor, -1);
_ndarrayOps.default.assigns(outputColIndicesMap.tensor, -1);
_ndarrayOps.default.assigns(outputRowIndicesMapPadded.tensor, -1);
_ndarrayOps.default.assigns(outputColIndicesMapPadded.tensor, -1);
const matMulColIndicesPatch = new _Tensor.default([], [nbRow, nbCol, nbFilter], {
type: Int32Array
});
for (let i = 0; i < nbRow * nbCol * nbFilter; i++) {
matMulColIndicesPatch.tensor.data[i] = i;
}
for (let i = 0; i < inputRows; i++) {
for (let j = 0; j < inputCols; j++) {
const matMulRowIndex = i * inputCols + j;
const iOutPos = i * this.strides[0];
const jOutPos = j * this.strides[1];
assignToRowIndicesMap(outputRowIndicesMapPadded.tensor.hi(iOutPos + nbRow, jOutPos + nbCol, effectiveKernelSize).lo(iOutPos, jOutPos, 0), matMulRowIndex, effectiveKernelSize);
assignToColIndicesMap(outputColIndicesMapPadded.tensor.hi(iOutPos + nbRow, jOutPos + nbCol, effectiveKernelSize).lo(iOutPos, jOutPos, 0), matMulColIndicesPatch.tensor.pick(null, null, 0), effectiveKernelSize);
}
}
_ndarrayOps.default.assign(outputRowIndicesMap.tensor, outputRowIndicesMapPadded.tensor.hi(this.outputShape[0] + paddingRowBefore, this.outputShape[1] + paddingColBefore, effectiveKernelSize).lo(paddingRowBefore, paddingColBefore, 0));
_ndarrayOps.default.assign(outputColIndicesMap.tensor, outputColIndicesMapPadded.tensor.hi(this.outputShape[0] + paddingRowBefore, this.outputShape[1] + paddingColBefore, effectiveKernelSize).lo(paddingRowBefore, paddingColBefore, 0));
const tiledIndicesMapShape = [this.outputShape[0] * this.outputShape[1], effectiveKernelSize];
this.indexMap = new _Tensor.default([], tiledIndicesMapShape, {
type: Int32Array
});
const channelData = new _Tensor.default([], [effectiveKernelSize], {
type: Int32Array
});
for (let i = 0; i < this.outputShape[0]; i++) {
for (let j = 0; j < this.outputShape[1]; j++) {
for (let k = 0; k < effectiveKernelSize; k++) {
const rowIndex = outputRowIndicesMap.tensor.get(i, j, k);
const colIndex = outputColIndicesMap.tensor.get(i, j, k);
if (rowIndex !== -1 && colIndex !== -1) {
channelData.tensor.set(k, rowIndex * this.weights['kernel'].glTextureShape[1] + colIndex);
} else {
channelData.tensor.set(k, -1);
}
}
_ndarrayOps.default.assign(this.indexMap.tensor.pick(i * this.outputShape[1] + j, null), channelData.tensor);
}
}
this.indexMap.createGLTexture({
type: '2d',
format: 'int',
supportsTextureFragments: true
});
}
_callGPU(x) {
if (x.is2DReshaped || x.is2DSquareReshaped) {
this.inputShape = x.originalShape;
this._calcOutputShape(this.inputShape);
} else {
this.inputShape = x.tensor.shape;
this._calcOutputShape(this.inputShape);
this._im2col(x);
this.imColsMat.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: true
});
}
const input = x.is2DReshaped || x.is2DSquareReshaped ? x : this.imColsMat;
if (!this.matMulResult) {
const outputTextureShape = [input.glTextureShape[0], this.weights['kernel'].glTextureShape[1]];
this.matMulResult = new _Tensor.default([], outputTextureShape);
this.matMulResult.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: true
});
}
if (this.activation !== 'linear' && !this.outputPreactiv) {
const outputTextureShape = [this.outputShape[0] * this.outputShape[1], this.outputShape[2]];
this.outputPreactiv = new _Tensor.default([], outputTextureShape);
this.outputPreactiv.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: true
});
this.outputPreactiv.is2DReshaped = true;
this.outputPreactiv.originalShape = this.outputShape;
this.outputPreactiv.indicesForReshaped = tensorUtils.createIndicesFor2DReshaped(this.outputShape, false, -1);
}
if (!this.output) {
const outputTextureShape = [this.outputShape[0] * this.outputShape[1], this.outputShape[2]];
this.output = new _Tensor.default([], outputTextureShape);
this.output.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: true
});
this.output.is2DReshaped = true;
this.output.originalShape = this.outputShape;
this.output.indicesForReshaped = tensorUtils.createIndicesFor2DReshaped(this.outputShape, false, -1);
}
_WebGL.webgl2.runProgram({
program: this.matMulProgram,
output: this.matMulResult,
inputs: [{
input: input,
name: 'A'
}, {
input: this.weights['kernel'],
name: 'B'
}],
uniforms: [{
value: 0,
type: 'bool',
name: 'addC'
}],
supportsTextureFragments: true
});
this._createIndexMap();
const hasFragments = Boolean(this.matMulResult.glTextureFragments);
if (hasFragments) {
this.matMulResult.convert2DRowFragmentedGLTextureToColStack();
}
if (!this.convTransposeProgram) {
const convTransposeProgramSource = (0, _createGLSLProgram.default)('conv2dTranspose', this.output.glTextureFragmentShape ? this.output.glTextureFragmentShape : this.output.glTextureShape, this.matMulResult.glTextureFragmentShape ? this.matMulResult.glTextureFragmentShape : this.matMulResult.glTextureShape, this.indexMap.glTextureFragmentShape ? this.indexMap.glTextureFragmentShape : this.indexMap.glTextureShape, this.useBias, hasFragments);
this.convTransposeProgram = _WebGL.webgl2.compileProgram(convTransposeProgramSource);
}
_WebGL.webgl2.runProgram({
program: this.convTransposeProgram,
output: this.activation === 'linear' ? this.output : this.outputPreactiv,
inputs: [{
input: this.matMulResult,
name: 'matMulResult'
}, {
input: this.indexMap,
name: 'indexMap'
}, ...(this.useBias ? [{
input: this.weights['bias'],
name: 'bias'
}] : [])],
supportsTextureFragments: true
});
if (hasFragments) {
this.matMulResult.removeGLTextureFragmentsAsColStack();
}
if (this.activation !== 'linear') {
_WebGL.webgl2.runProgram({
program: this.activationProgram,
output: this.output,
inputs: [{
input: this.outputPreactiv,
name: 'x'
}],
supportsTextureFragments: true
});
}
if (this.outbound.length === 0) {
this.output.transferFromGLTexture();
this.output.reshapeFrom2D();
if (this.dataFormat === 'channels_first') {
this.output.tensor = this.output.tensor.transpose(2, 0, 1);
}
}
}
}
exports.default = Conv2DTranspose;