@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
253 lines • 34.9 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Advanced activation layers.
*/
import { add, cast, clipByValue, elu, exp, greater, leakyRelu, logSumExp, mul, ones, prelu, relu, scalar, serialization, sub, tidy } from '@tensorflow/tfjs-core';
import { Softmax as softmaxActivation } from '../activations';
import { getConstraint, serializeConstraint } from '../constraints';
import { InputSpec, Layer } from '../engine/topology';
import { NotImplementedError, ValueError } from '../errors';
import { getInitializer, serializeInitializer } from '../initializers';
import { getRegularizer, serializeRegularizer } from '../regularizers';
import { getExactlyOneShape, getExactlyOneTensor } from '../utils/types_utils';
class ReLU extends Layer {
constructor(args) {
super(args == null ? {} : args);
this.supportsMasking = true;
if (args != null) {
this.maxValue = args.maxValue;
}
}
call(inputs, kwargs) {
inputs = getExactlyOneTensor(inputs);
let output = relu(inputs);
if (this.maxValue != null) {
output = clipByValue(output, 0, this.maxValue);
}
return output;
}
computeOutputShape(inputShape) {
return inputShape;
}
getConfig() {
const config = { maxValue: this.maxValue };
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
}
/** @nocollapse */
ReLU.className = 'ReLU';
export { ReLU };
serialization.registerClass(ReLU);
class LeakyReLU extends Layer {
constructor(args) {
super(args == null ? {} : args);
this.DEFAULT_ALPHA = 0.3;
if (args == null) {
args = {};
}
this.alpha = args.alpha == null ? this.DEFAULT_ALPHA : args.alpha;
}
call(inputs, kwargs) {
const x = getExactlyOneTensor(inputs);
return leakyRelu(x, this.alpha);
}
computeOutputShape(inputShape) {
return inputShape;
}
getConfig() {
const config = { alpha: this.alpha };
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
}
/** @nocollapse */
LeakyReLU.className = 'LeakyReLU';
export { LeakyReLU };
serialization.registerClass(LeakyReLU);
class PReLU extends Layer {
constructor(args) {
super(args == null ? {} : args);
this.DEFAULT_ALPHA_INITIALIZER = 'zeros';
if (args == null) {
args = {};
}
this.supportsMasking = true;
this.alphaInitializer =
getInitializer(args.alphaInitializer || this.DEFAULT_ALPHA_INITIALIZER);
this.alphaRegularizer = getRegularizer(args.alphaRegularizer);
this.alphaConstraint = getConstraint(args.alphaConstraint);
if (args.sharedAxes == null) {
this.sharedAxes = null;
}
else if (Array.isArray(args.sharedAxes)) {
this.sharedAxes = args.sharedAxes;
}
else if (typeof args.sharedAxes === 'number') {
this.sharedAxes = [args.sharedAxes];
}
else {
throw new ValueError(`Expected sharedAxes to be a number or an array of numbers, ` +
`but got ${args.sharedAxes}`);
}
}
build(inputShape) {
inputShape = getExactlyOneShape(inputShape);
const paramShape = inputShape.slice(1);
if (this.sharedAxes != null) {
for (const i of this.sharedAxes) {
paramShape[i - 1] = 1;
}
}
this.alpha = this.addWeight('alpha', paramShape, 'float32', this.alphaInitializer, this.alphaRegularizer, true, this.alphaConstraint);
// Set input spec.
const axes = {};
if (this.sharedAxes != null) {
for (let i = 1; i < inputShape.length; ++i) {
axes[i] = inputShape[i];
}
}
this.inputSpec = [new InputSpec({
ndim: inputShape.length,
axes,
})];
this.built = true;
}
call(inputs, kwargs) {
inputs = getExactlyOneTensor(inputs);
return prelu(inputs, this.alpha.read());
}
getConfig() {
const config = {
alphaInitializer: serializeInitializer(this.alphaInitializer),
alphaRegularizer: serializeRegularizer(this.alphaRegularizer),
alphaConstraint: serializeConstraint(this.alphaConstraint),
sharedAxes: this.sharedAxes
};
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
}
/** @nocollapse */
PReLU.className = 'PReLU';
export { PReLU };
serialization.registerClass(PReLU);
class ELU extends Layer {
constructor(args) {
super(args == null ? {} : args);
this.DEFAULT_ALPHA = 1.0;
if (args == null) {
args = {};
}
if (args.alpha != null && args.alpha !== this.DEFAULT_ALPHA) {
throw new NotImplementedError(`Non-default alpha value (${args.alpha}) is not supported by the ` +
`ELU layer yet.`);
}
this.alpha = args.alpha == null ? this.DEFAULT_ALPHA : args.alpha;
}
call(inputs, kwargs) {
const x = getExactlyOneTensor(inputs);
return elu(x);
}
computeOutputShape(inputShape) {
return inputShape;
}
getConfig() {
const config = { alpha: this.alpha };
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
}
/** @nocollapse */
ELU.className = 'ELU';
export { ELU };
serialization.registerClass(ELU);
class ThresholdedReLU extends Layer {
constructor(args) {
super(args == null ? {} : args);
this.DEFAULT_THETA = 1.0;
if (args == null) {
args = {};
}
this.theta = args.theta == null ? this.DEFAULT_THETA : args.theta;
}
call(inputs, kwargs) {
const x = getExactlyOneTensor(inputs);
return mul(x, cast(greater(x, this.theta), 'float32'));
}
computeOutputShape(inputShape) {
return inputShape;
}
getConfig() {
const config = { theta: this.theta };
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
}
/** @nocollapse */
ThresholdedReLU.className = 'ThresholdedReLU';
export { ThresholdedReLU };
serialization.registerClass(ThresholdedReLU);
class Softmax extends Layer {
constructor(args) {
super(args == null ? {} : args);
this.DEFAULT_AXIS = 1.0;
if (args == null) {
args = {};
}
this.softmax = new softmaxActivation().apply;
this.axis = args.axis == null ? this.DEFAULT_AXIS : args.axis;
}
call(inputs, kwargs) {
// TODO(pforderique): Add tests for when `this.axis` is a number[].
return tidy(() => {
let x = getExactlyOneTensor(inputs);
const mask = kwargs['mask'];
if (mask != null) {
// Since mask is 1.0 for positions we want to keep and 0.0 for masked
// positions, this operation will create a tensor which is 0.0 for
// positions we want to attend and -1e.9 for masked positions.
const adder = mul(sub(ones(x.shape), cast(mask, x.dtype)), scalar(-1e9));
// Since we are adding it to the raw scores before the softmax, this
// is effectively the same as removing these entirely.
x = add(x, adder);
}
if (this.axis instanceof Array) {
if (this.axis.length > 1) {
return exp(sub(x, logSumExp(x, this.axis, true)));
}
else {
return this.softmax(x, this.axis[0]);
}
}
return this.softmax(x, this.axis);
});
}
computeOutputShape(inputShape) {
return inputShape;
}
getConfig() {
const config = { axis: this.axis };
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
}
/** @nocollapse */
Softmax.className = 'Softmax';
export { Softmax };
serialization.registerClass(Softmax);
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"advanced_activations.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/layers/advanced_activations.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH;;GAEG;AAEH,OAAO,EAAC,GAAG,EAAE,IAAI,EAAE,WAAW,EAAE,GAAG,EAAE,GAAG,EAAE,OAAO,EAAE,SAAS,EAAE,SAAS,EAAE,GAAG,EAAE,IAAI,EAAE,KAAK,EAAE,IAAI,EAAE,MAAM,EAAE,aAAa,EAAE,GAAG,EAAU,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAExK,OAAO,EAAC,OAAO,IAAI,iBAAiB,EAAC,MAAM,gBAAgB,CAAC;AAC5D,OAAO,EAAa,aAAa,EAAE,mBAAmB,EAAC,MAAM,gBAAgB,CAAC;AAC9E,OAAO,EAAC,SAAS,EAAE,KAAK,EAAY,MAAM,oBAAoB,CAAC;AAC/D,OAAO,EAAC,mBAAmB,EAAE,UAAU,EAAC,MAAM,WAAW,CAAC;AAC1D,OAAO,EAAC,cAAc,EAAsC,oBAAoB,EAAC,MAAM,iBAAiB,CAAC;AAEzG,OAAO,EAAC,cAAc,EAAe,oBAAoB,EAAC,MAAM,iBAAiB,CAAC;AAElF,OAAO,EAAC,kBAAkB,EAAE,mBAAmB,EAAC,MAAM,sBAAsB,CAAC;AAU7E,MAAa,IAAK,SAAQ,KAAK;IAK7B,YAAY,IAAoB;QAC9B,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAChC,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC;QAC5B,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;SAC/B;IACH,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACrC,IAAI,MAAM,GAAG,IAAI,CAAC,MAAM,CAAC,CAAC;QAC1B,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;YACzB,MAAM,GAAG,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC;SAChD;QACD,OAAO,MAAM,CAAC;IAChB,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B,EAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAC,CAAC;QACnE,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AA9BD,kBAAkB;AACX,cAAS,GAAG,MAAM,CAAC;SAFf,IAAI;AAiCjB,aAAa,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;AASlC,MAAa,SAAU,SAAQ,KAAK;IAOlC,YAAY,IAAyB;QACnC,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAHzB,kBAAa,GAAG,GAAG,CAAC;QAI3B,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,aAAa,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC;IACpE,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,CAAC,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACtC,OAAO,SAAS,CAAC,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;IAClC,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B,EAAC,KAAK,EAAE,IAAI,CAAC,KAAK,EAAC,CAAC;QAC7D,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AA5BD,kBAAkB;AACX,mBAAS,GAAG,WAAW,AAAd,CAAe;SAFpB,SAAS;AA+BtB,aAAa,CAAC,aAAa,CAAC,SAAS,CAAC,CAAC;AA6BvC,MAAa,KAAM,SAAQ,KAAK;IAW9B,YAAY,IAAqB;QAC/B,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAHzB,8BAAyB,GAA0B,OAAO,CAAC;QAIlE,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QAED,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC;QAC5B,IAAI,CAAC,gBAAgB;YACjB,cAAc,CAAC,IAAI,CAAC,gBAAgB,IAAI,IAAI,CAAC,yBAAyB,CAAC,CAAC;QAC5E,IAAI,CAAC,gBAAgB,GAAG,cAAc,CAAC,IAAI,CAAC,gBAAgB,CAAC,CAAC;QAC9D,IAAI,CAAC,eAAe,GAAG,aAAa,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;QAC3D,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3B,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC;SACxB;aAAM,IAAI,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,UAAU,CAAC,EAAE;YACzC,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,UAAU,CAAC;SACnC;aAAM,IAAI,OAAO,IAAI,CAAC,UAAU,KAAK,QAAQ,EAAE;YAC9C,IAAI,CAAC,UAAU,GAAG,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;SACrC;aAAM;YACL,MAAM,IAAI,UAAU,CAChB,6DAA6D;gBAC7D,WAAW,IAAI,CAAC,UAAU,EAAE,CAAC,CAAC;SACnC;IACH,CAAC;IAEQ,KAAK,CAAC,UAAyB;QACtC,UAAU,GAAG,kBAAkB,CAAC,UAAU,CAAC,CAAC;QAC5C,MAAM,UAAU,GAAU,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAC9C,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3B,KAAK,MAAM,CAAC,IAAI,IAAI,CAAC,UAAU,EAAE;gBAC/B,UAAU,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;aACvB;SACF;QACD,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,SAAS,CACvB,OAAO,EAAE,UAAU,EAAE,SAAS,EAAE,IAAI,CAAC,gBAAgB,EACrD,IAAI,CAAC,gBAAgB,EAAE,IAAI,EAAE,IAAI,CAAC,eAAe,CAAC,CAAC;QACvD,kBAAkB;QAClB,MAAM,IAAI,GAA6B,EAAE,CAAC;QAC1C,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;gBAC1C,IAAI,CAAC,CAAC,CAAC,GAAG,UAAU,CAAC,CAAC,CAAC,CAAC;aACzB;SACF;QACD,IAAI,CAAC,SAAS,GAAG,CAAC,IAAI,SAAS,CAAC;gBAC9B,IAAI,EAAE,UAAU,CAAC,MAAM;gBACvB,IAAI;aACL,CAAC,CAAC,CAAC;QACJ,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC;IACpB,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACrC,OAAO,KAAK,CAAC,MAAM,EAAE,IAAI,CAAC,KAAK,CAAC,IAAI,EAAE,CAAC,CAAC;IAC1C,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B;YACvC,gBAAgB,EAAE,oBAAoB,CAAC,IAAI,CAAC,gBAAgB,CAAC;YAC7D,gBAAgB,EAAE,oBAAoB,CAAC,IAAI,CAAC,gBAAgB,CAAC;YAC7D,eAAe,EAAE,mBAAmB,CAAC,IAAI,CAAC,eAAe,CAAC;YAC1D,UAAU,EAAE,IAAI,CAAC,UAAU;SAC5B,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AA1ED,kBAAkB;AACX,eAAS,GAAG,OAAO,AAAV,CAAW;SAFhB,KAAK;AA6ElB,aAAa,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC;AASnC,MAAa,GAAI,SAAQ,KAAK;IAO5B,YAAY,IAAmB;QAC7B,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAHzB,kBAAa,GAAG,GAAG,CAAC;QAI3B,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QAED,IAAI,IAAI,CAAC,KAAK,IAAI,IAAI,IAAI,IAAI,CAAC,KAAK,KAAK,IAAI,CAAC,aAAa,EAAE;YAC3D,MAAM,IAAI,mBAAmB,CACzB,4BAA4B,IAAI,CAAC,KAAK,4BAA4B;gBAClE,gBAAgB,CAAC,CAAC;SACvB;QAED,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,aAAa,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC;IACpE,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,CAAC,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACtC,OAAO,GAAG,CAAC,CAAC,CAAC,CAAC;IAChB,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B,EAAC,KAAK,EAAE,IAAI,CAAC,KAAK,EAAC,CAAC;QAC7D,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AAnCD,kBAAkB;AACX,aAAS,GAAG,KAAK,AAAR,CAAS;SAFd,GAAG;AAsChB,aAAa,CAAC,aAAa,CAAC,GAAG,CAAC,CAAC;AASjC,MAAa,eAAgB,SAAQ,KAAK;IAOxC,YAAY,IAA+B;QACzC,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAHzB,kBAAa,GAAG,GAAG,CAAC;QAI3B,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QAED,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,aAAa,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC;IACpE,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,CAAC,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACtC,OAAO,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;IACzD,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B,EAAC,KAAK,EAAE,IAAI,CAAC,KAAK,EAAC,CAAC;QAC7D,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AA7BD,kBAAkB;AACX,yBAAS,GAAG,iBAAiB,AAApB,CAAqB;SAF1B,eAAe;AAgC5B,aAAa,CAAC,aAAa,CAAC,eAAe,CAAC,CAAC;AAU7C,MAAa,OAAQ,SAAQ,KAAK;IAOhC,YAAY,IAAuB;QACjC,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAHzB,iBAAY,GAAG,GAAG,CAAC;QAI1B,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,IAAI,CAAC,OAAO,GAAG,IAAI,iBAAiB,EAAE,CAAC,KAAK,CAAC;QAC7C,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,YAAY,CAAC,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC;IAChE,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,mEAAmE;QACnE,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,IAAI,CAAC,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;YACpC,MAAM,IAAI,GAAG,MAAM,CAAC,MAAM,CAAW,CAAC;YACtC,IAAI,IAAI,IAAI,IAAI,EAAE;gBAChB,qEAAqE;gBACrE,kEAAkE;gBAClE,8DAA8D;gBAC9D,MAAM,KAAK,GACT,GAAG,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,KAAK,CAAC,EAAE,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;gBAE7D,oEAAoE;gBACpE,sDAAsD;gBACtD,CAAC,GAAG,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,CAAC;aACnB;YACD,IAAI,IAAI,CAAC,IAAI,YAAY,KAAK,EAAE;gBAC9B,IAAI,IAAI,CAAC,IAAI,CAAC,MAAM,GAAG,CAAC,EAAE;oBACxB,OAAO,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC;iBACnD;qBAAM;oBACL,OAAO,IAAI,CAAC,OAAO,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;iBACtC;aACF;YACD,OAAO,IAAI,CAAC,OAAO,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC;QACpC,CAAC,CAAC,CAAC;IACL,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B,EAAC,IAAI,EAAE,IAAI,CAAC,IAAI,EAAC,CAAC;QAC3D,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AAnDD,kBAAkB;AACX,iBAAS,GAAG,SAAS,AAAZ,CAAa;SAFlB,OAAO;AAsDpB,aAAa,CAAC,aAAa,CAAC,OAAO,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/**\n *  Advanced activation layers.\n */\n\nimport {add, cast, clipByValue, elu, exp, greater, leakyRelu, logSumExp, mul, ones, prelu, relu, scalar, serialization, sub, Tensor, tidy} from '@tensorflow/tfjs-core';\n\nimport {Softmax as softmaxActivation} from '../activations';\nimport {Constraint, getConstraint, serializeConstraint} from '../constraints';\nimport {InputSpec, Layer, LayerArgs} from '../engine/topology';\nimport {NotImplementedError, ValueError} from '../errors';\nimport {getInitializer, Initializer, InitializerIdentifier, serializeInitializer} from '../initializers';\nimport {Shape} from '../keras_format/common';\nimport {getRegularizer, Regularizer, serializeRegularizer} from '../regularizers';\nimport {Kwargs} from '../types';\nimport {getExactlyOneShape, getExactlyOneTensor} from '../utils/types_utils';\nimport {LayerVariable} from '../variables';\n\nexport declare interface ReLULayerArgs extends LayerArgs {\n  /**\n   * Float, the maximum output value.\n   */\n  maxValue?: number;\n}\n\nexport class ReLU extends Layer {\n  /** @nocollapse */\n  static className = 'ReLU';\n  maxValue: number;\n\n  constructor(args?: ReLULayerArgs) {\n    super(args == null ? {} : args);\n    this.supportsMasking = true;\n    if (args != null) {\n      this.maxValue = args.maxValue;\n    }\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    inputs = getExactlyOneTensor(inputs);\n    let output = relu(inputs);\n    if (this.maxValue != null) {\n      output = clipByValue(output, 0, this.maxValue);\n    }\n    return output;\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    return inputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {maxValue: this.maxValue};\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(ReLU);\n\nexport declare interface LeakyReLULayerArgs extends LayerArgs {\n  /**\n   * Float `>= 0`. Negative slope coefficient. Defaults to `0.3`.\n   */\n  alpha?: number;\n}\n\nexport class LeakyReLU extends Layer {\n  /** @nocollapse */\n  static className = 'LeakyReLU';\n  readonly alpha: number;\n\n  readonly DEFAULT_ALPHA = 0.3;\n\n  constructor(args?: LeakyReLULayerArgs) {\n    super(args == null ? {} : args);\n    if (args == null) {\n      args = {};\n    }\n    this.alpha = args.alpha == null ? this.DEFAULT_ALPHA : args.alpha;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    const x = getExactlyOneTensor(inputs);\n    return leakyRelu(x, this.alpha);\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    return inputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {alpha: this.alpha};\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(LeakyReLU);\n\nexport declare interface PReLULayerArgs extends LayerArgs {\n  /**\n   * Initializer for the learnable alpha.\n   */\n  alphaInitializer?: Initializer|InitializerIdentifier;\n\n  /**\n   * Regularizer for the learnable alpha.\n   */\n  alphaRegularizer?: Regularizer;\n\n  /**\n   * Constraint for the learnable alpha.\n   */\n  alphaConstraint?: Constraint;\n\n  /**\n   * The axes along which to share learnable parameters for the activation\n   * function. For example, if the incoming feature maps are from a 2D\n   * convolution with output shape `[numExamples, height, width, channels]`,\n   * and you wish to share parameters across space (height and width) so that\n   * each filter channels has only one set of parameters, set\n   * `shared_axes: [1, 2]`.\n   */\n  sharedAxes?: number|number[];\n}\n\nexport class PReLU extends Layer {\n  /** @nocollapse */\n  static className = 'PReLU';\n  private readonly alphaInitializer: Initializer;\n  private readonly alphaRegularizer: Regularizer;\n  private readonly alphaConstraint: Constraint;\n  private readonly sharedAxes: number[];\n  private alpha: LayerVariable;\n\n  readonly DEFAULT_ALPHA_INITIALIZER: InitializerIdentifier = 'zeros';\n\n  constructor(args?: PReLULayerArgs) {\n    super(args == null ? {} : args);\n    if (args == null) {\n      args = {};\n    }\n\n    this.supportsMasking = true;\n    this.alphaInitializer =\n        getInitializer(args.alphaInitializer || this.DEFAULT_ALPHA_INITIALIZER);\n    this.alphaRegularizer = getRegularizer(args.alphaRegularizer);\n    this.alphaConstraint = getConstraint(args.alphaConstraint);\n    if (args.sharedAxes == null) {\n      this.sharedAxes = null;\n    } else if (Array.isArray(args.sharedAxes)) {\n      this.sharedAxes = args.sharedAxes;\n    } else if (typeof args.sharedAxes === 'number') {\n      this.sharedAxes = [args.sharedAxes];\n    } else {\n      throw new ValueError(\n          `Expected sharedAxes to be a number or an array of numbers, ` +\n          `but got ${args.sharedAxes}`);\n    }\n  }\n\n  override build(inputShape: Shape|Shape[]) {\n    inputShape = getExactlyOneShape(inputShape);\n    const paramShape: Shape = inputShape.slice(1);\n    if (this.sharedAxes != null) {\n      for (const i of this.sharedAxes) {\n        paramShape[i - 1] = 1;\n      }\n    }\n    this.alpha = this.addWeight(\n        'alpha', paramShape, 'float32', this.alphaInitializer,\n        this.alphaRegularizer, true, this.alphaConstraint);\n    // Set input spec.\n    const axes: {[axis: number]: number} = {};\n    if (this.sharedAxes != null) {\n      for (let i = 1; i < inputShape.length; ++i) {\n        axes[i] = inputShape[i];\n      }\n    }\n    this.inputSpec = [new InputSpec({\n      ndim: inputShape.length,\n      axes,\n    })];\n    this.built = true;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    inputs = getExactlyOneTensor(inputs);\n    return prelu(inputs, this.alpha.read());\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {\n      alphaInitializer: serializeInitializer(this.alphaInitializer),\n      alphaRegularizer: serializeRegularizer(this.alphaRegularizer),\n      alphaConstraint: serializeConstraint(this.alphaConstraint),\n      sharedAxes: this.sharedAxes\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(PReLU);\n\nexport declare interface ELULayerArgs extends LayerArgs {\n  /**\n   * Float `>= 0`. Negative slope coefficient. Defaults to `1.0`.\n   */\n  alpha?: number;\n}\n\nexport class ELU extends Layer {\n  /** @nocollapse */\n  static className = 'ELU';\n  readonly alpha: number;\n\n  readonly DEFAULT_ALPHA = 1.0;\n\n  constructor(args?: ELULayerArgs) {\n    super(args == null ? {} : args);\n    if (args == null) {\n      args = {};\n    }\n\n    if (args.alpha != null && args.alpha !== this.DEFAULT_ALPHA) {\n      throw new NotImplementedError(\n          `Non-default alpha value (${args.alpha}) is not supported by the ` +\n          `ELU layer yet.`);\n    }\n\n    this.alpha = args.alpha == null ? this.DEFAULT_ALPHA : args.alpha;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    const x = getExactlyOneTensor(inputs);\n    return elu(x);\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    return inputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {alpha: this.alpha};\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(ELU);\n\nexport declare interface ThresholdedReLULayerArgs extends LayerArgs {\n  /**\n   * Float >= 0. Threshold location of activation.\n   */\n  theta?: number;\n}\n\nexport class ThresholdedReLU extends Layer {\n  /** @nocollapse */\n  static className = 'ThresholdedReLU';\n  readonly theta: number;\n\n  readonly DEFAULT_THETA = 1.0;\n\n  constructor(args?: ThresholdedReLULayerArgs) {\n    super(args == null ? {} : args);\n    if (args == null) {\n      args = {};\n    }\n\n    this.theta = args.theta == null ? this.DEFAULT_THETA : args.theta;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    const x = getExactlyOneTensor(inputs);\n    return mul(x, cast(greater(x, this.theta), 'float32'));\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    return inputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {theta: this.theta};\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(ThresholdedReLU);\n\nexport declare interface SoftmaxLayerArgs extends LayerArgs {\n  /**\n   * Integer, axis along which the softmax normalization is applied.\n   * Defaults to `-1` (i.e., the last axis).\n   */\n  axis?: number|number[];\n}\n\nexport class Softmax extends Layer {\n  /** @nocollapse */\n  static className = 'Softmax';\n  readonly axis: number|number[];\n  readonly softmax: (t: Tensor, a?: number) => Tensor;\n  readonly DEFAULT_AXIS = 1.0;\n\n  constructor(args?: SoftmaxLayerArgs) {\n    super(args == null ? {} : args);\n    if (args == null) {\n      args = {};\n    }\n    this.softmax = new softmaxActivation().apply;\n    this.axis = args.axis == null ? this.DEFAULT_AXIS : args.axis;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    // TODO(pforderique): Add tests for when `this.axis` is a number[].\n    return tidy(() => {\n      let x = getExactlyOneTensor(inputs);\n      const mask = kwargs['mask'] as Tensor;\n      if (mask != null) {\n        // Since mask is 1.0 for positions we want to keep and 0.0 for masked\n        // positions, this operation will create a tensor which is 0.0 for\n        // positions we want to attend and -1e.9 for masked positions.\n        const adder =\n          mul(sub(ones(x.shape), cast(mask, x.dtype)), scalar(-1e9));\n\n        // Since we are adding it to the raw scores before the softmax, this\n        // is effectively the same as removing these entirely.\n        x = add(x, adder);\n      }\n      if (this.axis instanceof Array) {\n        if (this.axis.length > 1) {\n          return exp(sub(x, logSumExp(x, this.axis, true)));\n        } else {\n          return this.softmax(x, this.axis[0]);\n        }\n      }\n      return this.softmax(x, this.axis);\n    });\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    return inputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {axis: this.axis};\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(Softmax);\n"]}