@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
71 lines (63 loc) • 2.68 kB
text/typescript
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {ENGINE, ForwardFunc} from '../engine';
import {OneHot, OneHotAttrs, OneHotInputs} from '../kernel_names';
import {NamedAttrMap} from '../kernel_registry';
import {Tensor, Tensor1D} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import {op} from './operation';
/**
* Creates a one-hot `tf.Tensor`. The locations represented by `indices` take
* value `onValue` (defaults to 1), while all other locations take value
* `offValue` (defaults to 0). If `indices` is rank `R`, the output has rank
* `R+1` with the last axis of size `depth`.
*
* ```js
* tf.oneHot(tf.tensor1d([0, 1], 'int32'), 3).print();
* ```
*
* @param indices `tf.Tensor` of indices with dtype `int32`.
* @param depth The depth of the one hot dimension.
* @param onValue A number used to fill in the output when the index matches
* the location.
* @param offValue A number used to fill in the output when the index does
* not match the location.
*/
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
function oneHot_(
indices: Tensor|TensorLike, depth: number, onValue = 1,
offValue = 0): Tensor {
if (depth < 2) {
throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`);
}
let $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32');
const outShape = [...$indices.shape, depth];
$indices = $indices.flatten();
const forward: ForwardFunc<Tensor> = (backend, save) => {
save([$indices]);
return backend.oneHot($indices as Tensor1D, depth, onValue, offValue);
};
const inputs: OneHotInputs = {indices: $indices};
const attrs: OneHotAttrs = {depth, onValue, offValue};
const result = ENGINE.runKernelFunc(
forward, inputs as unknown as NamedTensorMap, null /* grad */, OneHot,
attrs as unknown as NamedAttrMap);
return result.reshape(outShape);
}
export const oneHot = op({oneHot_});