UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

281 lines 43.9 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * TFJS-based einsum dense layer. */ /* Original source: keras/layers/core/einsum_dense.py */ import { einsum, serialization, tidy } from '@tensorflow/tfjs-core'; import { getActivation, serializeActivation } from '../../activations'; import { getConstraint, serializeConstraint } from '../../constraints'; import { Layer } from '../../engine/topology'; import { ValueError } from '../../errors'; import { getInitializer, serializeInitializer } from '../../initializers'; import { getRegularizer, serializeRegularizer } from '../../regularizers'; /** * Analyzes an einsum string to determine the required weight shape. */ export function analyzeEinsumString(equation, biasAxes, inputShape, outputShape) { const dotReplacedString = equation.replace(/\.\.\./g, '0'); // This is the case where no ellipses are present in the string. let splitString = dotReplacedString.match(/([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)/); if (splitString) { return analyzeSplitString(splitString, biasAxes, inputShape, outputShape); } // This is the case where ellipses are present on the left. splitString = dotReplacedString.match(/0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)/); if (splitString) { return analyzeSplitString(splitString, biasAxes, inputShape, outputShape, true); } // This is the case where ellipses are present on the right. splitString = dotReplacedString.match(/([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0/); if (splitString) { return analyzeSplitString(splitString, biasAxes, inputShape, outputShape); } throw new ValueError(`Invalid einsum equation '${equation}'. Equations must be in the form ` + '[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]....'); } /** * Analyze an pre-split einsum string to find the weight shape. */ export function analyzeSplitString(splitString, biasAxes, inputShape, outputShape, leftElided = false) { const inputSpec = splitString[1]; const weightSpec = splitString[2]; const outputSpec = splitString[3]; const elided = inputShape.length - inputSpec.length; const newOutputShape = Array.isArray(outputShape) ? outputShape.slice() : [outputShape]; newOutputShape.unshift(inputShape[0]); if (elided > 0 && leftElided) { for (let i = 1; i < elided; i++) { // We already inserted the 0th input dimension at dim 0, so we need // to start at location 1 here. newOutputShape.splice(1, 0, inputShape[i]); } } else if (elided > 0 && !leftElided) { for (let i = inputShape.length - elided; i < inputShape.length; i++) { newOutputShape.push(inputShape[i]); } } const inputSpecArr = Array.from(inputSpec); const outputSpecArr = Array.from(outputSpec); let inputDimMap, outputDimMap; if (leftElided) { // If we have beginning dimensions elided, we need to use negative // indexing to determine where in the input dimension our values are. inputDimMap = new Map(inputSpecArr.map((dim, i) => { // This converts any negative indices to positive ones. const idx = i + elided - inputShape.length; const positiveIdx = ((idx % inputShape.length) + inputShape.length) % inputShape.length; return [dim, positiveIdx]; })); // Because we've constructed the full output shape already, we don't need // to do negative indexing. outputDimMap = new Map(outputSpecArr.map((dim, i) => [dim, i + elided])); } else { inputDimMap = new Map(inputSpecArr.map((dim, i) => [dim, i])); outputDimMap = new Map(outputSpecArr.map((dim, i) => [dim, i])); } for (const dim of inputSpec) { const inputShapeAtDim = inputShape[inputDimMap.get(dim)]; if (outputDimMap.has(dim)) { const outputShapeAtDim = newOutputShape[outputDimMap.get(dim)]; if (outputShapeAtDim !== null && outputShapeAtDim !== inputShapeAtDim) { throw new ValueError(`Input shape and output shape do not match at shared dimension ` + `'${dim}'. Input shape is ${inputShapeAtDim}, and output shape ` + `is ${outputShapeAtDim}.`); } } } for (const dim of outputSpec) { if (!inputSpec.includes(dim) && !weightSpec.includes(dim)) { throw new ValueError(`Dimension '${dim}' was specified in the output '${outputSpec}' ` + `but has no corresponding dimension in the input spec ` + `'${inputSpec}' or weight spec '${weightSpec}'`); } } const weightShape = []; for (const dim of weightSpec) { if (inputDimMap.has(dim)) { weightShape.push(inputShape[inputDimMap.get(dim)]); } else if (outputDimMap.has(dim)) { weightShape.push(newOutputShape[outputDimMap.get(dim)]); } else { throw new ValueError(`Weight dimension '${dim}' did not have a match in either the ` + `input spec '${inputSpec}' or the output spec '${outputSpec}'. For ` + `this layer, the weight must be fully specified.`); } } let biasShape; if (biasAxes != null) { const numLeftElided = leftElided ? elided : 0; const idxMap = {}; for (let i = 0; i < outputSpec.length; i++) { idxMap[outputSpec[i]] = newOutputShape[i + numLeftElided]; } for (const char of biasAxes) { if (!outputSpec.includes(char)) { throw new ValueError(`Bias dimension '${char}' was requested, but is not part of the ` + `output spec '${outputSpec}'`); } } const firstBiasLocation = Math.min(...biasAxes.split('').map(char => outputSpec.indexOf(char))); const biasOutputSpec = outputSpec.slice(firstBiasLocation); biasShape = biasOutputSpec.split('').map(char => biasAxes.includes(char) ? idxMap[char] : 1); if (!leftElided) { for (let i = 0; i < elided; i++) { biasShape.push(1); } } } else { biasShape = null; } return [weightShape, biasShape, newOutputShape]; } /** * A layer that uses `tf.einsum` as the backing computation. * * This layer can perform einsum calculations of arbitrary dimensionality. * * Examples: * * **Biased dense layer with einsums** * * This example shows how to instantiate a standard Keras dense layer using * einsum operations. This example is equivalent to * tf.layers.Dense({units: 64, useBias: true})`. * * const layer = new EinsumDense({ * equation: "ab,bc->ac", outputShape: 4, biasAxes: "c"}); * const inputTensor = tf.input({shape: [32]}); * const outputTensor = layer.call(inputTensor); * console.log(outputTensor); // [null, 64] * * **Applying a dense layer to a sequence** * * This example shows how to instantiate a layer that applies the same dense * operation to every element in a sequence. Here, the `outputShape` has two * values (since there are two non-batch dimensions in the output); the first * dimension in the `outputShape` is `null`, because the sequence dimension * `b` has an unknown shape. * * ```js * const layer = new EinsumDense({ * equation: "abc,cd->abd", outputShape: [null, 64], biasAxes: "d"}); * const inputTensor = tf.input({shape: [32, 128]}); * const outputTensor = layer.call(inputTensor); * console.log(outputTensor); // [null, 32, 64] * ``` * * **Applying a dense layer to a sequence using ellipses** * * This example shows how to instantiate a layer that applies the same dense * operation to every element in a sequence, but uses the ellipsis notation * instead of specifying the batch and sequence dimensions. * * Because we are using ellipsis notation and have specified only one axis, the * `outputShape` arg is a single value. When instantiated in this way, the * layer can handle any number of sequence dimensions - including the case * where no sequence dimension exists. * * ```js * const layer = new EinsumDense({ * equation: "...x,xy->...y", outputShape: 64, biasAxes: "y"}); * const inputTensor = tf.input({shape: [32, 128]}); * const outputTensor = layer.call(inputTensor); * console.log(outputTensor); // [null, 32, 64] * `` */ class EinsumDense extends Layer { constructor(args) { var _a, _b; super(args); this.equation = args.equation; this.biasAxes = args.biasAxes; this.partialOutputShape = Array.isArray(args.outputShape) ? args.outputShape : [args.outputShape]; this.activation = getActivation(args.activation); this.kernelInitializer = getInitializer((_a = args.kernelInitializer) !== null && _a !== void 0 ? _a : 'glorotUniform'); this.biasInitializer = getInitializer((_b = args.biasInitializer) !== null && _b !== void 0 ? _b : 'zeros'); this.kernelRegularizer = getRegularizer(args.kernelRegularizer); this.biasRegularizer = getRegularizer(args.biasRegularizer); this.kernelConstraint = getConstraint(args.kernelConstraint); this.biasConstraint = getConstraint(args.biasConstraint); } get kernel() { return this._kernel; } get bias() { return this._bias; } build(inputShape) { const [kernelShape, biasShape, fullOutputShape] = analyzeEinsumString(this.equation, this.biasAxes, inputShape, this.partialOutputShape); this.fullOutputShape = fullOutputShape; this._kernel = this.addWeight('kernel', kernelShape, this.dtype, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint); if (biasShape != null) { this._bias = this.addWeight('bias', biasShape, this.dtype, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint); } else { this._bias = null; } super.build(inputShape); } computeOutputShape(_) { return this.fullOutputShape; } getConfig() { const config = { outputShape: this.partialOutputShape, equation: this.equation, activation: serializeActivation(this.activation), biasAxes: this.biasAxes, kernelInitializer: serializeInitializer(this.kernelInitializer), biasInitializer: serializeInitializer(this.biasInitializer), kernelRegularizer: serializeRegularizer(this.kernelRegularizer), biasRegularizer: serializeRegularizer(this.biasRegularizer), kernelConstraint: serializeConstraint(this.kernelConstraint), biasConstraint: serializeConstraint(this.biasConstraint), }; const baseConfig = super.getConfig(); Object.assign(config, baseConfig); return config; } call(inputs, kwargs) { return tidy(() => { inputs = Array.isArray(inputs) ? inputs : [inputs]; let ret = einsum(this.equation, ...inputs, this.kernel.read()); if (this.bias != null) { ret = ret.add(this.bias.read()); } if (this.activation != null) { ret = this.activation.apply(ret); } return ret; }); } } /** @nocollapse */ EinsumDense.className = 'EinsumDense'; export { EinsumDense }; serialization.registerClass(EinsumDense); //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"einsum_dense.js","sourceRoot":"","sources":["../../../../../../../tfjs-layers/src/layers/nlp/einsum_dense.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,wDAAwD;AACxD,OAAO,EAAoB,MAAM,EAAE,aAAa,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAEtF,OAAO,EAAc,aAAa,EAAE,mBAAmB,EAAE,MAAM,mBAAmB,CAAC;AACnF,OAAO,EAAoC,aAAa,EAAE,mBAAmB,EAAE,MAAM,mBAAmB,CAAC;AACzG,OAAO,EAAE,KAAK,EAAa,MAAM,uBAAuB,CAAC;AACzD,OAAO,EAAE,UAAU,EAAE,MAAM,cAAc,CAAC;AAC1C,OAAO,EAAsC,cAAc,EAAE,oBAAoB,EAAE,MAAM,oBAAoB,CAAC;AAG9G,OAAO,EAAsC,cAAc,EAAE,oBAAoB,EAAE,MAAM,oBAAoB,CAAC;AAI9G;;GAEG;AACH,MAAM,UAAU,mBAAmB,CACjC,QAAgB,EAChB,QAAgB,EAChB,UAAiB,EACjB,WAAkB;IAElB,MAAM,iBAAiB,GAAG,QAAQ,CAAC,OAAO,CAAC,SAAS,EAAE,GAAG,CAAC,CAAC;IAE3D,gEAAgE;IAChE,IAAI,WAAW,GACb,iBAAiB,CAAC,KAAK,CAAC,sCAAsC,CAAC,CAAC;IAClE,IAAI,WAAW,EAAE;QACf,OAAO,kBAAkB,CACvB,WAAW,EAAE,QAAQ,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC;KACnD;IAED,2DAA2D;IAC3D,WAAW;QACT,iBAAiB,CAAC,KAAK,CAAC,wCAAwC,CAAC,CAAC;IACpE,IAAI,WAAW,EAAE;QACf,OAAO,kBAAkB,CACvB,WAAW,EAAE,QAAQ,EAAE,UAAU,EAAE,WAAW,EAAE,IAAI,CAAC,CAAC;KACzD;IAED,4DAA4D;IAC5D,WAAW;QACT,iBAAiB,CAAC,KAAK,CAAC,2CAA2C,CAAC,CAAC;IACvE,IAAI,WAAW,EAAE;QACf,OAAO,kBAAkB,CACvB,WAAW,EAAE,QAAQ,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC;KACnD;IAED,MAAM,IAAI,UAAU,CAClB,4BAA4B,QAAQ,mCAAmC;QACvE,0DAA0D,CAC3D,CAAC;AACJ,CAAC;AAED;;GAEG;AACH,MAAM,UAAU,kBAAkB,CAChC,WAA6B,EAC7B,QAAgB,EAChB,UAAiB,EACjB,WAAyB,EACzB,UAAU,GAAG,KAAK;IAElB,MAAM,SAAS,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC;IACjC,MAAM,UAAU,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC;IAClC,MAAM,UAAU,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC;IAClC,MAAM,MAAM,GAAG,UAAU,CAAC,MAAM,GAAG,SAAS,CAAC,MAAM,CAAC;IAEpD,MAAM,cAAc,GAAU,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,CAAC;QACxD,WAAW,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,CAAC,WAAW,CAAC,CAAC;IACtC,cAAc,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC;IAEtC,IAAI,MAAM,GAAG,CAAC,IAAI,UAAU,EAAE;QAC5B,KAAI,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,EAAE,CAAC,EAAE,EAAE;YAC9B,mEAAmE;YACnE,+BAA+B;YAC/B,cAAc,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC;SAC5C;KACF;SAAM,IAAI,MAAM,GAAG,CAAC,IAAI,CAAC,UAAU,EAAE;QACpC,KAAI,IAAI,CAAC,GAAG,UAAU,CAAC,MAAM,GAAG,MAAM,EAAE,CAAC,GAAG,UAAU,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;YAClE,cAAc,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC;SACpC;KACF;IAED,MAAM,YAAY,GAAG,KAAK,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;IAC3C,MAAM,aAAa,GAAG,KAAK,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;IAC7C,IAAI,WAAW,EAAE,YAAY,CAAC;IAE9B,IAAI,UAAU,EAAE;QACd,kEAAkE;QAClE,qEAAqE;QACrE,WAAW,GAAG,IAAI,GAAG,CACnB,YAAY,CAAC,GAAG,CAAC,CAAC,GAAG,EAAE,CAAC,EAAE,EAAE;YAC1B,uDAAuD;YACvD,MAAM,GAAG,GAAG,CAAC,GAAG,MAAM,GAAG,UAAU,CAAC,MAAM,CAAC;YAC3C,MAAM,WAAW,GACf,CAAC,CAAC,GAAG,GAAG,UAAU,CAAC,MAAM,CAAC,GAAG,UAAU,CAAC,MAAM,CAAC,GAAG,UAAU,CAAC,MAAM,CAAC;YACtE,OAAO,CAAC,GAAG,EAAE,WAAW,CAAC,CAAC;QAC5B,CAAC,CAAC,CACH,CAAC;QAEF,yEAAyE;QACzE,2BAA2B;QAC3B,YAAY,GAAG,IAAI,GAAG,CACpB,aAAa,CAAC,GAAG,CAAC,CAAC,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,EAAE,CAAC,GAAG,MAAM,CAAC,CAAC,CACjD,CAAC;KACH;SAAM;QACL,WAAW,GAAG,IAAI,GAAG,CACnB,YAAY,CAAC,GAAG,CAAC,CAAC,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,CACvC,CAAC;QACF,YAAY,GAAG,IAAI,GAAG,CACpB,aAAa,CAAC,GAAG,CAAC,CAAC,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,CACxC,CAAC;KACH;IAED,KAAK,MAAM,GAAG,IAAI,SAAS,EAAE;QAC3B,MAAM,eAAe,GAAG,UAAU,CAAC,WAAW,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;QACzD,IAAI,YAAY,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE;YACzB,MAAM,gBAAgB,GAAG,cAAc,CAAC,YAAY,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;YAC/D,IAAI,gBAAgB,KAAK,IAAI,IAAI,gBAAgB,KAAK,eAAe,EAAE;gBACrE,MAAM,IAAI,UAAU,CAClB,gEAAgE;oBAChE,IAAI,GAAG,qBAAqB,eAAe,qBAAqB;oBAChE,MAAM,gBAAgB,GAAG,CAC1B,CAAC;aACH;SACF;KACF;IAED,KAAK,MAAM,GAAG,IAAI,UAAU,EAAE;QAC5B,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAC,GAAG,CAAC,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAC,GAAG,CAAC,EAAE;YACzD,MAAM,IAAI,UAAU,CAClB,cAAc,GAAG,kCAAkC,UAAU,IAAI;gBACjE,uDAAuD;gBACvD,IAAI,SAAS,qBAAqB,UAAU,GAAG,CAChD,CAAC;SACH;KACF;IAED,MAAM,WAAW,GAAU,EAAE,CAAC;IAC9B,KAAK,MAAM,GAAG,IAAI,UAAU,EAAE;QAC5B,IAAI,WAAW,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE;YACxB,WAAW,CAAC,IAAI,CAAC,UAAU,CAAC,WAAW,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;SACpD;aAAM,IAAI,YAAY,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE;YAChC,WAAW,CAAC,IAAI,CAAC,cAAc,CAAC,YAAY,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;SACzD;aAAM;YACL,MAAM,IAAI,UAAU,CAClB,qBAAqB,GAAG,uCAAuC;gBAC/D,eAAe,SAAS,yBAAyB,UAAU,SAAS;gBACpE,iDAAiD,CAClD,CAAC;SACH;KACF;IAED,IAAI,SAAgB,CAAC;IACrB,IAAI,QAAQ,IAAI,IAAI,EAAE;QACpB,MAAM,aAAa,GAAG,UAAU,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;QAC9C,MAAM,MAAM,GAA+B,EAAE,CAAC;QAC9C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;YAC1C,MAAM,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,GAAG,cAAc,CAAC,CAAC,GAAG,aAAa,CAAC,CAAC;SAC3D;QAED,KAAK,MAAM,IAAI,IAAI,QAAQ,EAAE;YAC3B,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAC,IAAI,CAAC,EAAE;gBAC9B,MAAM,IAAI,UAAU,CAClB,mBAAmB,IAAI,0CAA0C;oBACjE,gBAAgB,UAAU,GAAG,CAC9B,CAAC;aACH;SACF;QAED,MAAM,iBAAiB,GAAG,IAAI,CAAC,GAAG,CAChC,GAAG,QAAQ,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,UAAU,CAAC,OAAO,CAAC,IAAI,CAAC,CAAC,CAC5D,CAAC;QACF,MAAM,cAAc,GAAG,UAAU,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;QAE3D,SAAS,GAAG,cAAc,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAC9C,QAAQ,CAAC,QAAQ,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAC3C,CAAC;QAEF,IAAI,CAAC,UAAU,EAAE;YACf,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,EAAE,CAAC,EAAE,EAAE;gBAC/B,SAAS,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;aACnB;SACF;KACF;SAAM;QACL,SAAS,GAAG,IAAI,CAAC;KAClB;IACD,OAAO,CAAC,WAAW,EAAE,SAAS,EAAE,cAAc,CAAC,CAAC;AAClD,CAAC;AAqED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAqDG;AACH,MAAa,WAAY,SAAQ,KAAK;IAiBpC,YAAY,IAAqB;;QAC/B,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAC9B,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAC9B,IAAI,CAAC,kBAAkB;YACrB,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC;QAC1E,IAAI,CAAC,UAAU,GAAG,aAAa,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;QACjD,IAAI,CAAC,iBAAiB,GAAG,cAAc,CACrC,MAAA,IAAI,CAAC,iBAAiB,mCAAI,eAAe,CAAC,CAAC;QAC7C,IAAI,CAAC,eAAe,GAAG,cAAc,CAAC,MAAA,IAAI,CAAC,eAAe,mCAAI,OAAO,CAAC,CAAC;QACvE,IAAI,CAAC,iBAAiB,GAAG,cAAc,CAAC,IAAI,CAAC,iBAAiB,CAAC,CAAC;QAChE,IAAI,CAAC,eAAe,GAAG,cAAc,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;QAC5D,IAAI,CAAC,gBAAgB,GAAG,aAAa,CAAC,IAAI,CAAC,gBAAgB,CAAC,CAAC;QAC7D,IAAI,CAAC,cAAc,GAAG,aAAa,CAAC,IAAI,CAAC,cAAc,CAAC,CAAC;IAC3D,CAAC;IAED,IAAI,MAAM;QACR,OAAO,IAAI,CAAC,OAAO,CAAC;IACtB,CAAC;IAED,IAAI,IAAI;QACN,OAAO,IAAI,CAAC,KAAK,CAAC;IACpB,CAAC;IAEQ,KAAK,CAAC,UAAiB;QAC9B,MAAM,CAAC,WAAW,EAAE,SAAS,EAAE,eAAe,CAAC,GAAG,mBAAmB,CACnE,IAAI,CAAC,QAAQ,EACb,IAAI,CAAC,QAAQ,EACb,UAAU,EACV,IAAI,CAAC,kBAAkB,CACxB,CAAC;QACF,IAAI,CAAC,eAAe,GAAG,eAAe,CAAC;QACvC,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,SAAS,CAC3B,QAAQ,EACR,WAAW,EACX,IAAI,CAAC,KAAK,EACV,IAAI,CAAC,iBAAiB,EACtB,IAAI,CAAC,iBAAiB,EACtB,IAAI,EACJ,IAAI,CAAC,gBAAgB,CACtB,CAAC;QAEF,IAAI,SAAS,IAAI,IAAI,EAAE;YACrB,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,SAAS,CACzB,MAAM,EACN,SAAS,EACT,IAAI,CAAC,KAAK,EACV,IAAI,CAAC,eAAe,EACpB,IAAI,CAAC,eAAe,EACpB,IAAI,EACJ,IAAI,CAAC,cAAc,CACpB,CAAC;SACH;aAAM;YACL,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC;SACnB;QACD,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC1B,CAAC;IAEQ,kBAAkB,CAAC,CAAQ;QAClC,OAAO,IAAI,CAAC,eAAe,CAAC;IAC9B,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAAG;YACb,WAAW,EAAE,IAAI,CAAC,kBAAkB;YACpC,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,UAAU,EAAE,mBAAmB,CAAC,IAAI,CAAC,UAAU,CAAC;YAChD,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,iBAAiB,EAAE,oBAAoB,CAAC,IAAI,CAAC,iBAAiB,CAAC;YAC/D,eAAe,EAAE,oBAAoB,CAAC,IAAI,CAAC,eAAe,CAAC;YAC3D,iBAAiB,EAAE,oBAAoB,CAAC,IAAI,CAAC,iBAAiB,CAAC;YAC/D,eAAe,EAAE,oBAAoB,CAAC,IAAI,CAAC,eAAe,CAAC;YAC3D,gBAAgB,EAAE,mBAAmB,CAAC,IAAI,CAAC,gBAAgB,CAAC;YAC5D,cAAc,EAAE,mBAAmB,CAAC,IAAI,CAAC,cAAc,CAAC;SACzD,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,MAAM,GAAG,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;YACnD,IAAI,GAAG,GAAG,MAAM,CAAC,IAAI,CAAC,QAAQ,EAAE,GAAG,MAAM,EAAE,IAAI,CAAC,MAAM,CAAC,IAAI,EAAE,CAAC,CAAC;YAC/D,IAAI,IAAI,CAAC,IAAI,IAAI,IAAI,EAAE;gBACrB,GAAG,GAAG,GAAG,CAAC,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC;aACjC;YACD,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;gBAC3B,GAAG,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC;aAClC;YACD,OAAO,GAAG,CAAC;QACb,CAAC,CAAC,CAAC;IACL,CAAC;;AA5GD,kBAAkB;AACF,qBAAS,GAAG,aAAa,CAAC;SAF/B,WAAW;AA+GxB,aAAa,CAAC,aAAa,CAAC,WAAW,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  TFJS-based einsum dense layer.\n */\n\n/* Original source: keras/layers/core/einsum_dense.py */\nimport { Tensor, Tensor2D, einsum, serialization, tidy } from '@tensorflow/tfjs-core';\n\nimport { Activation, getActivation, serializeActivation } from '../../activations';\nimport { Constraint, ConstraintIdentifier, getConstraint, serializeConstraint } from '../../constraints';\nimport { Layer, LayerArgs } from '../../engine/topology';\nimport { ValueError } from '../../errors';\nimport { Initializer, InitializerIdentifier, getInitializer, serializeInitializer } from '../../initializers';\nimport { ActivationIdentifier } from '../../keras_format/activation_config';\nimport { Shape } from '../../keras_format/common';\nimport { Regularizer, RegularizerIdentifier, getRegularizer, serializeRegularizer } from '../../regularizers';\nimport { Kwargs } from '../../types';\nimport { LayerVariable } from '../../variables';\n\n/**\n * Analyzes an einsum string to determine the required weight shape.\n */\nexport function analyzeEinsumString(\n  equation: string,\n  biasAxes: string,\n  inputShape: Shape,\n  outputShape: Shape\n): [Shape, Shape, Shape] {\n  const dotReplacedString = equation.replace(/\\.\\.\\./g, '0');\n\n  // This is the case where no ellipses are present in the string.\n  let splitString =\n    dotReplacedString.match(/([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)/);\n  if (splitString) {\n    return analyzeSplitString(\n      splitString, biasAxes, inputShape, outputShape);\n  }\n\n  // This is the case where ellipses are present on the left.\n  splitString =\n    dotReplacedString.match(/0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)/);\n  if (splitString) {\n    return analyzeSplitString(\n      splitString, biasAxes, inputShape, outputShape, true);\n  }\n\n  // This is the case where ellipses are present on the right.\n  splitString =\n    dotReplacedString.match(/([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0/);\n  if (splitString) {\n    return analyzeSplitString(\n      splitString, biasAxes, inputShape, outputShape);\n  }\n\n  throw new ValueError(\n    `Invalid einsum equation '${equation}'. Equations must be in the form ` +\n    '[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]....'\n  );\n}\n\n/**\n * Analyze an pre-split einsum string to find the weight shape.\n */\nexport function analyzeSplitString(\n  splitString: RegExpMatchArray,\n  biasAxes: string,\n  inputShape: Shape,\n  outputShape: Shape|number,\n  leftElided = false\n): [Shape, Shape, Shape] {\n  const inputSpec = splitString[1];\n  const weightSpec = splitString[2];\n  const outputSpec = splitString[3];\n  const elided = inputShape.length - inputSpec.length;\n\n  const newOutputShape: Shape = Array.isArray(outputShape) ?\n    outputShape.slice() : [outputShape];\n  newOutputShape.unshift(inputShape[0]);\n\n  if (elided > 0 && leftElided) {\n    for(let i = 1; i < elided; i++) {\n      // We already inserted the 0th input dimension at dim 0, so we need\n      // to start at location 1 here.\n      newOutputShape.splice(1, 0, inputShape[i]);\n    }\n  } else if (elided > 0 && !leftElided) {\n    for(let i = inputShape.length - elided; i < inputShape.length; i++) {\n      newOutputShape.push(inputShape[i]);\n    }\n  }\n\n  const inputSpecArr = Array.from(inputSpec);\n  const outputSpecArr = Array.from(outputSpec);\n  let inputDimMap, outputDimMap;\n\n  if (leftElided) {\n    // If we have beginning dimensions elided, we need to use negative\n    // indexing to determine where in the input dimension our values are.\n    inputDimMap = new Map<string, number>(\n      inputSpecArr.map((dim, i) => {\n        // This converts any negative indices to positive ones.\n        const idx = i + elided - inputShape.length;\n        const positiveIdx =\n          ((idx % inputShape.length) + inputShape.length) % inputShape.length;\n        return [dim, positiveIdx];\n      })\n    );\n\n    // Because we've constructed the full output shape already, we don't need\n    // to do negative indexing.\n    outputDimMap = new Map<string, number>(\n      outputSpecArr.map((dim, i) => [dim, i + elided])\n    );\n  } else {\n    inputDimMap = new Map<string, number>(\n      inputSpecArr.map((dim, i) => [dim, i])\n    );\n    outputDimMap = new Map<string, number>(\n      outputSpecArr.map((dim, i) => [dim, i])\n    );\n  }\n\n  for (const dim of inputSpec) {\n    const inputShapeAtDim = inputShape[inputDimMap.get(dim)];\n    if (outputDimMap.has(dim)) {\n      const outputShapeAtDim = newOutputShape[outputDimMap.get(dim)];\n      if (outputShapeAtDim !== null && outputShapeAtDim !== inputShapeAtDim) {\n        throw new ValueError(\n          `Input shape and output shape do not match at shared dimension `+\n          `'${dim}'. Input shape is ${inputShapeAtDim}, and output shape ` +\n          `is ${outputShapeAtDim}.`\n        );\n      }\n    }\n  }\n\n  for (const dim of outputSpec) {\n    if (!inputSpec.includes(dim) && !weightSpec.includes(dim)) {\n      throw new ValueError(\n        `Dimension '${dim}' was specified in the output '${outputSpec}' ` +\n        `but has no corresponding dimension in the input spec ` +\n        `'${inputSpec}' or weight spec '${weightSpec}'`\n      );\n    }\n  }\n\n  const weightShape: Shape = [];\n  for (const dim of weightSpec) {\n    if (inputDimMap.has(dim)) {\n      weightShape.push(inputShape[inputDimMap.get(dim)]);\n    } else if (outputDimMap.has(dim)) {\n      weightShape.push(newOutputShape[outputDimMap.get(dim)]);\n    } else {\n      throw new ValueError(\n        `Weight dimension '${dim}' did not have a match in either the ` +\n        `input spec '${inputSpec}' or the output spec '${outputSpec}'. For ` +\n        `this layer, the weight must be fully specified.`\n      );\n    }\n  }\n\n  let biasShape: Shape;\n  if (biasAxes != null) {\n    const numLeftElided = leftElided ? elided : 0;\n    const idxMap: { [char: string]: number } = {};\n    for (let i = 0; i < outputSpec.length; i++) {\n      idxMap[outputSpec[i]] = newOutputShape[i + numLeftElided];\n    }\n\n    for (const char of biasAxes) {\n      if (!outputSpec.includes(char)) {\n        throw new ValueError(\n          `Bias dimension '${char}' was requested, but is not part of the ` +\n          `output spec '${outputSpec}'`\n        );\n      }\n    }\n\n    const firstBiasLocation = Math.min(\n      ...biasAxes.split('').map(char => outputSpec.indexOf(char))\n    );\n    const biasOutputSpec = outputSpec.slice(firstBiasLocation);\n\n    biasShape = biasOutputSpec.split('').map(char =>\n      biasAxes.includes(char) ? idxMap[char] : 1\n    );\n\n    if (!leftElided) {\n      for (let i = 0; i < elided; i++) {\n        biasShape.push(1);\n      }\n    }\n  } else {\n    biasShape = null;\n  }\n  return [weightShape, biasShape, newOutputShape];\n}\n\nexport declare interface EinsumDenseArgs extends LayerArgs {\n  /**\n   * An equation describing the einsum to perform. This equation must be a\n   * valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or\n   * `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum\n   * axis expression sequence.\n   */\n  equation: string;\n\n  /**\n   * The expected shape of the output tensor (excluding the batch dimension and\n   * any dimensions represented by ellipses). You can specify None for any\n   * dimension that is unknown or can be inferred from the input shape.\n   */\n  outputShape: Shape|number;\n\n  /**\n   * Activation function to use. If you don't specify anything, no activation\n   * is applied (that is, a \"linear\" activation: `a(x) = x`).\n   */\n  activation?: ActivationIdentifier;\n\n  /**\n   * A string containing the output dimension(s) to apply a bias to. Each\n   * character in the `biasAxes` string should correspond to a character\n   * in the output portion of the `equation` string.\n   */\n  biasAxes?: string;\n\n  /**\n   * Initializer for the `kernel` weights matrix.\n   * Defaults to `\"glorotUniform\"`.\n   */\n  kernelInitializer?: InitializerIdentifier;\n\n  /**\n   * Initializer for the bias vector.\n   * Defaults to `\"zeros\"`.\n   */\n  biasInitializer?: InitializerIdentifier;\n\n  /**\n   * Regularizer function applied to the `kernel` weights matrix.\n   */\n  kernelRegularizer?: RegularizerIdentifier;\n\n  /**\n   * Regularizer function applied to the bias vector.\n   */\n  biasRegularizer?: RegularizerIdentifier;\n\n  /**\n   * Regularizer function applied to the output of the layer (its \"activation\").\n   */\n  activityRegularizer?: RegularizerIdentifier;\n\n  /**\n   * Constraint function applied to the `kernel` weights matrix.\n   */\n  kernelConstraint?: ConstraintIdentifier;\n\n  /**\n   * Constraint function applied to the bias vector.\n   */\n  biasConstraint?: ConstraintIdentifier;\n}\n\n/**\n * A layer that uses `tf.einsum` as the backing computation.\n *\n * This layer can perform einsum calculations of arbitrary dimensionality.\n *\n * Examples:\n *\n * **Biased dense layer with einsums**\n *\n * This example shows how to instantiate a standard Keras dense layer using\n * einsum operations. This example is equivalent to\n * tf.layers.Dense({units: 64, useBias: true})`.\n *\n * const layer = new EinsumDense({\n *    equation: \"ab,bc->ac\", outputShape: 4, biasAxes: \"c\"});\n * const inputTensor = tf.input({shape: [32]});\n * const outputTensor = layer.call(inputTensor);\n * console.log(outputTensor);  // [null, 64]\n *\n * **Applying a dense layer to a sequence**\n *\n * This example shows how to instantiate a layer that applies the same dense\n * operation to every element in a sequence. Here, the `outputShape` has two\n * values (since there are two non-batch dimensions in the output); the first\n * dimension in the `outputShape` is `null`, because the sequence dimension\n * `b` has an unknown shape.\n *\n * ```js\n * const layer = new EinsumDense({\n *    equation: \"abc,cd->abd\", outputShape: [null, 64], biasAxes: \"d\"});\n * const inputTensor = tf.input({shape: [32, 128]});\n * const outputTensor = layer.call(inputTensor);\n * console.log(outputTensor);  // [null, 32, 64]\n * ```\n *\n * **Applying a dense layer to a sequence using ellipses**\n *\n * This example shows how to instantiate a layer that applies the same dense\n * operation to every element in a sequence, but uses the ellipsis notation\n * instead of specifying the batch and sequence dimensions.\n *\n * Because we are using ellipsis notation and have specified only one axis, the\n * `outputShape` arg is a single value. When instantiated in this way, the\n * layer can handle any number of sequence dimensions - including the case\n * where no sequence dimension exists.\n *\n * ```js\n * const layer = new EinsumDense({\n *    equation: \"...x,xy->...y\", outputShape: 64, biasAxes: \"y\"});\n * const inputTensor = tf.input({shape: [32, 128]});\n * const outputTensor = layer.call(inputTensor);\n * console.log(outputTensor);  // [null, 32, 64]\n * ``\n */\nexport class EinsumDense extends Layer {\n  /** @nocollapse */\n  static readonly className = 'EinsumDense';\n  private readonly equation: string;\n  private readonly biasAxes: string;\n  private readonly partialOutputShape: Shape;\n  private readonly activation: Activation;\n  private readonly kernelInitializer: Initializer;\n  private readonly biasInitializer: Initializer;\n  private readonly kernelRegularizer: Regularizer;\n  private readonly biasRegularizer: Regularizer;\n  private readonly kernelConstraint: Constraint;\n  private readonly biasConstraint: Constraint;\n  private fullOutputShape: Shape;\n  private _kernel: LayerVariable;\n  private _bias: LayerVariable;\n\n  constructor(args: EinsumDenseArgs) {\n    super(args);\n    this.equation = args.equation;\n    this.biasAxes = args.biasAxes;\n    this.partialOutputShape =\n      Array.isArray(args.outputShape) ? args.outputShape : [args.outputShape];\n    this.activation = getActivation(args.activation);\n    this.kernelInitializer = getInitializer(\n      args.kernelInitializer ?? 'glorotUniform');\n    this.biasInitializer = getInitializer(args.biasInitializer ?? 'zeros');\n    this.kernelRegularizer = getRegularizer(args.kernelRegularizer);\n    this.biasRegularizer = getRegularizer(args.biasRegularizer);\n    this.kernelConstraint = getConstraint(args.kernelConstraint);\n    this.biasConstraint = getConstraint(args.biasConstraint);\n  }\n\n  get kernel(): LayerVariable {\n    return this._kernel;\n  }\n\n  get bias(): LayerVariable {\n    return this._bias;\n  }\n\n  override build(inputShape: Shape): void {\n    const [kernelShape, biasShape, fullOutputShape] = analyzeEinsumString(\n      this.equation,\n      this.biasAxes,\n      inputShape,\n      this.partialOutputShape\n    );\n    this.fullOutputShape = fullOutputShape;\n    this._kernel = this.addWeight(\n      'kernel',\n      kernelShape,\n      this.dtype,\n      this.kernelInitializer,\n      this.kernelRegularizer,\n      true,\n      this.kernelConstraint,\n    );\n\n    if (biasShape != null) {\n      this._bias = this.addWeight(\n        'bias',\n        biasShape,\n        this.dtype,\n        this.biasInitializer,\n        this.biasRegularizer,\n        true,\n        this.biasConstraint,\n      );\n    } else {\n      this._bias = null;\n    }\n    super.build(inputShape);\n  }\n\n  override computeOutputShape(_: Shape): Shape {\n    return this.fullOutputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config = {\n      outputShape: this.partialOutputShape,\n      equation: this.equation,\n      activation: serializeActivation(this.activation),\n      biasAxes: this.biasAxes,\n      kernelInitializer: serializeInitializer(this.kernelInitializer),\n      biasInitializer: serializeInitializer(this.biasInitializer),\n      kernelRegularizer: serializeRegularizer(this.kernelRegularizer),\n      biasRegularizer: serializeRegularizer(this.biasRegularizer),\n      kernelConstraint: serializeConstraint(this.kernelConstraint),\n      biasConstraint: serializeConstraint(this.biasConstraint),\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor2D {\n    return tidy(() => {\n      inputs = Array.isArray(inputs) ? inputs : [inputs];\n      let ret = einsum(this.equation, ...inputs, this.kernel.read());\n      if (this.bias != null) {\n        ret = ret.add(this.bias.read());\n      }\n      if (this.activation != null) {\n        ret = this.activation.apply(ret);\n      }\n      return ret;\n    });\n  }\n}\nserialization.registerClass(EinsumDense);\n"]}