@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
1,104 lines (1,103 loc) • 228 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* TensorFlow.js Layers: Recurrent Neural Network Layers.
*/
import * as tfc from '@tensorflow/tfjs-core';
import { serialization, tidy, util } from '@tensorflow/tfjs-core';
import { getActivation, serializeActivation } from '../activations';
import * as K from '../backend/tfjs_backend';
import { nameScope } from '../common';
import { getConstraint, serializeConstraint } from '../constraints';
import { InputSpec, SymbolicTensor } from '../engine/topology';
import { Layer } from '../engine/topology';
import { AttributeError, NotImplementedError, ValueError } from '../errors';
import { getInitializer, Initializer, Ones, serializeInitializer } from '../initializers';
import { getRegularizer, serializeRegularizer } from '../regularizers';
import { assertPositiveInteger } from '../utils/generic_utils';
import * as math_utils from '../utils/math_utils';
import { getExactlyOneShape, getExactlyOneTensor, isArrayOfShapes } from '../utils/types_utils';
import { batchGetValue, batchSetValue } from '../variables';
import { deserialize } from './serialization';
/**
* Standardize `apply()` args to a single list of tensor inputs.
*
* When running a model loaded from file, the input tensors `initialState` and
* `constants` are passed to `RNN.apply()` as part of `inputs` instead of the
* dedicated kwargs fields. `inputs` consists of
* `[inputs, initialState0, initialState1, ..., constant0, constant1]` in this
* case.
* This method makes sure that arguments are
* separated and that `initialState` and `constants` are `Array`s of tensors
* (or None).
*
* @param inputs Tensor or `Array` of tensors.
* @param initialState Tensor or `Array` of tensors or `null`/`undefined`.
* @param constants Tensor or `Array` of tensors or `null`/`undefined`.
* @returns An object consisting of
* inputs: A tensor.
* initialState: `Array` of tensors or `null`.
* constants: `Array` of tensors or `null`.
* @throws ValueError, if `inputs` is an `Array` but either `initialState` or
* `constants` is provided.
*/
export function standardizeArgs(inputs, initialState, constants, numConstants) {
if (Array.isArray(inputs)) {
if (initialState != null || constants != null) {
throw new ValueError('When inputs is an array, neither initialState or constants ' +
'should be provided');
}
if (numConstants != null) {
constants = inputs.slice(inputs.length - numConstants, inputs.length);
inputs = inputs.slice(0, inputs.length - numConstants);
}
if (inputs.length > 1) {
initialState = inputs.slice(1, inputs.length);
}
inputs = inputs[0];
}
function toListOrNull(x) {
if (x == null || Array.isArray(x)) {
return x;
}
else {
return [x];
}
}
initialState = toListOrNull(initialState);
constants = toListOrNull(constants);
return { inputs, initialState, constants };
}
/**
* Iterates over the time dimension of a tensor.
*
* @param stepFunction RNN step function.
* Parameters:
* inputs: tensor with shape `[samples, ...]` (no time dimension),
* representing input for the batch of samples at a certain time step.
* states: an Array of tensors.
* Returns:
* outputs: tensor with shape `[samples, outputDim]` (no time dimension).
* newStates: list of tensors, same length and shapes as `states`. The first
* state in the list must be the output tensor at the previous timestep.
* @param inputs Tensor of temporal data of shape `[samples, time, ...]` (at
* least 3D).
* @param initialStates Tensor with shape `[samples, outputDim]` (no time
* dimension), containing the initial values of the states used in the step
* function.
* @param goBackwards If `true`, do the iteration over the time dimension in
* reverse order and return the reversed sequence.
* @param mask Binary tensor with shape `[sample, time, 1]`, with a zero for
* every element that is masked.
* @param constants An Array of constant values passed at each step.
* @param unroll Whether to unroll the RNN or to use a symbolic loop. *Not*
* applicable to this imperative deeplearn.js backend. Its value is ignored.
* @param needPerStepOutputs Whether the per-step outputs are to be
* concatenated into a single tensor and returned (as the second return
* value). Default: `false`. This arg is included so that the relatively
* expensive concatenation of the stepwise outputs can be omitted unless
* the stepwise outputs need to be kept (e.g., for an LSTM layer of which
* `returnSequence` is `true`.)
* @returns An Array: `[lastOutput, outputs, newStates]`.
* lastOutput: the lastest output of the RNN, of shape `[samples, ...]`.
* outputs: tensor with shape `[samples, time, ...]` where each entry
* `output[s, t]` is the output of the step function at time `t` for sample
* `s`. This return value is provided if and only if the
* `needPerStepOutputs` is set as `true`. If it is set as `false`, this
* return value will be `undefined`.
* newStates: Array of tensors, latest states returned by the step function,
* of shape `(samples, ...)`.
* @throws ValueError If input dimension is less than 3.
*
* TODO(nielsene): This needs to be tidy-ed.
*/
export function rnn(stepFunction, inputs, initialStates, goBackwards = false, mask, constants, unroll = false, needPerStepOutputs = false) {
return tfc.tidy(() => {
const ndim = inputs.shape.length;
if (ndim < 3) {
throw new ValueError(`Input should be at least 3D, but is ${ndim}D.`);
}
// Transpose to time-major, i.e., from [batch, time, ...] to [time, batch,
// ...].
const axes = [1, 0].concat(math_utils.range(2, ndim));
inputs = tfc.transpose(inputs, axes);
if (constants != null) {
throw new NotImplementedError('The rnn() functoin of the deeplearn.js backend does not support ' +
'constants yet.');
}
// Porting Note: the unroll option is ignored by the imperative backend.
if (unroll) {
console.warn('Backend rnn(): the unroll = true option is not applicable to the ' +
'imperative deeplearn.js backend.');
}
if (mask != null) {
mask = tfc.cast(tfc.cast(mask, 'bool'), 'float32');
if (mask.rank === ndim - 1) {
mask = tfc.expandDims(mask, -1);
}
mask = tfc.transpose(mask, axes);
}
if (goBackwards) {
inputs = tfc.reverse(inputs, 0);
if (mask != null) {
mask = tfc.reverse(mask, 0);
}
}
// Porting Note: PyKeras with TensorFlow backend uses a symbolic loop
// (tf.while_loop). But for the imperative deeplearn.js backend, we just
// use the usual TypeScript control flow to iterate over the time steps in
// the inputs.
// Porting Note: PyKeras patches a "_use_learning_phase" attribute to
// outputs.
// This is not idiomatic in TypeScript. The info regarding whether we are
// in a learning (i.e., training) phase for RNN is passed in a different
// way.
const perStepOutputs = [];
let lastOutput;
let states = initialStates;
const timeSteps = inputs.shape[0];
const perStepInputs = tfc.unstack(inputs);
let perStepMasks;
if (mask != null) {
perStepMasks = tfc.unstack(mask);
}
for (let t = 0; t < timeSteps; ++t) {
const currentInput = perStepInputs[t];
const stepOutputs = tfc.tidy(() => stepFunction(currentInput, states));
if (mask == null) {
lastOutput = stepOutputs[0];
states = stepOutputs[1];
}
else {
const maskedOutputs = tfc.tidy(() => {
const stepMask = perStepMasks[t];
const negStepMask = tfc.sub(tfc.onesLike(stepMask), stepMask);
// TODO(cais): Would tfc.where() be better for performance?
const output = tfc.add(tfc.mul(stepOutputs[0], stepMask), tfc.mul(states[0], negStepMask));
const newStates = states.map((state, i) => {
return tfc.add(tfc.mul(stepOutputs[1][i], stepMask), tfc.mul(state, negStepMask));
});
return { output, newStates };
});
lastOutput = maskedOutputs.output;
states = maskedOutputs.newStates;
}
if (needPerStepOutputs) {
perStepOutputs.push(lastOutput);
}
}
let outputs;
if (needPerStepOutputs) {
const axis = 1;
outputs = tfc.stack(perStepOutputs, axis);
}
return [lastOutput, outputs, states];
});
}
class RNN extends Layer {
constructor(args) {
super(args);
let cell;
if (args.cell == null) {
throw new ValueError('cell property is missing for the constructor of RNN.');
}
else if (Array.isArray(args.cell)) {
cell = new StackedRNNCells({ cells: args.cell });
}
else {
cell = args.cell;
}
if (cell.stateSize == null) {
throw new ValueError('The RNN cell should have an attribute `stateSize` (tuple of ' +
'integers, one integer per RNN state).');
}
this.cell = cell;
this.returnSequences =
args.returnSequences == null ? false : args.returnSequences;
this.returnState = args.returnState == null ? false : args.returnState;
this.goBackwards = args.goBackwards == null ? false : args.goBackwards;
this._stateful = args.stateful == null ? false : args.stateful;
this.unroll = args.unroll == null ? false : args.unroll;
this.supportsMasking = true;
this.inputSpec = [new InputSpec({ ndim: 3 })];
this.stateSpec = null;
this.states_ = null;
// TODO(cais): Add constantsSpec and numConstants.
this.numConstants = null;
// TODO(cais): Look into the use of initial_state in the kwargs of the
// constructor.
this.keptStates = [];
}
// Porting Note: This is the equivalent of `RNN.states` property getter in
// PyKeras.
getStates() {
if (this.states_ == null) {
const numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
return math_utils.range(0, numStates).map(x => null);
}
else {
return this.states_;
}
}
// Porting Note: This is the equivalent of the `RNN.states` property setter in
// PyKeras.
setStates(states) {
this.states_ = states;
}
computeOutputShape(inputShape) {
if (isArrayOfShapes(inputShape)) {
inputShape = inputShape[0];
}
inputShape = inputShape;
// TODO(cais): Remove the casting once stacked RNN cells become supported.
let stateSize = this.cell.stateSize;
if (!Array.isArray(stateSize)) {
stateSize = [stateSize];
}
const outputDim = stateSize[0];
let outputShape;
if (this.returnSequences) {
outputShape = [inputShape[0], inputShape[1], outputDim];
}
else {
outputShape = [inputShape[0], outputDim];
}
if (this.returnState) {
const stateShape = [];
for (const dim of stateSize) {
stateShape.push([inputShape[0], dim]);
}
return [outputShape].concat(stateShape);
}
else {
return outputShape;
}
}
computeMask(inputs, mask) {
return tfc.tidy(() => {
if (Array.isArray(mask)) {
mask = mask[0];
}
const outputMask = this.returnSequences ? mask : null;
if (this.returnState) {
const stateMask = this.states.map(s => null);
return [outputMask].concat(stateMask);
}
else {
return outputMask;
}
});
}
/**
* Get the current state tensors of the RNN.
*
* If the state hasn't been set, return an array of `null`s of the correct
* length.
*/
get states() {
if (this.states_ == null) {
const numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
const output = [];
for (let i = 0; i < numStates; ++i) {
output.push(null);
}
return output;
}
else {
return this.states_;
}
}
set states(s) {
this.states_ = s;
}
build(inputShape) {
// Note inputShape will be an Array of Shapes of initial states and
// constants if these are passed in apply().
const constantShape = null;
if (this.numConstants != null) {
throw new NotImplementedError('Constants support is not implemented in RNN yet.');
}
if (isArrayOfShapes(inputShape)) {
inputShape = inputShape[0];
}
inputShape = inputShape;
const batchSize = this.stateful ? inputShape[0] : null;
const inputDim = inputShape.slice(2);
this.inputSpec[0] = new InputSpec({ shape: [batchSize, null, ...inputDim] });
// Allow cell (if RNNCell Layer) to build before we set or validate
// stateSpec.
const stepInputShape = [inputShape[0]].concat(inputShape.slice(2));
if (constantShape != null) {
throw new NotImplementedError('Constants support is not implemented in RNN yet.');
}
else {
this.cell.build(stepInputShape);
}
// Set or validate stateSpec.
let stateSize;
if (Array.isArray(this.cell.stateSize)) {
stateSize = this.cell.stateSize;
}
else {
stateSize = [this.cell.stateSize];
}
if (this.stateSpec != null) {
if (!util.arraysEqual(this.stateSpec.map(spec => spec.shape[spec.shape.length - 1]), stateSize)) {
throw new ValueError(`An initialState was passed that is not compatible with ` +
`cell.stateSize. Received stateSpec=${this.stateSpec}; ` +
`However cell.stateSize is ${this.cell.stateSize}`);
}
}
else {
this.stateSpec =
stateSize.map(dim => new InputSpec({ shape: [null, dim] }));
}
if (this.stateful) {
this.resetStates();
}
}
/**
* Reset the state tensors of the RNN.
*
* If the `states` argument is `undefined` or `null`, will set the
* state tensor(s) of the RNN to all-zero tensors of the appropriate
* shape(s).
*
* If `states` is provided, will set the state tensors of the RNN to its
* value.
*
* @param states Optional externally-provided initial states.
* @param training Whether this call is done during training. For stateful
* RNNs, this affects whether the old states are kept or discarded. In
* particular, if `training` is `true`, the old states will be kept so
* that subsequent backpropgataion through time (BPTT) may work properly.
* Else, the old states will be discarded.
*/
resetStates(states, training = false) {
tidy(() => {
if (!this.stateful) {
throw new AttributeError('Cannot call resetStates() on an RNN Layer that is not stateful.');
}
const batchSize = this.inputSpec[0].shape[0];
if (batchSize == null) {
throw new ValueError('If an RNN is stateful, it needs to know its batch size. Specify ' +
'the batch size of your input tensors: \n' +
'- If using a Sequential model, specify the batch size by ' +
'passing a `batchInputShape` option to your first layer.\n' +
'- If using the functional API, specify the batch size by ' +
'passing a `batchShape` option to your Input layer.');
}
// Initialize state if null.
if (this.states_ == null) {
if (Array.isArray(this.cell.stateSize)) {
this.states_ =
this.cell.stateSize.map(dim => tfc.zeros([batchSize, dim]));
}
else {
this.states_ = [tfc.zeros([batchSize, this.cell.stateSize])];
}
}
else if (states == null) {
// Dispose old state tensors.
tfc.dispose(this.states_);
// For stateful RNNs, fully dispose kept old states.
if (this.keptStates != null) {
tfc.dispose(this.keptStates);
this.keptStates = [];
}
if (Array.isArray(this.cell.stateSize)) {
this.states_ =
this.cell.stateSize.map(dim => tfc.zeros([batchSize, dim]));
}
else {
this.states_[0] = tfc.zeros([batchSize, this.cell.stateSize]);
}
}
else {
if (!Array.isArray(states)) {
states = [states];
}
if (states.length !== this.states_.length) {
throw new ValueError(`Layer ${this.name} expects ${this.states_.length} state(s), ` +
`but it received ${states.length} state value(s). Input ` +
`received: ${states}`);
}
if (training === true) {
// Store old state tensors for complete disposal later, i.e., during
// the next no-arg call to this method. We do not dispose the old
// states immediately because that BPTT (among other things) require
// them.
this.keptStates.push(this.states_.slice());
}
else {
tfc.dispose(this.states_);
}
for (let index = 0; index < this.states_.length; ++index) {
const value = states[index];
const dim = Array.isArray(this.cell.stateSize) ?
this.cell.stateSize[index] :
this.cell.stateSize;
const expectedShape = [batchSize, dim];
if (!util.arraysEqual(value.shape, expectedShape)) {
throw new ValueError(`State ${index} is incompatible with layer ${this.name}: ` +
`expected shape=${expectedShape}, received shape=${value.shape}`);
}
this.states_[index] = value;
}
}
this.states_ = this.states_.map(state => tfc.keep(state.clone()));
});
}
apply(inputs, kwargs) {
// TODO(cais): Figure out whether initialState is in kwargs or inputs.
let initialState = kwargs == null ? null : kwargs['initialState'];
let constants = kwargs == null ? null : kwargs['constants'];
if (kwargs == null) {
kwargs = {};
}
const standardized = standardizeArgs(inputs, initialState, constants, this.numConstants);
inputs = standardized.inputs;
initialState = standardized.initialState;
constants = standardized.constants;
// If any of `initial_state` or `constants` are specified and are
// `tf.SymbolicTensor`s, then add them to the inputs and temporarily modify
// the input_spec to include them.
let additionalInputs = [];
let additionalSpecs = [];
if (initialState != null) {
kwargs['initialState'] = initialState;
additionalInputs = additionalInputs.concat(initialState);
this.stateSpec = [];
for (const state of initialState) {
this.stateSpec.push(new InputSpec({ shape: state.shape }));
}
// TODO(cais): Use the following instead.
// this.stateSpec = initialState.map(state => new InputSpec({shape:
// state.shape}));
additionalSpecs = additionalSpecs.concat(this.stateSpec);
}
if (constants != null) {
kwargs['constants'] = constants;
additionalInputs = additionalInputs.concat(constants);
// TODO(cais): Add this.constantsSpec.
this.numConstants = constants.length;
}
const isTensor = additionalInputs[0] instanceof SymbolicTensor;
if (isTensor) {
// Compute full input spec, including state and constants.
const fullInput = [inputs].concat(additionalInputs);
const fullInputSpec = this.inputSpec.concat(additionalSpecs);
// Perform the call with temporarily replaced inputSpec.
const originalInputSpec = this.inputSpec;
this.inputSpec = fullInputSpec;
const output = super.apply(fullInput, kwargs);
this.inputSpec = originalInputSpec;
return output;
}
else {
return super.apply(inputs, kwargs);
}
}
// tslint:disable-next-line:no-any
call(inputs, kwargs) {
// Input shape: `[samples, time (padded with zeros), input_dim]`.
// Note that the .build() method of subclasses **must** define
// this.inputSpec and this.stateSpec owith complete input shapes.
return tidy(() => {
const mask = kwargs == null ? null : kwargs['mask'];
const training = kwargs == null ? null : kwargs['training'];
let initialState = kwargs == null ? null : kwargs['initialState'];
inputs = getExactlyOneTensor(inputs);
if (initialState == null) {
if (this.stateful) {
initialState = this.states_;
}
else {
initialState = this.getInitialState(inputs);
}
}
const numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
if (initialState.length !== numStates) {
throw new ValueError(`RNN Layer has ${numStates} state(s) but was passed ` +
`${initialState.length} initial state(s).`);
}
if (this.unroll) {
console.warn('Ignoring unroll = true for RNN layer, due to imperative backend.');
}
const cellCallKwargs = { training };
// TODO(cais): Add support for constants.
const step = (inputs, states) => {
// `inputs` and `states` are concatenated to form a single `Array` of
// `tf.Tensor`s as the input to `cell.call()`.
const outputs = this.cell.call([inputs].concat(states), cellCallKwargs);
// Marshall the return value into output and new states.
return [outputs[0], outputs.slice(1)];
};
// TODO(cais): Add support for constants.
const rnnOutputs = rnn(step, inputs, initialState, this.goBackwards, mask, null, this.unroll, this.returnSequences);
const lastOutput = rnnOutputs[0];
const outputs = rnnOutputs[1];
const states = rnnOutputs[2];
if (this.stateful) {
this.resetStates(states, training);
}
const output = this.returnSequences ? outputs : lastOutput;
// TODO(cais): Property set learning phase flag.
if (this.returnState) {
return [output].concat(states);
}
else {
return output;
}
});
}
getInitialState(inputs) {
return tidy(() => {
// Build an all-zero tensor of shape [samples, outputDim].
// [Samples, timeSteps, inputDim].
let initialState = tfc.zeros(inputs.shape);
// [Samples].
initialState = tfc.sum(initialState, [1, 2]);
initialState = K.expandDims(initialState); // [Samples, 1].
if (Array.isArray(this.cell.stateSize)) {
return this.cell.stateSize.map(dim => dim > 1 ? K.tile(initialState, [1, dim]) : initialState);
}
else {
return this.cell.stateSize > 1 ?
[K.tile(initialState, [1, this.cell.stateSize])] :
[initialState];
}
});
}
get trainableWeights() {
if (!this.trainable) {
return [];
}
// Porting Note: In TypeScript, `this` is always an instance of `Layer`.
return this.cell.trainableWeights;
}
get nonTrainableWeights() {
// Porting Note: In TypeScript, `this` is always an instance of `Layer`.
if (!this.trainable) {
return this.cell.weights;
}
return this.cell.nonTrainableWeights;
}
setFastWeightInitDuringBuild(value) {
super.setFastWeightInitDuringBuild(value);
if (this.cell != null) {
this.cell.setFastWeightInitDuringBuild(value);
}
}
getConfig() {
const baseConfig = super.getConfig();
const config = {
returnSequences: this.returnSequences,
returnState: this.returnState,
goBackwards: this.goBackwards,
stateful: this.stateful,
unroll: this.unroll,
};
if (this.numConstants != null) {
config['numConstants'] = this.numConstants;
}
const cellConfig = this.cell.getConfig();
if (this.getClassName() === RNN.className) {
config['cell'] = {
'className': this.cell.getClassName(),
'config': cellConfig,
};
}
// this order is necessary, to prevent cell name from replacing layer name
return Object.assign(Object.assign(Object.assign({}, cellConfig), baseConfig), config);
}
/** @nocollapse */
static fromConfig(cls, config, customObjects = {}) {
const cellConfig = config['cell'];
const cell = deserialize(cellConfig, customObjects);
return new cls(Object.assign(config, { cell }));
}
}
/** @nocollapse */
RNN.className = 'RNN';
export { RNN };
serialization.registerClass(RNN);
// Porting Note: This is a common parent class for RNN cells. There is no
// equivalent of this in PyKeras. Having a common parent class forgoes the
// need for `has_attr(cell, ...)` checks or its TypeScript equivalent.
/**
* An RNNCell layer.
*
* @doc {heading: 'Layers', subheading: 'Classes'}
*/
export class RNNCell extends Layer {
}
class SimpleRNNCell extends RNNCell {
constructor(args) {
super(args);
this.DEFAULT_ACTIVATION = 'tanh';
this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
this.DEFAULT_BIAS_INITIALIZER = 'zeros';
this.units = args.units;
assertPositiveInteger(this.units, `units`);
this.activation = getActivation(args.activation == null ? this.DEFAULT_ACTIVATION : args.activation);
this.useBias = args.useBias == null ? true : args.useBias;
this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
this.recurrentInitializer = getInitializer(args.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER);
this.biasInitializer =
getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
this.biasRegularizer = getRegularizer(args.biasRegularizer);
this.kernelConstraint = getConstraint(args.kernelConstraint);
this.recurrentConstraint = getConstraint(args.recurrentConstraint);
this.biasConstraint = getConstraint(args.biasConstraint);
this.dropout = math_utils.min([1, math_utils.max([0, args.dropout == null ? 0 : args.dropout])]);
this.recurrentDropout = math_utils.min([
1,
math_utils.max([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
]);
this.dropoutFunc = args.dropoutFunc;
this.stateSize = this.units;
this.dropoutMask = null;
this.recurrentDropoutMask = null;
}
build(inputShape) {
inputShape = getExactlyOneShape(inputShape);
// TODO(cais): Use regularizer.
this.kernel = this.addWeight('kernel', [inputShape[inputShape.length - 1], this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
if (this.useBias) {
this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
}
else {
this.bias = null;
}
this.built = true;
}
// Porting Note: PyKeras' equivalent of this method takes two tensor inputs:
// `inputs` and `states`. Here, the two tensors are combined into an
// `Tensor[]` Array as the first input argument.
// Similarly, PyKeras' equivalent of this method returns two values:
// `output` and `[output]`. Here the two are combined into one length-2
// `Tensor[]`, consisting of `output` repeated.
call(inputs, kwargs) {
return tidy(() => {
inputs = inputs;
if (inputs.length !== 2) {
throw new ValueError(`SimpleRNNCell expects 2 input Tensors, got ${inputs.length}.`);
}
let prevOutput = inputs[1];
inputs = inputs[0];
const training = kwargs['training'] == null ? false : kwargs['training'];
if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
this.dropoutMask = generateDropoutMask({
ones: () => tfc.onesLike(inputs),
rate: this.dropout,
training,
dropoutFunc: this.dropoutFunc,
});
}
if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
this.recurrentDropoutMask == null) {
this.recurrentDropoutMask = generateDropoutMask({
ones: () => tfc.onesLike(prevOutput),
rate: this.recurrentDropout,
training,
dropoutFunc: this.dropoutFunc,
});
}
let h;
const dpMask = this.dropoutMask;
const recDpMask = this.recurrentDropoutMask;
if (dpMask != null) {
h = K.dot(tfc.mul(inputs, dpMask), this.kernel.read());
}
else {
h = K.dot(inputs, this.kernel.read());
}
if (this.bias != null) {
h = K.biasAdd(h, this.bias.read());
}
if (recDpMask != null) {
prevOutput = tfc.mul(prevOutput, recDpMask);
}
let output = tfc.add(h, K.dot(prevOutput, this.recurrentKernel.read()));
if (this.activation != null) {
output = this.activation.apply(output);
}
// TODO(cais): Properly set learning phase on output tensor?
return [output, output];
});
}
getConfig() {
const baseConfig = super.getConfig();
const config = {
units: this.units,
activation: serializeActivation(this.activation),
useBias: this.useBias,
kernelInitializer: serializeInitializer(this.kernelInitializer),
recurrentInitializer: serializeInitializer(this.recurrentInitializer),
biasInitializer: serializeInitializer(this.biasInitializer),
kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
biasRegularizer: serializeRegularizer(this.biasRegularizer),
activityRegularizer: serializeRegularizer(this.activityRegularizer),
kernelConstraint: serializeConstraint(this.kernelConstraint),
recurrentConstraint: serializeConstraint(this.recurrentConstraint),
biasConstraint: serializeConstraint(this.biasConstraint),
dropout: this.dropout,
recurrentDropout: this.recurrentDropout,
};
return Object.assign(Object.assign({}, baseConfig), config);
}
}
/** @nocollapse */
SimpleRNNCell.className = 'SimpleRNNCell';
export { SimpleRNNCell };
serialization.registerClass(SimpleRNNCell);
class SimpleRNN extends RNN {
constructor(args) {
args.cell = new SimpleRNNCell(args);
super(args);
// TODO(cais): Add activityRegularizer.
}
call(inputs, kwargs) {
return tidy(() => {
if (this.cell.dropoutMask != null) {
tfc.dispose(this.cell.dropoutMask);
this.cell.dropoutMask = null;
}
if (this.cell.recurrentDropoutMask != null) {
tfc.dispose(this.cell.recurrentDropoutMask);
this.cell.recurrentDropoutMask = null;
}
const mask = kwargs == null ? null : kwargs['mask'];
const training = kwargs == null ? null : kwargs['training'];
const initialState = kwargs == null ? null : kwargs['initialState'];
return super.call(inputs, { mask, training, initialState });
});
}
/** @nocollapse */
static fromConfig(cls, config) {
return new cls(config);
}
}
/** @nocollapse */
SimpleRNN.className = 'SimpleRNN';
export { SimpleRNN };
serialization.registerClass(SimpleRNN);
class GRUCell extends RNNCell {
constructor(args) {
super(args);
this.DEFAULT_ACTIVATION = 'tanh';
this.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid';
this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
this.DEFAULT_BIAS_INITIALIZER = 'zeros';
if (args.resetAfter) {
throw new ValueError(`GRUCell does not support reset_after parameter set to true.`);
}
this.units = args.units;
assertPositiveInteger(this.units, 'units');
this.activation = getActivation(args.activation === undefined ? this.DEFAULT_ACTIVATION :
args.activation);
this.recurrentActivation = getActivation(args.recurrentActivation === undefined ?
this.DEFAULT_RECURRENT_ACTIVATION :
args.recurrentActivation);
this.useBias = args.useBias == null ? true : args.useBias;
this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
this.recurrentInitializer = getInitializer(args.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER);
this.biasInitializer =
getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
this.biasRegularizer = getRegularizer(args.biasRegularizer);
this.kernelConstraint = getConstraint(args.kernelConstraint);
this.recurrentConstraint = getConstraint(args.recurrentConstraint);
this.biasConstraint = getConstraint(args.biasConstraint);
this.dropout = math_utils.min([1, math_utils.max([0, args.dropout == null ? 0 : args.dropout])]);
this.recurrentDropout = math_utils.min([
1,
math_utils.max([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
]);
this.dropoutFunc = args.dropoutFunc;
this.implementation = args.implementation;
this.stateSize = this.units;
this.dropoutMask = null;
this.recurrentDropoutMask = null;
}
build(inputShape) {
inputShape = getExactlyOneShape(inputShape);
const inputDim = inputShape[inputShape.length - 1];
this.kernel = this.addWeight('kernel', [inputDim, this.units * 3], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 3], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
if (this.useBias) {
this.bias = this.addWeight('bias', [this.units * 3], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
}
else {
this.bias = null;
}
// Porting Notes: Unlike the PyKeras implementation, we perform slicing
// of the weights and bias in the call() method, at execution time.
this.built = true;
}
call(inputs, kwargs) {
return tidy(() => {
inputs = inputs;
if (inputs.length !== 2) {
throw new ValueError(`GRUCell expects 2 input Tensors (inputs, h, c), got ` +
`${inputs.length}.`);
}
const training = kwargs['training'] == null ? false : kwargs['training'];
let hTMinus1 = inputs[1]; // Previous memory state.
inputs = inputs[0];
// Note: For superior performance, TensorFlow.js always uses
// implementation 2, regardless of the actual value of
// config.implementation.
if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
this.dropoutMask = generateDropoutMask({
ones: () => tfc.onesLike(inputs),
rate: this.dropout,
training,
count: 3,
dropoutFunc: this.dropoutFunc,
});
}
if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
this.recurrentDropoutMask == null) {
this.recurrentDropoutMask = generateDropoutMask({
ones: () => tfc.onesLike(hTMinus1),
rate: this.recurrentDropout,
training,
count: 3,
dropoutFunc: this.dropoutFunc,
});
}
const dpMask = this.dropoutMask;
const recDpMask = this.recurrentDropoutMask;
let z;
let r;
let hh;
if (0 < this.dropout && this.dropout < 1) {
inputs = tfc.mul(inputs, dpMask[0]);
}
let matrixX = K.dot(inputs, this.kernel.read());
if (this.useBias) {
matrixX = K.biasAdd(matrixX, this.bias.read());
}
if (0 < this.recurrentDropout && this.recurrentDropout < 1) {
hTMinus1 = tfc.mul(hTMinus1, recDpMask[0]);
}
const recurrentKernelValue = this.recurrentKernel.read();
const [rk1, rk2] = tfc.split(recurrentKernelValue, [2 * this.units, this.units], recurrentKernelValue.rank - 1);
const matrixInner = K.dot(hTMinus1, rk1);
const [xZ, xR, xH] = tfc.split(matrixX, 3, matrixX.rank - 1);
const [recurrentZ, recurrentR] = tfc.split(matrixInner, 2, matrixInner.rank - 1);
z = this.recurrentActivation.apply(tfc.add(xZ, recurrentZ));
r = this.recurrentActivation.apply(tfc.add(xR, recurrentR));
const recurrentH = K.dot(tfc.mul(r, hTMinus1), rk2);
hh = this.activation.apply(tfc.add(xH, recurrentH));
const h = tfc.add(tfc.mul(z, hTMinus1), tfc.mul(tfc.add(1, tfc.neg(z)), hh));
// TODO(cais): Add use_learning_phase flag properly.
return [h, h];
});
}
getConfig() {
const baseConfig = super.getConfig();
const config = {
units: this.units,
activation: serializeActivation(this.activation),
recurrentActivation: serializeActivation(this.recurrentActivation),
useBias: this.useBias,
kernelInitializer: serializeInitializer(this.kernelInitializer),
recurrentInitializer: serializeInitializer(this.recurrentInitializer),
biasInitializer: serializeInitializer(this.biasInitializer),
kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
biasRegularizer: serializeRegularizer(this.biasRegularizer),
activityRegularizer: serializeRegularizer(this.activityRegularizer),
kernelConstraint: serializeConstraint(this.kernelConstraint),
recurrentConstraint: serializeConstraint(this.recurrentConstraint),
biasConstraint: serializeConstraint(this.biasConstraint),
dropout: this.dropout,
recurrentDropout: this.recurrentDropout,
implementation: this.implementation,
resetAfter: false
};
return Object.assign(Object.assign({}, baseConfig), config);
}
}
/** @nocollapse */
GRUCell.className = 'GRUCell';
export { GRUCell };
serialization.registerClass(GRUCell);
class GRU extends RNN {
constructor(args) {
if (args.implementation === 0) {
console.warn('`implementation=0` has been deprecated, and now defaults to ' +
'`implementation=1`. Please update your layer call.');
}
args.cell = new GRUCell(args);
super(args);
// TODO(cais): Add activityRegularizer.
}
call(inputs, kwargs) {
return tidy(() => {
if (this.cell.dropoutMask != null) {
tfc.dispose(this.cell.dropoutMask);
this.cell.dropoutMask = null;
}
if (this.cell.recurrentDropoutMask != null) {
tfc.dispose(this.cell.recurrentDropoutMask);
this.cell.recurrentDropoutMask = null;
}
const mask = kwargs == null ? null : kwargs['mask'];
const training = kwargs == null ? null : kwargs['training'];
const initialState = kwargs == null ? null : kwargs['initialState'];
return super.call(inputs, { mask, training, initialState });
});
}
/** @nocollapse */
static fromConfig(cls, config) {
if (config['implmentation'] === 0) {
config['implementation'] = 1;
}
return new cls(config);
}
}
/** @nocollapse */
GRU.className = 'GRU';
export { GRU };
serialization.registerClass(GRU);
class LSTMCell extends RNNCell {
constructor(args) {
super(args);
this.DEFAULT_ACTIVATION = 'tanh';
this.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid';
this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
this.DEFAULT_BIAS_INITIALIZER = 'zeros';
this.units = args.units;
assertPositiveInteger(this.units, 'units');
this.activation = getActivation(args.activation === undefined ? this.DEFAULT_ACTIVATION :
args.activation);
this.recurrentActivation = getActivation(args.recurrentActivation === undefined ?
this.DEFAULT_RECURRENT_ACTIVATION :
args.recurrentActivation);
this.useBias = args.useBias == null ? true : args.useBias;
this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
this.recurrentInitializer = getInitializer(args.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER);
this.biasInitializer =
getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
this.unitForgetBias = args.unitForgetBias;
this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
this.biasRegularizer = getRegularizer(args.biasRegularizer);
this.kernelConstraint = getConstraint(args.kernelConstraint);
this.recurrentConstraint = getConstraint(args.recurrentConstraint);
this.biasConstraint = getConstraint(args.biasConstraint);
this.dropout = math_utils.min([1, math_utils.max([0, args.dropout == null ? 0 : args.dropout])]);
this.recurrentDropout = math_utils.min([
1,
math_utils.max([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
]);
this.dropoutFunc = args.dropoutFunc;
this.implementation = args.implementation;
this.stateSize = [this.units, this.units];
this.dropoutMask = null;
this.recurrentDropoutMask = null;
}
build(inputShape) {
var _a;
inputShape = getExactlyOneShape(inputShape);
const inputDim = inputShape[inputShape.length - 1];
this.kernel = this.addWeight('kernel', [inputDim, this.units * 4], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 4], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
let biasInitializer;
if (this.useBias) {
if (this.unitForgetBias) {
const capturedBiasInit = this.biasInitializer;
const capturedUnits = this.units;
biasInitializer = new (_a = class CustomInit extends Initializer {
apply(shape, dtype) {
// TODO(cais): More informative variable names?
const bI = capturedBiasInit.apply([capturedUnits]);
const bF = (new Ones()).apply([capturedUnits]);
const bCAndH = capturedBiasInit.apply([capturedUnits * 2]);
return K.concatAlongFirstAxis(K.concatAlongFirstAxis(bI, bF), bCAndH);
}
},
/** @nocollapse */
_a.className = 'CustomInit',
_a)();
}
else {
biasInitializer = this.biasInitializer;
}
this.bias = this.addWeight('bias', [this.units * 4], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint);
}
else {
this.bias = null;
}
// Porting Notes: Unlike the PyKeras implementation, we perform slicing
// of the weights and bias in the call() method, at execution time.
this.built = true;
}
call(inputs, kwargs) {
return tidy(() => {
const training = kwargs['training'] == null ? false : kwargs['training'];
inputs = inputs;
if (inputs.length !== 3) {
throw new ValueError(`LSTMCell expects 3 input Tensors (inputs, h, c), got ` +
`${inputs.length}.`);
}
let hTMinus1 = inputs[1]; // Previous memory state.
const cTMinus1 = inputs[2]; // Previous carry state.
inputs = inputs[0];
if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
this.dropoutMask = generateDropoutMask({
ones: () => tfc.onesLike(inputs),
rate: this.dropout,
training,
count: 4,
dropoutFunc: this.dropoutFunc
});
}
if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
this.recurrentDropoutMask == null) {
this.recurrentDropoutMask = generateDropoutMask({
ones: () => tfc.onesLike(hTMinus1),
rate: this.recurrentDropout,
training,
count: 4,
dropoutFunc: this.dropoutFunc
});
}
const dpMask = this.dropoutMask;
const recDpMask = this.recurrentDropoutMask;
// Note: For superior performance, TensorFlow.js always uses
// implementation 2 regardless of the actual value of
// config.implementation.
let i;
let f;
let c;
let o;
if (0 < this.dropout && this.dropout < 1) {
inputs = tfc.mul(inputs, dpMask[0]);
}
let z = K.dot(inputs, this.kernel.read());
if (0 < this.recurrentDropout && this.recurrentDropout < 1) {
hTMinus1 = tfc.mul(hTMinus1, recDpMask[0]);
}
z = tfc.add(z, K.dot(hTMinus1, this.recurrentKernel.read()));
if (this.useBias) {
z = K.biasAdd(z, this.bias.read());