UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

73 lines (72 loc) 3.44 kB
/** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /// <amd-module name="@tensorflow/tfjs-layers/dist/engine/training_utils" /> import { Tensor } from '@tensorflow/tfjs-core'; /** * For multi-class classification problems, this object is designed to store a * mapping from class index to the "weight" of the class, where higher weighted * classes have larger impact on loss, accuracy, and other metrics. * * This is useful for cases in which you want the model to "pay more attention" * to examples from an under-represented class, e.g., in unbalanced datasets. */ export type ClassWeight = { [classIndex: number]: number; }; /** * Class weighting for a model with multiple outputs. * * This object maps each output name to a class-weighting object. */ export type ClassWeightMap = { [outputName: string]: ClassWeight; }; /** * Standardize class weighting objects. * * This function takes a single class-weighting object, an array of them, * or a map from output name to class-weighting object. It compares it to the * output name(s) of the model, base on which it outputs an array of * class-weighting objects of which the length matches the number of outputs. * * @param classWeight Input class-weighting object(s). * @param outputNames All output name(s) of the model. * @return An array of class-weighting objects. The length of the array matches * the model's number of outputs. */ export declare function standardizeClassWeights(classWeight: ClassWeight | ClassWeight[] | ClassWeightMap, outputNames: string[]): ClassWeight[]; export declare function standardizeSampleWeights(classWeight: ClassWeight | ClassWeight[] | ClassWeightMap, outputNames: string[]): ClassWeight[]; /** * Standardize by-sample and/or by-class weights for training. * * Note that this function operates on one model output at a time. For a model * with multiple outputs, you must call this function multiple times. * * @param y The target tensor that the by-sample and/or by-class weight is for. * The values of y are assumed to encode the classes, either directly * as an integer index, or as one-hot encoding. * @param sampleWeight By-sample weights. * @param classWeight By-class weights: an object mapping class indices * (integers) to a weight (float) to apply to the model's loss for the * samples from this class during training. This can be useful to tell the * model to "pay more attention" to samples from an under-represented class. * @param sampleWeightMode The mode for the sample weights. * @return A Promise of weight tensor, of which the size of the first dimension * matches that of `y`. */ export declare function standardizeWeights(y: Tensor, sampleWeight?: Tensor, classWeight?: ClassWeight, sampleWeightMode?: 'temporal'): Promise<Tensor>; /** * Apply per-sample weights on the loss values from a number of samples. * * @param losses Loss tensor of shape `[batchSize]`. * @param sampleWeights Per-sample weight tensor of shape `[batchSize]`. * @returns Tensor of the same shape as`losses`. */ export declare function computeWeightedLoss(losses: Tensor, sampleWeights: Tensor): Tensor<import("@tensorflow/tfjs-core").Rank>;