keras-js
Version:
Run Keras models in the browser, with GPU support using WebGL
419 lines (336 loc) • 13.3 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", {
value: true
});
exports.default = void 0;
var _Layer = _interopRequireDefault(require("../../Layer"));
var _Tensor = _interopRequireDefault(require("../../Tensor"));
var activations = _interopRequireWildcard(require("../../activations"));
var _WebGL = require("../../WebGL2");
var tensorUtils = _interopRequireWildcard(require("../../utils/tensorUtils"));
var _ndarrayOps = _interopRequireDefault(require("ndarray-ops"));
var _ndarrayGemm = _interopRequireDefault(require("ndarray-gemm"));
var _Conv2D = _interopRequireDefault(require("./Conv2D"));
var activationProgramSources = _interopRequireWildcard(require("../../activations/programSources"));
function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) { var desc = Object.defineProperty && Object.getOwnPropertyDescriptor ? Object.getOwnPropertyDescriptor(obj, key) : {}; if (desc.get || desc.set) { Object.defineProperty(newObj, key, desc); } else { newObj[key] = obj[key]; } } } } newObj.default = obj; return newObj; } }
function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
class _DepthwiseConv2D extends _Conv2D.default {
constructor(attrs = {}) {
super(attrs);
}
_calcOutputShape(inputShape) {
super._calcOutputShape(inputShape);
const nbFilter = this.kernelShape[0];
const inputChannels = inputShape[2];
this.outputShape[2] = nbFilter * inputChannels;
}
_im2col(x) {
const [inputRows, inputCols, inputChannels] = x.tensor.shape;
const nbRow = this.kernelShape[1];
const nbCol = this.kernelShape[2];
const outputRows = this.outputShape[0];
const outputCols = this.outputShape[1];
const nbPatches = outputRows * outputCols;
const patchLen = nbRow * nbCol;
if (!this.imColsMat) {
this.imColsMat = new _Tensor.default([], [nbPatches * inputChannels, patchLen]);
}
let patch = new _Tensor.default([], [nbRow, nbCol, 1]);
let offset = 0;
for (let c = 0; c < inputChannels; c++) {
for (let i = 0, limit = inputRows - nbRow; i <= limit; i += this.strides[0]) {
for (let j = 0, limit = inputCols - nbCol; j <= limit; j += this.strides[1]) {
_ndarrayOps.default.assign(patch.tensor, x.tensor.hi(i + nbRow, j + nbCol, c + 1).lo(i, j, c));
this.imColsMat.tensor.data.set(patch.tensor.data, offset);
offset += patchLen;
}
}
}
return this.imColsMat;
}
_w2row() {
const inputChannels = this.weights['kernel'].tensor.shape[2];
const [nbFilter, nbRow, nbCol] = this.kernelShape;
const patchLen = nbRow * nbCol;
this.wRowsMat = new _Tensor.default([], [patchLen, nbFilter * inputChannels]);
let patch = new _Tensor.default([], [nbRow, nbCol]);
let patchRaveled = new _Tensor.default([], [patchLen]);
let p = 0;
for (let c = 0; c < inputChannels; c++) {
for (let n = 0; n < nbFilter; n++) {
_ndarrayOps.default.assign(patch.tensor, this.weights['kernel'].tensor.pick(null, null, c, n));
patchRaveled.replaceTensorData(patch.tensor.data);
_ndarrayOps.default.assign(this.wRowsMat.tensor.pick(null, p), patchRaveled.tensor);
p += 1;
}
}
return this.wRowsMat;
}
_callCPU(x) {
this.inputShape = x.tensor.shape;
this._calcOutputShape(this.inputShape);
x = this._padInput(x);
this._im2col(x);
const nbFilter = this.kernelShape[0];
const outputRows = this.outputShape[0];
const outputCols = this.outputShape[1];
const nbPatches = outputRows * outputCols;
const inputChannels = this.inputShape[2];
const matMul = new _Tensor.default([], [nbPatches * inputChannels, nbFilter * inputChannels]);
(0, _ndarrayGemm.default)(matMul.tensor, this.imColsMat.tensor, this.wRowsMat.tensor, 1, 1);
this.output = new _Tensor.default([], this.outputShape);
const outputDataLength = outputRows * outputCols * nbFilter * inputChannels;
let dataFiltered = new Float32Array(outputDataLength);
for (let c = 0; c < inputChannels; c++) {
for (let n = c * outputDataLength + c * nbFilter; n < (c + 1) * outputDataLength; n += nbFilter * inputChannels) {
for (let m = 0; m < nbFilter; m++) {
dataFiltered[n + m - c * outputDataLength] = matMul.tensor.data[n + m];
}
}
}
this.output.replaceTensorData(dataFiltered);
}
_createIndexMap(indicesForReshaped) {
if (this.indexMap) {
return;
}
let [inputRows, inputCols, inputChannels] = this.inputShape;
let indices = new _Tensor.default(indicesForReshaped.data, indicesForReshaped.shape, {
type: Int32Array
});
if (this.padding === 'same') {
const [paddingRowBefore, paddingRowAfter, paddingColBefore, paddingColAfter] = this.inputPadding;
inputRows = inputRows + paddingRowBefore + paddingRowAfter;
inputCols = inputCols + paddingColBefore + paddingColAfter;
const padValue = -1;
indices = this._padInput(indices, padValue);
}
const nbRow = this.kernelShape[1];
const nbCol = this.kernelShape[2];
const outputRows = this.outputShape[0];
const outputCols = this.outputShape[1];
const nbPatches = outputRows * outputCols;
const patchLen = nbRow * nbCol;
this.indexMap = new _Tensor.default([], [nbPatches * inputChannels, patchLen], {
type: Int32Array
});
const indicesPatch = new _Tensor.default([], [nbRow, nbCol, 1]);
let offset = 0;
for (let c = 0; c < inputChannels; c++) {
for (let i = 0, limit = inputRows - nbRow; i <= limit; i += this.strides[0]) {
for (let j = 0, limit = inputCols - nbCol; j <= limit; j += this.strides[1]) {
_ndarrayOps.default.assign(indicesPatch.tensor, indices.tensor.hi(i + nbRow, j + nbCol, c + 1).lo(i, j, c));
this.indexMap.tensor.data.set(indicesPatch.tensor.data, offset);
offset += patchLen;
}
}
}
this.indexMap.createGLTexture({
type: '2d',
format: 'int',
supportsTextureFragments: true
});
}
_createOutputReshapeIndexMap() {
if (this.reshapeIndexMap) {
return;
}
const nbFilter = this.kernelShape[0];
const reshape = [this.outputShape[0] * this.outputShape[1], this.outputShape[2]];
const reshapeRowIndices = new _Tensor.default([], reshape, {
type: Int32Array
});
const reshapeColIndices = new _Tensor.default([], reshape, {
type: Int32Array
});
this.reshapeIndexMap = new _Tensor.default([], reshape, {
type: Int32Array
});
for (let j = 0; j < reshape[1]; j++) {
for (let i = 0; i < reshape[0]; i++) {
_ndarrayOps.default.assigns(reshapeRowIndices.tensor.pick(i, j), i + Math.floor(j / nbFilter) * reshape[0]);
}
}
for (let j = 0; j < reshape[1]; j++) {
_ndarrayOps.default.assigns(reshapeColIndices.tensor.pick(null, j), j);
}
_ndarrayOps.default.muls(this.reshapeIndexMap.tensor, reshapeRowIndices.tensor, reshape[1]);
_ndarrayOps.default.addeq(this.reshapeIndexMap.tensor, reshapeColIndices.tensor);
this.reshapeIndexMap.createGLTexture({
type: '2d',
format: 'int',
supportsTextureFragments: true
});
}
_callGPU(x) {
super._callGPU(x);
this._createOutputReshapeIndexMap();
if (!this.outputReshaped) {
const reshape = [this.outputShape[0] * this.outputShape[1], this.outputShape[2]];
this.outputReshaped = new _Tensor.default([], reshape);
this.outputReshaped.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: true
});
this.outputReshaped.is2DReshaped = true;
this.outputReshaped.originalShape = this.outputShape;
this.outputReshaped.indicesForReshaped = tensorUtils.createIndicesFor2DReshaped(this.outputShape, false, -1);
}
if (this.output.glTextureFragments) {
this.output.convert2DRowFragmentedGLTextureToColStack();
}
_WebGL.webgl2.runProgram({
program: this.output.glTextureFragments ? this.mapInputFragmentsProgram : this.mapInputProgram,
output: this.outputReshaped,
inputs: [{
input: this.output,
name: 'x'
}, {
input: this.reshapeIndexMap,
name: 'indexMap'
}],
uniforms: [{
value: this.output.glTextureShape[1],
type: 'int',
name: 'inputCols'
}],
supportsTextureFragments: true
});
if (this.output.glTextureFragments) {
this.output.removeGLTextureFragmentsAsColStack();
}
}
}
class SeparableConv2D extends _Layer.default {
constructor(attrs = {}) {
super(attrs);
this.layerClass = 'SeparableConv2D';
const {
filters = 1,
kernel_size = [1, 1],
strides = [1, 1],
padding = 'valid',
data_format = 'channels_last',
depth_multiplier = 1,
activation = 'linear',
use_bias = true
} = attrs;
if (Array.isArray(kernel_size)) {
this.kernelShape = [filters, ...kernel_size];
} else {
this.kernelShape = [filters, kernel_size, kernel_size];
}
if (Array.isArray(strides)) {
this.strides = strides;
} else {
this.strides = [strides, strides];
}
if (padding === 'valid' || padding === 'same') {
this.padding = padding;
} else {
this.throwError('Invalid padding.');
}
if (data_format === 'channels_last' || data_format === 'channels_first') {
this.dataFormat = data_format;
} else {
this.throwError('Only channels_last and channels_first data formats are allowed.');
}
this.activation = activation;
this.activationFunc = activations[activation];
if (padding === 'valid' || padding === 'same') {
this.padding = padding;
} else {
this.throwError('Invalid padding.');
}
this.useBias = use_bias;
this.params = this.useBias ? ['depthwise_kernel', 'pointwise_kernel', 'bias'] : ['depthwise_kernel', 'pointwise_kernel'];
this.depthwiseConvAttrs = {
filters: depth_multiplier,
kernel_size: [this.kernelShape[1], this.kernelShape[2]],
strides: this.strides,
padding,
data_format,
activation: 'linear',
use_bias: false,
gpu: attrs.gpu
};
this.pointwiseConvAttrs = {
filters,
kernel_size: [1, 1],
strides: [1, 1],
padding,
data_format,
activation: 'linear',
use_bias,
gpu: attrs.gpu
};
this.description = `${this.kernelShape[0]} ${this.kernelShape.slice(1).join('x')} filters`;
this.description += this.strides.some(s => s > 1) ? `, ${this.strides.join('x')} striding` : '';
this.description += this.padding === 'valid' ? `, no border padding` : ', pad to same borders';
this.description += depth_multiplier > 1 ? `, depth multiplier: ${depth_multiplier}` : '';
this.description += this.activation !== 'linear' ? `, ${this.activation} activation` : '';
if (this.gpu) {
this.activationProgram = _WebGL.webgl2.compileProgram(activationProgramSources[this.activation]);
}
}
setWeights(weightsArr) {
this._depthwiseConv = new _DepthwiseConv2D(this.depthwiseConvAttrs);
this._depthwiseConv.setWeights(weightsArr.slice(0, 1));
this._pointwiseConv = new _Conv2D.default(this.pointwiseConvAttrs);
this._pointwiseConv.setWeights(weightsArr.slice(1, 3));
}
call(x) {
if (this.gpu) {
this._callGPU(x);
} else {
this._callCPU(x);
}
return this.output;
}
_callCPU(x) {
this._depthwiseConv._callCPU(x);
this._pointwiseConv._callCPU(this._depthwiseConv.output);
this.output = this._pointwiseConv.output;
this.activationFunc(this.output);
}
_callGPU(x) {
this._depthwiseConv.outbound = [null];
this._pointwiseConv.outbound = [null];
this._depthwiseConv._callGPU(x);
this._pointwiseConv._callGPU(this._depthwiseConv.outputReshaped);
if (this.activation === 'linear') {
this.output = this._pointwiseConv.output;
} else {
if (!this.output) {
this.output = new _Tensor.default([], this._pointwiseConv.output.glTextureShape);
this.output.createGLTexture({
type: '2d',
format: 'float',
supportsTextureFragments: true
});
this.output.is2DReshaped = true;
this.output.originalShape = this._pointwiseConv.output.originalShape;
this.output.indicesForReshaped = tensorUtils.createIndicesFor2DReshaped(this._pointwiseConv.output.originalShape, false, -1);
}
this.outputPreactiv = this._pointwiseConv.output;
_WebGL.webgl2.runProgram({
program: this.activationProgram,
output: this.output,
inputs: [{
input: this.outputPreactiv,
name: 'x'
}],
supportsTextureFragments: true
});
}
if (this.outbound.length === 0) {
this.output.transferFromGLTexture();
this.output.reshapeFrom2D();
if (this.dataFormat === 'channels_first') {
this.output.tensor = this.output.tensor.transpose(2, 0, 1);
}
}
}
}
exports.default = SeparableConv2D;