UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

71 lines 9.72 kB
/** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { ENGINE } from '../engine'; import { TensorScatterUpdate } from '../kernel_names'; import { convertToTensor } from '../tensor_util_env'; import { op } from './operation'; import * as scatter_nd_util from './scatter_nd_util'; /** * Creates a new tensor by applying sparse updates to individual * values or slices to the passed in tensor according to * indices. This operator is the similar to scatterNd op, except that the * udpates are scattered on an existing tensor (as opposed to a zero-tensor). * * If indices contains duplicates, then we pick the last update for the index. * * If an out of bound index is found on CPU, an error is returned. * * Warning: There are some GPU specific semantics for this operation. * - If an out of bound index is found, the index is ignored. * - The order in which updates are applied is nondeterministic, so the output * will be nondeterministic if indices contains duplicates. * ```js * const shape = [8]; * const tensor = tf.ones(shape); * const indices = tf.tensor2d([4, 3, 1, 7], [4, 1], 'int32'); * const updates = tf.tensor1d([9, 10, 11, 12]); * * tf.tensorScatterUpdate(tensor, indices, updates).print(); * //[1, 11, 1, 10, 9, 1, 1, 12] * ``` * * @param tensor A Tensor. Tensor to copy/update. * @param indices The tensor contains the indices into the output tensor, must * have at least 2 axes: (num_updates, index_depth). * @param updates The tensor contains the value for the indices. * * @doc {heading: 'Operations', subheading: 'Slicing and Joining'} */ function tensorScatterUpdate_(tensor, indices, updates) { const $tensor = convertToTensor(tensor, 'tensor', 'tensorScatterupdate'); const $indices = convertToTensor(indices, 'indices', 'tensorScatterupdate', 'int32'); const $updates = convertToTensor(updates, 'updates', 'tensorScatterupdate'); scatter_nd_util.validateInput($updates, $indices, $tensor.shape); if ($tensor.dtype !== $updates.dtype) { throw new Error(`tensor and updates must have the same dtype, instead they are ${$tensor.dtype} and ${$updates.dtype}.`); } const inputs = { tensor: $tensor, indices: $indices, updates: $updates }; const attrs = {}; // tslint:disable-next-line: no-unnecessary-type-assertion return ENGINE.runKernel(TensorScatterUpdate, inputs, attrs); } export const tensorScatterUpdate = op({ tensorScatterUpdate_ }); //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoidGVuc29yX3NjYXR0ZXJfdXBkYXRlLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvdGVuc29yX3NjYXR0ZXJfdXBkYXRlLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDakMsT0FBTyxFQUFDLG1CQUFtQixFQUFzRCxNQUFNLGlCQUFpQixDQUFDO0FBSXpHLE9BQU8sRUFBQyxlQUFlLEVBQUMsTUFBTSxvQkFBb0IsQ0FBQztBQUduRCxPQUFPLEVBQUMsRUFBRSxFQUFDLE1BQU0sYUFBYSxDQUFDO0FBQy9CLE9BQU8sS0FBSyxlQUFlLE1BQU0sbUJBQW1CLENBQUM7QUFFckQ7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7OztHQThCRztBQUNILFNBQVMsb0JBQW9CLENBQ3pCLE1BQTRCLEVBQUUsT0FBMEIsRUFDeEQsT0FBMEI7SUFDNUIsTUFBTSxPQUFPLEdBQUcsZUFBZSxDQUFDLE1BQU0sRUFBRSxRQUFRLEVBQUUscUJBQXFCLENBQUMsQ0FBQztJQUN6RSxNQUFNLFFBQVEsR0FDVixlQUFlLENBQUMsT0FBTyxFQUFFLFNBQVMsRUFBRSxxQkFBcUIsRUFBRSxPQUFPLENBQUMsQ0FBQztJQUN4RSxNQUFNLFFBQVEsR0FBRyxlQUFlLENBQUMsT0FBTyxFQUFFLFNBQVMsRUFBRSxxQkFBcUIsQ0FBQyxDQUFDO0lBQzVFLGVBQWUsQ0FBQyxhQUFhLENBQUMsUUFBUSxFQUFFLFFBQVEsRUFBRSxPQUFPLENBQUMsS0FBSyxDQUFDLENBQUM7SUFDakUsSUFBSSxPQUFPLENBQUMsS0FBSyxLQUFLLFFBQVEsQ0FBQyxLQUFLLEVBQUU7UUFDcEMsTUFBTSxJQUFJLEtBQUssQ0FDWCxpRUFDSSxPQUFPLENBQUMsS0FBSyxRQUFRLFFBQVEsQ0FBQyxLQUFLLEdBQUcsQ0FBQyxDQUFDO0tBQ2pEO0lBRUQsTUFBTSxNQUFNLEdBQThCO1FBQ3hDLE1BQU0sRUFBRSxPQUFPO1FBQ2YsT0FBTyxFQUFFLFFBQVE7UUFDakIsT0FBTyxFQUFFLFFBQVE7S0FDbEIsQ0FBQztJQUNGLE1BQU0sS0FBSyxHQUE2QixFQUFFLENBQUM7SUFFM0MsMERBQTBEO0lBQzFELE9BQU8sTUFBTSxDQUFDLFNBQVMsQ0FDWixtQkFBbUIsRUFBRSxNQUFtQyxFQUN4RCxLQUFnQyxDQUFjLENBQUM7QUFDNUQsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLG1CQUFtQixHQUFHLEVBQUUsQ0FBQyxFQUFDLG9CQUFvQixFQUFDLENBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIyIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtFTkdJTkV9IGZyb20gJy4uL2VuZ2luZSc7XG5pbXBvcnQge1RlbnNvclNjYXR0ZXJVcGRhdGUsIFRlbnNvclNjYXR0ZXJVcGRhdGVBdHRycywgVGVuc29yU2NhdHRlclVwZGF0ZUlucHV0c30gZnJvbSAnLi4va2VybmVsX25hbWVzJztcbmltcG9ydCB7TmFtZWRBdHRyTWFwfSBmcm9tICcuLi9rZXJuZWxfcmVnaXN0cnknO1xuaW1wb3J0IHtUZW5zb3J9IGZyb20gJy4uL3RlbnNvcic7XG5pbXBvcnQge05hbWVkVGVuc29yTWFwfSBmcm9tICcuLi90ZW5zb3JfdHlwZXMnO1xuaW1wb3J0IHtjb252ZXJ0VG9UZW5zb3J9IGZyb20gJy4uL3RlbnNvcl91dGlsX2Vudic7XG5pbXBvcnQge1JhbmssIFRlbnNvckxpa2V9IGZyb20gJy4uL3R5cGVzJztcblxuaW1wb3J0IHtvcH0gZnJvbSAnLi9vcGVyYXRpb24nO1xuaW1wb3J0ICogYXMgc2NhdHRlcl9uZF91dGlsIGZyb20gJy4vc2NhdHRlcl9uZF91dGlsJztcblxuLyoqXG4gKiBDcmVhdGVzIGEgbmV3IHRlbnNvciBieSBhcHBseWluZyBzcGFyc2UgdXBkYXRlcyB0byBpbmRpdmlkdWFsXG4gKiB2YWx1ZXMgb3Igc2xpY2VzIHRvIHRoZSBwYXNzZWQgaW4gdGVuc29yIGFjY29yZGluZyB0b1xuICogaW5kaWNlcy4gVGhpcyBvcGVyYXRvciBpcyB0aGUgc2ltaWxhciB0byBzY2F0dGVyTmQgb3AsIGV4Y2VwdCB0aGF0IHRoZVxuICogdWRwYXRlcyBhcmUgc2NhdHRlcmVkIG9uIGFuIGV4aXN0aW5nIHRlbnNvciAoYXMgb3Bwb3NlZCB0byBhIHplcm8tdGVuc29yKS5cbiAqXG4gKiBJZiBpbmRpY2VzIGNvbnRhaW5zIGR1cGxpY2F0ZXMsIHRoZW4gd2UgcGljayB0aGUgbGFzdCB1cGRhdGUgZm9yIHRoZSBpbmRleC5cbiAqXG4gKiBJZiBhbiBvdXQgb2YgYm91bmQgaW5kZXggaXMgZm91bmQgb24gQ1BVLCBhbiBlcnJvciBpcyByZXR1cm5lZC5cbiAqXG4gKiBXYXJuaW5nOiBUaGVyZSBhcmUgc29tZSBHUFUgc3BlY2lmaWMgc2VtYW50aWNzIGZvciB0aGlzIG9wZXJhdGlvbi5cbiAqICAtIElmIGFuIG91dCBvZiBib3VuZCBpbmRleCBpcyBmb3VuZCwgdGhlIGluZGV4IGlzIGlnbm9yZWQuXG4gKiAgLSBUaGUgb3JkZXIgaW4gd2hpY2ggdXBkYXRlcyBhcmUgYXBwbGllZCBpcyBub25kZXRlcm1pbmlzdGljLCBzbyB0aGUgb3V0cHV0XG4gKiB3aWxsIGJlIG5vbmRldGVybWluaXN0aWMgaWYgaW5kaWNlcyBjb250YWlucyBkdXBsaWNhdGVzLlxuICogYGBganNcbiAqIGNvbnN0IHNoYXBlID0gWzhdO1xuICogY29uc3QgdGVuc29yID0gdGYub25lcyhzaGFwZSk7XG4gKiBjb25zdCBpbmRpY2VzID0gdGYudGVuc29yMmQoWzQsIDMsIDEsIDddLCBbNCwgMV0sICdpbnQzMicpO1xuICogY29uc3QgdXBkYXRlcyA9IHRmLnRlbnNvcjFkKFs5LCAxMCwgMTEsIDEyXSk7XG4gKlxuICogdGYudGVuc29yU2NhdHRlclVwZGF0ZSh0ZW5zb3IsIGluZGljZXMsIHVwZGF0ZXMpLnByaW50KCk7XG4gKiAgICAvL1sxLCAxMSwgMSwgMTAsIDksIDEsIDEsIDEyXVxuICogYGBgXG4gKlxuICogQHBhcmFtIHRlbnNvciBBIFRlbnNvci4gVGVuc29yIHRvIGNvcHkvdXBkYXRlLlxuICogQHBhcmFtIGluZGljZXMgVGhlIHRlbnNvciBjb250YWlucyB0aGUgaW5kaWNlcyBpbnRvIHRoZSBvdXRwdXQgdGVuc29yLCBtdXN0XG4gKiAgICAgaGF2ZSBhdCBsZWFzdCAyIGF4ZXM6IChudW1fdXBkYXRlcywgaW5kZXhfZGVwdGgpLlxuICogQHBhcmFtIHVwZGF0ZXMgVGhlIHRlbnNvciBjb250YWlucyB0aGUgdmFsdWUgZm9yIHRoZSBpbmRpY2VzLlxuICpcbiAqIEBkb2Mge2hlYWRpbmc6ICdPcGVyYXRpb25zJywgc3ViaGVhZGluZzogJ1NsaWNpbmcgYW5kIEpvaW5pbmcnfVxuICovXG5mdW5jdGlvbiB0ZW5zb3JTY2F0dGVyVXBkYXRlXzxSIGV4dGVuZHMgUmFuaz4oXG4gICAgdGVuc29yOiBUZW5zb3I8Uj58VGVuc29yTGlrZSwgaW5kaWNlczogVGVuc29yfFRlbnNvckxpa2UsXG4gICAgdXBkYXRlczogVGVuc29yfFRlbnNvckxpa2UpOiBUZW5zb3I8Uj4ge1xuICBjb25zdCAkdGVuc29yID0gY29udmVydFRvVGVuc29yKHRlbnNvciwgJ3RlbnNvcicsICd0ZW5zb3JTY2F0dGVydXBkYXRlJyk7XG4gIGNvbnN0ICRpbmRpY2VzID1cbiAgICAgIGNvbnZlcnRUb1RlbnNvcihpbmRpY2VzLCAnaW5kaWNlcycsICd0ZW5zb3JTY2F0dGVydXBkYXRlJywgJ2ludDMyJyk7XG4gIGNvbnN0ICR1cGRhdGVzID0gY29udmVydFRvVGVuc29yKHVwZGF0ZXMsICd1cGRhdGVzJywgJ3RlbnNvclNjYXR0ZXJ1cGRhdGUnKTtcbiAgc2NhdHRlcl9uZF91dGlsLnZhbGlkYXRlSW5wdXQoJHVwZGF0ZXMsICRpbmRpY2VzLCAkdGVuc29yLnNoYXBlKTtcbiAgaWYgKCR0ZW5zb3IuZHR5cGUgIT09ICR1cGRhdGVzLmR0eXBlKSB7XG4gICAgdGhyb3cgbmV3IEVycm9yKFxuICAgICAgICBgdGVuc29yIGFuZCB1cGRhdGVzIG11c3QgaGF2ZSB0aGUgc2FtZSBkdHlwZSwgaW5zdGVhZCB0aGV5IGFyZSAke1xuICAgICAgICAgICAgJHRlbnNvci5kdHlwZX0gYW5kICR7JHVwZGF0ZXMuZHR5cGV9LmApO1xuICB9XG5cbiAgY29uc3QgaW5wdXRzOiBUZW5zb3JTY2F0dGVyVXBkYXRlSW5wdXRzID0ge1xuICAgIHRlbnNvcjogJHRlbnNvcixcbiAgICBpbmRpY2VzOiAkaW5kaWNlcyxcbiAgICB1cGRhdGVzOiAkdXBkYXRlc1xuICB9O1xuICBjb25zdCBhdHRyczogVGVuc29yU2NhdHRlclVwZGF0ZUF0dHJzID0ge307XG5cbiAgLy8gdHNsaW50OmRpc2FibGUtbmV4dC1saW5lOiBuby11bm5lY2Vzc2FyeS10eXBlLWFzc2VydGlvblxuICByZXR1cm4gRU5HSU5FLnJ1bktlcm5lbChcbiAgICAgICAgICAgICBUZW5zb3JTY2F0dGVyVXBkYXRlLCBpbnB1dHMgYXMgdW5rbm93biBhcyBOYW1lZFRlbnNvck1hcCxcbiAgICAgICAgICAgICBhdHRycyBhcyB1bmtub3duIGFzIE5hbWVkQXR0ck1hcCkgYXMgVGVuc29yPFI+O1xufVxuXG5leHBvcnQgY29uc3QgdGVuc29yU2NhdHRlclVwZGF0ZSA9IG9wKHt0ZW5zb3JTY2F0dGVyVXBkYXRlX30pO1xuIl19