UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

112 lines 18.6 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Cached MHA layer based on `MultiHeadAttention`. */ /* Original source: keras_nlp/layers/modeling/cached_multi_head_attention.py */ import { cast, einsum, mul, reciprocal, serialization, sqrt, stack, tidy } from '@tensorflow/tfjs-core'; import { ValueError } from '../../../errors'; import { MultiHeadAttention } from '../multihead_attention'; import { sliceUpdate } from '../utils'; /** * MultiHeadAttention layer with cache support. * * This layer is suitable for use in autoregressive decoding. It can be use * to cache decoder self-attention and cross-attention. The forward pass * can happen in one of three modes: * - No cache, same as regular multi-head attention. * - Static cache (`cacheUpdateIndex` is None). In this case, the * cached key/value projections will be used and the input values will * be ignored. * - Updated cache (`cacheUpdateIndex` is not None). In this case, new * key/value projections are computed using the input, and spliced into * the cache at the specified index. * * Note that caching is useful only during inference and should not be used * during training. * * We use the notation `B`, `T`, `S` below, where `B` is the batch dimension, * `T` is the target sequence length, and `S` in the source sequence length. * Note that during generative decoding, `T` is usually 1 (you are * generating a target sequence of length one to predict the next token). * * Returns: * An `(attentionOutput, cache)` tuple. `attentionOutput` is the result * of the computation, of shape `(B, T, dim)`, where `T` is for target * sequence shapes and `dim` is the query input last dimension if * `outputShape` is `null`. Otherwise, the multi-head outputs are * projected to the shape specified by `outputShape`. `cache` is the * updated cache. */ export class CachedMultiHeadAttention extends MultiHeadAttention { call(query, kwargs) { return this.callAndReturnCache(query, kwargs)[0]; } /** * Exactly like `call` except also returns the updated cache. */ callAndReturnCache(query, { value, key, attentionMask, cache, cacheUpdateIndex }) { return tidy(() => { if (!this.builtFromSignature) { this.buildFromSignature(query.shape, value.shape, key ? key.shape : null); } if (key == null) { key = value; } query = this.queryDense.apply(query); // If cache is not `null`, we will use the cache to compute the final key // and value tensors. If `cacheUpdateIndex` is not `null`, we will first // update the cache before use. To do this, we first call the // `keyDense` and `valueDense` layers, and copy the outputs into the // cache at the specified index. `cache = null` handles the training // case, where we don't use the cache at all. if (cache != null) { const keyCache = cache.gather([0], 1).squeeze(); const valueCache = cache.gather([1], 1).squeeze(); if (cacheUpdateIndex == null) { key = keyCache; value = valueCache; } else { const keyUpdate = this.keyDense.apply(key); const valueUpdate = this.valueDense.apply(value); const start = [0, cacheUpdateIndex, 0, 0]; key = sliceUpdate(keyCache, start, keyUpdate); value = sliceUpdate(valueCache, start, valueUpdate); cache = stack([key, value], 1); } } else { if (cacheUpdateIndex != null) { throw new ValueError('`cacheUpdateIndex` should not be set if `cache` is `null`. ' + `Received: cache=${cache}, cacheUpdateIndex=${cacheUpdateIndex}`); } key = this.keyDense.apply(key); value = this.valueDense.apply(value); } query = mul(query, reciprocal(sqrt(cast(this.keyDim, query.dtype)))); let attentionScores = einsum(this.dotProductEquation, key, query); attentionScores = this.maskedSoftmax(attentionScores, attentionMask); attentionScores = this.dropoutLayer.apply(attentionScores); let attentionOutput = einsum(this.combineEquation, attentionScores, value); attentionOutput = this.outputDense.apply(attentionOutput); return [attentionOutput, cache]; }); } } serialization.registerClass(CachedMultiHeadAttention); //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiY2FjaGVkX211bHRpaGVhZF9hdHRlbnRpb24uanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWxheWVycy9zcmMvbGF5ZXJzL25scC9tb2RlbGluZy9jYWNoZWRfbXVsdGloZWFkX2F0dGVudGlvbi50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSDs7R0FFRztBQUVILCtFQUErRTtBQUMvRSxPQUFPLEVBQVUsSUFBSSxFQUFFLE1BQU0sRUFBRSxHQUFHLEVBQUUsVUFBVSxFQUFFLGFBQWEsRUFBRSxJQUFJLEVBQUUsS0FBSyxFQUFFLElBQUksRUFBRSxNQUFNLHVCQUF1QixDQUFDO0FBRWhILE9BQU8sRUFBRSxVQUFVLEVBQUUsTUFBTSxpQkFBaUIsQ0FBQztBQUM3QyxPQUFPLEVBQUUsa0JBQWtCLEVBQUUsTUFBTSx3QkFBd0IsQ0FBQztBQUM1RCxPQUFPLEVBQUUsV0FBVyxFQUFFLE1BQU0sVUFBVSxDQUFDO0FBaUR2Qzs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7R0E2Qkc7QUFDSCxNQUFNLE9BQU8sd0JBQXlCLFNBQVEsa0JBQWtCO0lBRXJELElBQUksQ0FDWCxLQUFhLEVBQUUsTUFBdUM7UUFFdEQsT0FBTyxJQUFJLENBQUMsa0JBQWtCLENBQUMsS0FBSyxFQUFFLE1BQU0sQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBQ25ELENBQUM7SUFFRDs7T0FFRztJQUNILGtCQUFrQixDQUNoQixLQUFhLEVBQ2IsRUFDRSxLQUFLLEVBQ0wsR0FBRyxFQUNILGFBQWEsRUFDYixLQUFLLEVBQ0wsZ0JBQWdCLEVBQ2lCO1FBRW5DLE9BQU8sSUFBSSxDQUFDLEdBQUcsRUFBRTtZQUNmLElBQUksQ0FBQyxJQUFJLENBQUMsa0JBQWtCLEVBQUU7Z0JBQzVCLElBQUksQ0FBQyxrQkFBa0IsQ0FDckIsS0FBSyxDQUFDLEtBQUssRUFBRSxLQUFLLENBQUMsS0FBSyxFQUFFLEdBQUcsQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsSUFBSSxDQUFDLENBQUM7YUFDckQ7WUFDRCxJQUFJLEdBQUcsSUFBSSxJQUFJLEVBQUU7Z0JBQ2YsR0FBRyxHQUFHLEtBQUssQ0FBQzthQUNiO1lBRUQsS0FBSyxHQUFHLElBQUksQ0FBQyxVQUFVLENBQUMsS0FBSyxDQUFDLEtBQUssQ0FBVyxDQUFDO1lBQy9DLHlFQUF5RTtZQUN6RSx3RUFBd0U7WUFDeEUsNkRBQTZEO1lBQzdELG9FQUFvRTtZQUNwRSxvRUFBb0U7WUFDcEUsNkNBQTZDO1lBQzdDLElBQUksS0FBSyxJQUFJLElBQUksRUFBRTtnQkFDakIsTUFBTSxRQUFRLEdBQUcsS0FBSyxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLE9BQU8sRUFBRSxDQUFDO2dCQUNoRCxNQUFNLFVBQVUsR0FBRyxLQUFLLENBQUMsTUFBTSxDQUFDLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsT0FBTyxFQUFFLENBQUM7Z0JBQ2xELElBQUksZ0JBQWdCLElBQUksSUFBSSxFQUFFO29CQUM1QixHQUFHLEdBQUcsUUFBUSxDQUFDO29CQUNmLEtBQUssR0FBRyxVQUFVLENBQUM7aUJBQ3BCO3FCQUFNO29CQUNMLE1BQU0sU0FBUyxHQUFHLElBQUksQ0FBQyxRQUFRLENBQUMsS0FBSyxDQUFDLEdBQUcsQ0FBVyxDQUFDO29CQUNyRCxNQUFNLFdBQVcsR0FBRyxJQUFJLENBQUMsVUFBVSxDQUFDLEtBQUssQ0FBQyxLQUFLLENBQVcsQ0FBQztvQkFDM0QsTUFBTSxLQUFLLEdBQUcsQ0FBQyxDQUFDLEVBQUUsZ0JBQWdCLEVBQUUsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDO29CQUMxQyxHQUFHLEdBQUcsV0FBVyxDQUFDLFFBQVEsRUFBRSxLQUFLLEVBQUUsU0FBUyxDQUFDLENBQUM7b0JBQzlDLEtBQUssR0FBRyxXQUFXLENBQUMsVUFBVSxFQUFFLEtBQUssRUFBRSxXQUFXLENBQUMsQ0FBQztvQkFDcEQsS0FBSyxHQUFHLEtBQUssQ0FBQyxDQUFDLEdBQUcsRUFBRSxLQUFLLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQztpQkFDaEM7YUFDRjtpQkFBTTtnQkFDTCxJQUFJLGdCQUFnQixJQUFJLElBQUksRUFBRTtvQkFDNUIsTUFBTSxJQUFJLFVBQVUsQ0FDbEIsNkRBQTZEO3dCQUM3RCxtQkFBbUIsS0FBSyxzQkFBc0IsZ0JBQWdCLEVBQUUsQ0FDakUsQ0FBQztpQkFDSDtnQkFDRCxHQUFHLEdBQUcsSUFBSSxDQUFDLFFBQVEsQ0FBQyxLQUFLLENBQUMsR0FBRyxDQUFXLENBQUM7Z0JBQ3pDLEtBQUssR0FBRyxJQUFJLENBQUMsVUFBVSxDQUFDLEtBQUssQ0FBQyxLQUFLLENBQVcsQ0FBQzthQUNoRDtZQUVELEtBQUssR0FBRyxHQUFHLENBQUMsS0FBSyxFQUFFLFVBQVUsQ0FBQyxJQUFJLENBQUMsSUFBSSxDQUFDLElBQUksQ0FBQyxNQUFNLEVBQUUsS0FBSyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO1lBQ3JFLElBQUksZUFBZSxHQUFHLE1BQU0sQ0FBQyxJQUFJLENBQUMsa0JBQWtCLEVBQUUsR0FBRyxFQUFFLEtBQUssQ0FBQyxDQUFDO1lBQ2xFLGVBQWUsR0FBRyxJQUFJLENBQUMsYUFBYSxDQUFDLGVBQWUsRUFBRSxhQUFhLENBQUMsQ0FBQztZQUNyRSxlQUFlLEdBQUcsSUFBSSxDQUFDLFlBQVksQ0FBQyxLQUFLLENBQUMsZUFBZSxDQUFXLENBQUM7WUFFckUsSUFBSSxlQUFlLEdBQ2pCLE1BQU0sQ0FBQyxJQUFJLENBQUMsZUFBZSxFQUFFLGVBQWUsRUFBRSxLQUFLLENBQUMsQ0FBQztZQUN2RCxlQUFlLEdBQUcsSUFBSSxDQUFDLFdBQVcsQ0FBQyxLQUFLLENBQUMsZUFBZSxDQUFXLENBQUM7WUFFcEUsT0FBTyxDQUFDLGVBQWUsRUFBRSxLQUFLLENBQUMsQ0FBQztRQUNsQyxDQUFDLENBQUMsQ0FBQztJQUNMLENBQUM7Q0FDRjtBQUNELGFBQWEsQ0FBQyxhQUFhLENBQUMsd0JBQXdCLENBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIzIEdvb2dsZSBMTEMuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuLyoqXG4gKiAgQ2FjaGVkIE1IQSBsYXllciBiYXNlZCBvbiBgTXVsdGlIZWFkQXR0ZW50aW9uYC5cbiAqL1xuXG4vKiBPcmlnaW5hbCBzb3VyY2U6IGtlcmFzX25scC9sYXllcnMvbW9kZWxpbmcvY2FjaGVkX211bHRpX2hlYWRfYXR0ZW50aW9uLnB5ICovXG5pbXBvcnQgeyBUZW5zb3IsIGNhc3QsIGVpbnN1bSwgbXVsLCByZWNpcHJvY2FsLCBzZXJpYWxpemF0aW9uLCBzcXJ0LCBzdGFjaywgdGlkeSB9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7IFZhbHVlRXJyb3IgfSBmcm9tICcuLi8uLi8uLi9lcnJvcnMnO1xuaW1wb3J0IHsgTXVsdGlIZWFkQXR0ZW50aW9uIH0gZnJvbSAnLi4vbXVsdGloZWFkX2F0dGVudGlvbic7XG5pbXBvcnQgeyBzbGljZVVwZGF0ZSB9IGZyb20gJy4uL3V0aWxzJztcblxuZXhwb3J0IGRlY2xhcmUgaW50ZXJmYWNlIENhY2hlZE11bHRpSGVhZEF0dGVudGlvbk9wdGlvbnMge1xuICAvKipcbiAgICogUXVlcnkgYFRlbnNvcmAgb2Ygc2hhcGUgYChCLCBULCBkaW0pYC5cbiAgICovXG5cbiAgLyoqXG4gICAqIFZhbHVlIGBUZW5zb3JgIG9mIHNoYXBlIGAoQiwgUyosIGRpbSlgLiBJZiBgY2FjaGVgIGlzIGBudWxsYCwgYFMqYFxuICAgKiBtdXN0IGVxdWFsIGBTYCBhbmQgbWF0Y2ggdGhlIHNoYXBlIG9mIGBhdHRlbnRpb25NYXNrYC4gSWYgYGNhY2hlYCBpc1xuICAgKiBub3QgYG51bGxgLCBgUypgIGNhbiBiZSBhbnkgbGVuZ3RoIGxlc3MgdGhhbiBgU2AsIGFuZCB0aGUgY29tcHV0ZWRcbiAgICogdmFsdWUgd2lsbCBiZSBzcGxpY2VkIGludG8gYGNhY2hlYCBhdCBgY2FjaGVVcGRhdGVJbmRleGAuXG4gICAqL1xuICB2YWx1ZTogVGVuc29yO1xuXG4gIC8qKlxuICAgKiBLZXkgYFRlbnNvcmAgb2Ygc2hhcGUgYChCLCBTKiwgZGltKWAuICBJZiBgY2FjaGVgIGlzIGBudWxsYCwgYFMqYCBtdXN0XG4gICAqIGVxdWFsIGBTYCBhbmQgbWF0Y2ggdGhlIHNoYXBlIG9mIGBhdHRlbnRpb25NYXNrYC4gSWYgYGNhY2hlYCBpcyBub3QgYG51bGxgLFxuICAgKiBgUypgIGNhbiBiZSBhbnkgbGVuZ3RoIGxlc3MgdGhhbiBgU2AsIGFuZCB0aGUgY29tcHV0ZWQgdmFsdWUgd2lsbCBiZVxuICAgKiBzcGxpY2VkIGludG8gYGNhY2hlYCBhdCBgY2FjaGVVcGRhdGVJbmRleGAuXG4gICAqL1xuICBrZXk/OiBUZW5zb3I7XG5cbiAgLyoqXG4gICAqIEEgYm9vbGVhbiBtYXNrIG9mIHNoYXBlIGAoQiwgVCwgUylgLiBgYXR0ZW50aW9uTWFza2AgcHJldmVudHNcbiAgICogYXR0ZW50aW9uIHRvIGNlcnRhaW4gcG9zaXRpb25zLiBUaGUgYm9vbGVhbiBtYXNrIHNwZWNpZmllcyB3aGljaFxuICAgKiBxdWVyeSBlbGVtZW50cyBjYW4gYXR0ZW5kIHRvIHdoaWNoIGtleSBlbGVtZW50cywgMSBpbmRpY2F0ZXNcbiAgICogYXR0ZW50aW9uIGFuZCAwIGluZGljYXRlcyBubyBhdHRlbnRpb24uIEJyb2FkY2FzdGluZyBjYW4gaGFwcGVuIGZvclxuICAgKiB0aGUgbWlzc2luZyBiYXRjaCBkaW1lbnNpb25zIGFuZCB0aGUgaGVhZCBkaW1lbnNpb24uXG4gICAqL1xuICBhdHRlbnRpb25NYXNrPzogVGVuc29yO1xuXG4gIC8qKlxuICAgKiBBIGRlbnNlIGZsb2F0IFRlbnNvci4gVGhlIGtleS92YWx1ZSBjYWNoZSwgb2Ygc2hhcGVcbiAgICogYFtCLCAyLCBTLCBudW1IZWFkcywga2V5RGltc11gLCB3aGVyZSBgU2AgbXVzdCBhZ3JlZSB3aXRoIHRoZVxuICAgKiBgYXR0ZW50aW9uTWFza2Agc2hhcGUuIFRoaXMgYXJndW1lbnQgaXMgaW50ZW5kZWQgZm9yIHVzZSBkdXJpbmdcbiAgICogZ2VuZXJhdGlvbiB0byBhdm9pZCByZWNvbXB1dGluZyBpbnRlcm1lZGlhdGUgc3RhdGUuXG4gICAqL1xuICBjYWNoZT86IFRlbnNvcjtcblxuICAvKipcbiAgICogSW50ZWdlciBvciBJbnRlZ2VyIGBUZW5zb3JgLiBUaGUgaW5kZXggYXQgd2hpY2ggdG8gdXBkYXRlIGBjYWNoZWBcbiAgICogKHVzdWFsbHkgdGhlIGluZGV4IG9mIHRoZSBjdXJyZW50IHRva2VuIGJlaW5nIHByb2Nlc3NlZCB3aGVuIHJ1bm5pbmdcbiAgICogZ2VuZXJhdGlvbikuIElmIGBjYWNoZVVwZGF0ZUluZGV4PW51bGxgIHdoaWxlIGBjYWNoZWAgaXMgc2V0LCB0aGUgY2FjaGVcbiAgICogd2lsbCBub3QgYmUgdXBkYXRlZC5cbiAgICovXG4gIGNhY2hlVXBkYXRlSW5kZXg/OiBudW1iZXI7XG59XG5cbi8qKlxuICogTXVsdGlIZWFkQXR0ZW50aW9uIGxheWVyIHdpdGggY2FjaGUgc3VwcG9ydC5cbiAqXG4gKiBUaGlzIGxheWVyIGlzIHN1aXRhYmxlIGZvciB1c2UgaW4gYXV0b3JlZ3Jlc3NpdmUgZGVjb2RpbmcuIEl0IGNhbiBiZSB1c2VcbiAqIHRvIGNhY2hlIGRlY29kZXIgc2VsZi1hdHRlbnRpb24gYW5kIGNyb3NzLWF0dGVudGlvbi4gVGhlIGZvcndhcmQgcGFzc1xuICogY2FuIGhhcHBlbiBpbiBvbmUgb2YgdGhyZWUgbW9kZXM6XG4gKiAtIE5vIGNhY2hlLCBzYW1lIGFzIHJlZ3VsYXIgbXVsdGktaGVhZCBhdHRlbnRpb24uXG4gKiAtIFN0YXRpYyBjYWNoZSAoYGNhY2hlVXBkYXRlSW5kZXhgIGlzIE5vbmUpLiBJbiB0aGlzIGNhc2UsIHRoZVxuICogICAgIGNhY2hlZCBrZXkvdmFsdWUgcHJvamVjdGlvbnMgd2lsbCBiZSB1c2VkIGFuZCB0aGUgaW5wdXQgdmFsdWVzIHdpbGxcbiAqICAgICBiZSBpZ25vcmVkLlxuICogLSBVcGRhdGVkIGNhY2hlIChgY2FjaGVVcGRhdGVJbmRleGAgaXMgbm90IE5vbmUpLiBJbiB0aGlzIGNhc2UsIG5ld1xuICogICAgIGtleS92YWx1ZSBwcm9qZWN0aW9ucyBhcmUgY29tcHV0ZWQgdXNpbmcgdGhlIGlucHV0LCBhbmQgc3BsaWNlZCBpbnRvXG4gKiAgICAgdGhlIGNhY2hlIGF0IHRoZSBzcGVjaWZpZWQgaW5kZXguXG4gKlxuICogTm90ZSB0aGF0IGNhY2hpbmcgaXMgdXNlZnVsIG9ubHkgZHVyaW5nIGluZmVyZW5jZSBhbmQgc2hvdWxkIG5vdCBiZSB1c2VkXG4gKiBkdXJpbmcgdHJhaW5pbmcuXG4gKlxuICogV2UgdXNlIHRoZSBub3RhdGlvbiBgQmAsIGBUYCwgYFNgIGJlbG93LCB3aGVyZSBgQmAgaXMgdGhlIGJhdGNoIGRpbWVuc2lvbixcbiAqIGBUYCBpcyB0aGUgdGFyZ2V0IHNlcXVlbmNlIGxlbmd0aCwgYW5kIGBTYCBpbiB0aGUgc291cmNlIHNlcXVlbmNlIGxlbmd0aC5cbiAqIE5vdGUgdGhhdCBkdXJpbmcgZ2VuZXJhdGl2ZSBkZWNvZGluZywgYFRgIGlzIHVzdWFsbHkgMSAoeW91IGFyZVxuICogZ2VuZXJhdGluZyBhIHRhcmdldCBzZXF1ZW5jZSBvZiBsZW5ndGggb25lIHRvIHByZWRpY3QgdGhlIG5leHQgdG9rZW4pLlxuICpcbiAqIFJldHVybnM6XG4gKiAgICAgQW4gYChhdHRlbnRpb25PdXRwdXQsIGNhY2hlKWAgdHVwbGUuIGBhdHRlbnRpb25PdXRwdXRgIGlzIHRoZSByZXN1bHRcbiAqICAgICBvZiB0aGUgY29tcHV0YXRpb24sIG9mIHNoYXBlIGAoQiwgVCwgZGltKWAsIHdoZXJlIGBUYCBpcyBmb3IgdGFyZ2V0XG4gKiAgICAgc2VxdWVuY2Ugc2hhcGVzIGFuZCBgZGltYCBpcyB0aGUgcXVlcnkgaW5wdXQgbGFzdCBkaW1lbnNpb24gaWZcbiAqICAgICBgb3V0cHV0U2hhcGVgIGlzIGBudWxsYC4gT3RoZXJ3aXNlLCB0aGUgbXVsdGktaGVhZCBvdXRwdXRzIGFyZVxuICogICAgIHByb2plY3RlZCB0byB0aGUgc2hhcGUgc3BlY2lmaWVkIGJ5IGBvdXRwdXRTaGFwZWAuIGBjYWNoZWAgaXMgdGhlXG4gKiAgICAgdXBkYXRlZCBjYWNoZS5cbiAqL1xuZXhwb3J0IGNsYXNzIENhY2hlZE11bHRpSGVhZEF0dGVudGlvbiBleHRlbmRzIE11bHRpSGVhZEF0dGVudGlvbiB7XG5cbiAgb3ZlcnJpZGUgY2FsbChcbiAgICBxdWVyeTogVGVuc29yLCBrd2FyZ3M6IENhY2hlZE11bHRpSGVhZEF0dGVudGlvbk9wdGlvbnNcbiAgKTogVGVuc29yIHtcbiAgICByZXR1cm4gdGhpcy5jYWxsQW5kUmV0dXJuQ2FjaGUocXVlcnksIGt3YXJncylbMF07XG4gIH1cblxuICAvKipcbiAgICogRXhhY3RseSBsaWtlIGBjYWxsYCBleGNlcHQgYWxzbyByZXR1cm5zIHRoZSB1cGRhdGVkIGNhY2hlLlxuICAgKi9cbiAgY2FsbEFuZFJldHVybkNhY2hlKFxuICAgIHF1ZXJ5OiBUZW5zb3IsXG4gICAge1xuICAgICAgdmFsdWUsXG4gICAgICBrZXksXG4gICAgICBhdHRlbnRpb25NYXNrLFxuICAgICAgY2FjaGUsXG4gICAgICBjYWNoZVVwZGF0ZUluZGV4XG4gICAgfSA6IENhY2hlZE11bHRpSGVhZEF0dGVudGlvbk9wdGlvbnNcbiAgKTogW1RlbnNvciwgVGVuc29yXSB7XG4gICAgcmV0dXJuIHRpZHkoKCkgPT4ge1xuICAgICAgaWYgKCF0aGlzLmJ1aWx0RnJvbVNpZ25hdHVyZSkge1xuICAgICAgICB0aGlzLmJ1aWxkRnJvbVNpZ25hdHVyZShcbiAgICAgICAgICBxdWVyeS5zaGFwZSwgdmFsdWUuc2hhcGUsIGtleSA/IGtleS5zaGFwZSA6IG51bGwpO1xuICAgICAgfVxuICAgICAgaWYgKGtleSA9PSBudWxsKSB7XG4gICAgICAgIGtleSA9IHZhbHVlO1xuICAgICAgfVxuXG4gICAgICBxdWVyeSA9IHRoaXMucXVlcnlEZW5zZS5hcHBseShxdWVyeSkgYXMgVGVuc29yO1xuICAgICAgLy8gSWYgY2FjaGUgaXMgbm90IGBudWxsYCwgd2Ugd2lsbCB1c2UgdGhlIGNhY2hlIHRvIGNvbXB1dGUgdGhlIGZpbmFsIGtleVxuICAgICAgLy8gYW5kIHZhbHVlIHRlbnNvcnMuIElmIGBjYWNoZVVwZGF0ZUluZGV4YCBpcyBub3QgYG51bGxgLCB3ZSB3aWxsIGZpcnN0XG4gICAgICAvLyB1cGRhdGUgdGhlIGNhY2hlIGJlZm9yZSB1c2UuIFRvIGRvIHRoaXMsIHdlIGZpcnN0IGNhbGwgdGhlXG4gICAgICAvLyBga2V5RGVuc2VgIGFuZCBgdmFsdWVEZW5zZWAgbGF5ZXJzLCBhbmQgY29weSB0aGUgb3V0cHV0cyBpbnRvIHRoZVxuICAgICAgLy8gY2FjaGUgYXQgdGhlIHNwZWNpZmllZCBpbmRleC4gYGNhY2hlID0gbnVsbGAgaGFuZGxlcyB0aGUgdHJhaW5pbmdcbiAgICAgIC8vIGNhc2UsIHdoZXJlIHdlIGRvbid0IHVzZSB0aGUgY2FjaGUgYXQgYWxsLlxuICAgICAgaWYgKGNhY2hlICE9IG51bGwpIHtcbiAgICAgICAgY29uc3Qga2V5Q2FjaGUgPSBjYWNoZS5nYXRoZXIoWzBdLCAxKS5zcXVlZXplKCk7XG4gICAgICAgIGNvbnN0IHZhbHVlQ2FjaGUgPSBjYWNoZS5nYXRoZXIoWzFdLCAxKS5zcXVlZXplKCk7XG4gICAgICAgIGlmIChjYWNoZVVwZGF0ZUluZGV4ID09IG51bGwpIHtcbiAgICAgICAgICBrZXkgPSBrZXlDYWNoZTtcbiAgICAgICAgICB2YWx1ZSA9IHZhbHVlQ2FjaGU7XG4gICAgICAgIH0gZWxzZSB7XG4gICAgICAgICAgY29uc3Qga2V5VXBkYXRlID0gdGhpcy5rZXlEZW5zZS5hcHBseShrZXkpIGFzIFRlbnNvcjtcbiAgICAgICAgICBjb25zdCB2YWx1ZVVwZGF0ZSA9IHRoaXMudmFsdWVEZW5zZS5hcHBseSh2YWx1ZSkgYXMgVGVuc29yO1xuICAgICAgICAgIGNvbnN0IHN0YXJ0ID0gWzAsIGNhY2hlVXBkYXRlSW5kZXgsIDAsIDBdO1xuICAgICAgICAgIGtleSA9IHNsaWNlVXBkYXRlKGtleUNhY2hlLCBzdGFydCwga2V5VXBkYXRlKTtcbiAgICAgICAgICB2YWx1ZSA9IHNsaWNlVXBkYXRlKHZhbHVlQ2FjaGUsIHN0YXJ0LCB2YWx1ZVVwZGF0ZSk7XG4gICAgICAgICAgY2FjaGUgPSBzdGFjayhba2V5LCB2YWx1ZV0sIDEpO1xuICAgICAgICB9XG4gICAgICB9IGVsc2Uge1xuICAgICAgICBpZiAoY2FjaGVVcGRhdGVJbmRleCAhPSBudWxsKSB7XG4gICAgICAgICAgdGhyb3cgbmV3IFZhbHVlRXJyb3IoXG4gICAgICAgICAgICAnYGNhY2hlVXBkYXRlSW5kZXhgIHNob3VsZCBub3QgYmUgc2V0IGlmIGBjYWNoZWAgaXMgYG51bGxgLiAnICtcbiAgICAgICAgICAgIGBSZWNlaXZlZDogY2FjaGU9JHtjYWNoZX0sIGNhY2hlVXBkYXRlSW5kZXg9JHtjYWNoZVVwZGF0ZUluZGV4fWBcbiAgICAgICAgICApO1xuICAgICAgICB9XG4gICAgICAgIGtleSA9IHRoaXMua2V5RGVuc2UuYXBwbHkoa2V5KSBhcyBUZW5zb3I7XG4gICAgICAgIHZhbHVlID0gdGhpcy52YWx1ZURlbnNlLmFwcGx5KHZhbHVlKSBhcyBUZW5zb3I7XG4gICAgICB9XG5cbiAgICAgIHF1ZXJ5ID0gbXVsKHF1ZXJ5LCByZWNpcHJvY2FsKHNxcnQoY2FzdCh0aGlzLmtleURpbSwgcXVlcnkuZHR5cGUpKSkpO1xuICAgICAgbGV0IGF0dGVudGlvblNjb3JlcyA9IGVpbnN1bSh0aGlzLmRvdFByb2R1Y3RFcXVhdGlvbiwga2V5LCBxdWVyeSk7XG4gICAgICBhdHRlbnRpb25TY29yZXMgPSB0aGlzLm1hc2tlZFNvZnRtYXgoYXR0ZW50aW9uU2NvcmVzLCBhdHRlbnRpb25NYXNrKTtcbiAgICAgIGF0dGVudGlvblNjb3JlcyA9IHRoaXMuZHJvcG91dExheWVyLmFwcGx5KGF0dGVudGlvblNjb3JlcykgYXMgVGVuc29yO1xuXG4gICAgICBsZXQgYXR0ZW50aW9uT3V0cHV0ID1cbiAgICAgICAgZWluc3VtKHRoaXMuY29tYmluZUVxdWF0aW9uLCBhdHRlbnRpb25TY29yZXMsIHZhbHVlKTtcbiAgICAgIGF0dGVudGlvbk91dHB1dCA9IHRoaXMub3V0cHV0RGVuc2UuYXBwbHkoYXR0ZW50aW9uT3V0cHV0KSBhcyBUZW5zb3I7XG5cbiAgICAgIHJldHVybiBbYXR0ZW50aW9uT3V0cHV0LCBjYWNoZV07XG4gICAgfSk7XG4gIH1cbn1cbnNlcmlhbGl6YXRpb24ucmVnaXN0ZXJDbGFzcyhDYWNoZWRNdWx0aUhlYWRBdHRlbnRpb24pO1xuIl19