UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

156 lines 21 kB
/** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ import { argMax, clone, dispose, mul, reshape, tensor1d, tidy } from '@tensorflow/tfjs-core'; function standardizeSampleOrClassWeights(xWeight, outputNames, weightType) { const numOutputs = outputNames.length; if (xWeight == null || (Array.isArray(xWeight) && xWeight.length === 0)) { return outputNames.map(name => null); } if (numOutputs === 1) { if (Array.isArray(xWeight) && xWeight.length === 1) { return xWeight; } else if (typeof xWeight === 'object' && outputNames[0] in xWeight) { return [xWeight[outputNames[0]]]; } else { return [xWeight]; } } if (Array.isArray(xWeight)) { if (xWeight.length !== numOutputs) { throw new Error(`Provided ${weightType} is an array of ${xWeight.length} ` + `element(s), but the model has ${numOutputs} outputs. ` + `Make sure a set of weights is provided for each model output.`); } return xWeight; } else if (typeof xWeight === 'object' && Object.keys(xWeight).length > 0 && typeof xWeight[Object.keys(xWeight)[0]] === 'object') { const output = []; outputNames.forEach(outputName => { if (outputName in xWeight) { output.push(xWeight[outputName]); } else { output.push(null); } }); return output; } else { throw new Error(`The model has multiple (${numOutputs}) outputs, ` + `so ${weightType} must be either an array with ` + `${numOutputs} elements or an object with ${outputNames} keys. ` + `Provided ${weightType} not understood: ${JSON.stringify(xWeight)}`); } } /** * Standardize class weighting objects. * * This function takes a single class-weighting object, an array of them, * or a map from output name to class-weighting object. It compares it to the * output name(s) of the model, base on which it outputs an array of * class-weighting objects of which the length matches the number of outputs. * * @param classWeight Input class-weighting object(s). * @param outputNames All output name(s) of the model. * @return An array of class-weighting objects. The length of the array matches * the model's number of outputs. */ export function standardizeClassWeights(classWeight, outputNames) { return standardizeSampleOrClassWeights(classWeight, outputNames, 'classWeight'); } export function standardizeSampleWeights(classWeight, outputNames) { return standardizeSampleOrClassWeights(classWeight, outputNames, 'sampleWeight'); } /** * Standardize by-sample and/or by-class weights for training. * * Note that this function operates on one model output at a time. For a model * with multiple outputs, you must call this function multiple times. * * @param y The target tensor that the by-sample and/or by-class weight is for. * The values of y are assumed to encode the classes, either directly * as an integer index, or as one-hot encoding. * @param sampleWeight By-sample weights. * @param classWeight By-class weights: an object mapping class indices * (integers) to a weight (float) to apply to the model's loss for the * samples from this class during training. This can be useful to tell the * model to "pay more attention" to samples from an under-represented class. * @param sampleWeightMode The mode for the sample weights. * @return A Promise of weight tensor, of which the size of the first dimension * matches that of `y`. */ export async function standardizeWeights(y, sampleWeight, classWeight, sampleWeightMode) { if (sampleWeight != null || sampleWeightMode != null) { // TODO(cais): Once 'temporal' mode is implemented, document it in the doc // string. throw new Error('Support sampleWeight is not implemented yet'); } if (classWeight != null) { // Apply class weights per sample. const yClasses = tidy(() => { if (y.shape.length === 1) { // Assume class indices. return clone(y); } else if (y.shape.length === 2) { if (y.shape[1] > 1) { // Assume one-hot encoding of classes. const axis = 1; return argMax(y, axis); } else if (y.shape[1] === 1) { // Class index. return reshape(y, [y.shape[0]]); } else { throw new Error(`Encountered unexpected last-dimension size (${y.shape[1]}) ` + `during handling of class weights. The size is expected to be ` + `>= 1.`); } } else { throw new Error(`Unexpected rank of target (y) tensor (${y.rank}) during ` + `handling of class weights. The rank is expected to be 1 or 2.`); } }); const yClassIndices = Array.from(await yClasses.data()); dispose(yClasses); const classSampleWeight = []; yClassIndices.forEach(classIndex => { if (classWeight[classIndex] == null) { throw new Error(`classWeight must contain all classes in the training data. ` + `The class ${classIndex} exists in the data but not in ` + `classWeight`); } else { classSampleWeight.push(classWeight[classIndex]); } }); return tensor1d(classSampleWeight, 'float32'); } else { return null; } } /** * Apply per-sample weights on the loss values from a number of samples. * * @param losses Loss tensor of shape `[batchSize]`. * @param sampleWeights Per-sample weight tensor of shape `[batchSize]`. * @returns Tensor of the same shape as`losses`. */ export function computeWeightedLoss(losses, sampleWeights) { return mul(losses, sampleWeights); } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"training_utils.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/engine/training_utils.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,OAAO,EAAC,MAAM,EAAE,KAAK,EAAE,OAAO,EAAE,GAAG,EAAE,OAAO,EAAoB,QAAQ,EAAE,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAuB7G,SAAS,+BAA+B,CACpC,OAAiD,EAAE,WAAqB,EACxE,UAAwC;IAC1C,MAAM,UAAU,GAAG,WAAW,CAAC,MAAM,CAAC;IACtC,IAAI,OAAO,IAAI,IAAI,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,CAAC,EAAE;QACvE,OAAO,WAAW,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC;KACtC;IACD,IAAI,UAAU,KAAK,CAAC,EAAE;QACpB,IAAI,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,EAAE;YAClD,OAAO,OAAO,CAAC;SAChB;aAAM,IAAI,OAAO,OAAO,KAAK,QAAQ,IAAI,WAAW,CAAC,CAAC,CAAC,IAAI,OAAO,EAAE;YACnE,OAAO,CAAE,OAA0B,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;SACtD;aAAM;YACL,OAAO,CAAC,OAAsB,CAAC,CAAC;SACjC;KACF;IACD,IAAI,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,EAAE;QAC1B,IAAI,OAAO,CAAC,MAAM,KAAK,UAAU,EAAE;YACjC,MAAM,IAAI,KAAK,CACX,YAAY,UAAU,mBAAmB,OAAO,CAAC,MAAM,GAAG;gBAC1D,iCAAiC,UAAU,YAAY;gBACvD,+DAA+D,CAAC,CAAC;SACtE;QACD,OAAO,OAAO,CAAC;KAChB;SAAM,IACH,OAAO,OAAO,KAAK,QAAQ,IAAI,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,MAAM,GAAG,CAAC;QAC9D,OAAQ,OAA0B,CAAC,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC;YACvD,QAAQ,EAAE;QAChB,MAAM,MAAM,GAAkB,EAAE,CAAC;QACjC,WAAW,CAAC,OAAO,CAAC,UAAU,CAAC,EAAE;YAC/B,IAAI,UAAU,IAAI,OAAO,EAAE;gBACzB,MAAM,CAAC,IAAI,CAAE,OAA0B,CAAC,UAAU,CAAC,CAAC,CAAC;aACtD;iBAAM;gBACL,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;aACnB;QACH,CAAC,CAAC,CAAC;QACH,OAAO,MAAM,CAAC;KACf;SAAM;QACL,MAAM,IAAI,KAAK,CACX,2BAA2B,UAAU,aAAa;YAClD,MAAM,UAAU,gCAAgC;YAChD,GAAG,UAAU,+BAA+B,WAAW,SAAS;YAChE,YAAY,UAAU,oBAAoB,IAAI,CAAC,SAAS,CAAC,OAAO,CAAC,EAAE,CAAC,CAAC;KAC1E;AACH,CAAC;AAED;;;;;;;;;;;;GAYG;AACH,MAAM,UAAU,uBAAuB,CACnC,WAAqD,EACrD,WAAqB;IACvB,OAAO,+BAA+B,CAClC,WAAW,EAAE,WAAW,EAAE,aAAa,CAAC,CAAC;AAC/C,CAAC;AAED,MAAM,UAAU,wBAAwB,CACpC,WAAqD,EACrD,WAAqB;IACvB,OAAO,+BAA+B,CAClC,WAAW,EAAE,WAAW,EAAE,cAAc,CAAC,CAAC;AAChD,CAAC;AAED;;;;;;;;;;;;;;;;;GAiBG;AACH,MAAM,CAAC,KAAK,UAAU,kBAAkB,CACpC,CAAS,EAAE,YAAqB,EAAE,WAAyB,EAC3D,gBAA6B;IAC/B,IAAI,YAAY,IAAI,IAAI,IAAI,gBAAgB,IAAI,IAAI,EAAE;QACpD,0EAA0E;QAC1E,UAAU;QACV,MAAM,IAAI,KAAK,CAAC,6CAA6C,CAAC,CAAC;KAChE;IAED,IAAI,WAAW,IAAI,IAAI,EAAE;QACvB,kCAAkC;QAClC,MAAM,QAAQ,GAAa,IAAI,CAAC,GAAG,EAAE;YACnC,IAAI,CAAC,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC,EAAE;gBACxB,wBAAwB;gBACxB,OAAO,KAAK,CAAC,CAAC,CAAa,CAAC;aAC7B;iBAAM,IAAI,CAAC,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC,EAAE;gBAC/B,IAAI,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE;oBAClB,sCAAsC;oBACtC,MAAM,IAAI,GAAG,CAAC,CAAC;oBACf,OAAO,MAAM,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC;iBACxB;qBAAM,IAAI,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,EAAE;oBAC3B,eAAe;oBACf,OAAO,OAAO,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;iBACjC;qBAAM;oBACL,MAAM,IAAI,KAAK,CACX,+CAA+C,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI;wBAC7D,+DAA+D;wBAC/D,OAAO,CAAC,CAAC;iBACd;aACF;iBAAM;gBACL,MAAM,IAAI,KAAK,CACX,yCAAyC,CAAC,CAAC,IAAI,WAAW;oBAC1D,+DAA+D,CAAC,CAAC;aACtE;QACH,CAAC,CAAC,CAAC;QAEH,MAAM,aAAa,GAAG,KAAK,CAAC,IAAI,CAAC,MAAM,QAAQ,CAAC,IAAI,EAAE,CAAC,CAAC;QACxD,OAAO,CAAC,QAAQ,CAAC,CAAC;QAClB,MAAM,iBAAiB,GAAa,EAAE,CAAC;QACvC,aAAa,CAAC,OAAO,CAAC,UAAU,CAAC,EAAE;YACjC,IAAI,WAAW,CAAC,UAAU,CAAC,IAAI,IAAI,EAAE;gBACnC,MAAM,IAAI,KAAK,CACX,6DAA6D;oBAC7D,aAAa,UAAU,iCAAiC;oBACxD,aAAa,CAAC,CAAC;aACpB;iBAAM;gBACL,iBAAiB,CAAC,IAAI,CAAC,WAAW,CAAC,UAAU,CAAC,CAAC,CAAC;aACjD;QACH,CAAC,CAAC,CAAC;QAEH,OAAO,QAAQ,CAAC,iBAAiB,EAAE,SAAS,CAAC,CAAC;KAC/C;SAAM;QACL,OAAO,IAAI,CAAC;KACb;AACH,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,mBAAmB,CAAC,MAAc,EAAE,aAAqB;IACvE,OAAO,GAAG,CAAC,MAAM,EAAE,aAAa,CAAC,CAAC;AACpC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\nimport {argMax, clone, dispose, mul, reshape, Tensor, Tensor1D, tensor1d, tidy} from '@tensorflow/tfjs-core';\n\n/**\n * For multi-class classification problems, this object is designed to store a\n * mapping from class index to the \"weight\" of the class, where higher weighted\n * classes have larger impact on loss, accuracy, and other metrics.\n *\n * This is useful for cases in which you want the model to \"pay more attention\"\n * to examples from an under-represented class, e.g., in unbalanced datasets.\n */\nexport type ClassWeight = {\n  [classIndex: number]: number\n};\n\n/**\n * Class weighting for a model with multiple outputs.\n *\n * This object maps each output name to a class-weighting object.\n */\nexport type ClassWeightMap = {\n  [outputName: string]: ClassWeight\n};\n\nfunction standardizeSampleOrClassWeights(\n    xWeight: ClassWeight|ClassWeight[]|ClassWeightMap, outputNames: string[],\n    weightType: 'sampleWeight'|'classWeight'): ClassWeight[] {\n  const numOutputs = outputNames.length;\n  if (xWeight == null || (Array.isArray(xWeight) && xWeight.length === 0)) {\n    return outputNames.map(name => null);\n  }\n  if (numOutputs === 1) {\n    if (Array.isArray(xWeight) && xWeight.length === 1) {\n      return xWeight;\n    } else if (typeof xWeight === 'object' && outputNames[0] in xWeight) {\n      return [(xWeight as ClassWeightMap)[outputNames[0]]];\n    } else {\n      return [xWeight as ClassWeight];\n    }\n  }\n  if (Array.isArray(xWeight)) {\n    if (xWeight.length !== numOutputs) {\n      throw new Error(\n          `Provided ${weightType} is an array of ${xWeight.length} ` +\n          `element(s), but the model has ${numOutputs} outputs. ` +\n          `Make sure a set of weights is provided for each model output.`);\n    }\n    return xWeight;\n  } else if (\n      typeof xWeight === 'object' && Object.keys(xWeight).length > 0 &&\n      typeof (xWeight as ClassWeightMap)[Object.keys(xWeight)[0]] ===\n          'object') {\n    const output: ClassWeight[] = [];\n    outputNames.forEach(outputName => {\n      if (outputName in xWeight) {\n        output.push((xWeight as ClassWeightMap)[outputName]);\n      } else {\n        output.push(null);\n      }\n    });\n    return output;\n  } else {\n    throw new Error(\n        `The model has multiple (${numOutputs}) outputs, ` +\n        `so ${weightType} must be either an array with ` +\n        `${numOutputs} elements or an object with ${outputNames} keys. ` +\n        `Provided ${weightType} not understood: ${JSON.stringify(xWeight)}`);\n  }\n}\n\n/**\n * Standardize class weighting objects.\n *\n * This function takes a single class-weighting object, an array of them,\n * or a map from output name to class-weighting object. It compares it to the\n * output name(s) of the model, base on which it outputs an array of\n * class-weighting objects of which the length matches the number of outputs.\n *\n * @param classWeight Input class-weighting object(s).\n * @param outputNames All output name(s) of the model.\n * @return An array of class-weighting objects. The length of the array matches\n *   the model's number of outputs.\n */\nexport function standardizeClassWeights(\n    classWeight: ClassWeight|ClassWeight[]|ClassWeightMap,\n    outputNames: string[]): ClassWeight[] {\n  return standardizeSampleOrClassWeights(\n      classWeight, outputNames, 'classWeight');\n}\n\nexport function standardizeSampleWeights(\n    classWeight: ClassWeight|ClassWeight[]|ClassWeightMap,\n    outputNames: string[]): ClassWeight[] {\n  return standardizeSampleOrClassWeights(\n      classWeight, outputNames, 'sampleWeight');\n}\n\n/**\n * Standardize by-sample and/or by-class weights for training.\n *\n * Note that this function operates on one model output at a time. For a model\n * with multiple outputs, you must call this function multiple times.\n *\n * @param y The target tensor that the by-sample and/or by-class weight is for.\n *     The values of y are assumed to encode the classes, either directly\n *     as an integer index, or as one-hot encoding.\n * @param sampleWeight By-sample weights.\n * @param classWeight By-class weights: an object mapping class indices\n *     (integers) to a weight (float) to apply to the model's loss for the\n *     samples from this class during training. This can be useful to tell the\n *     model to \"pay more attention\" to samples from an under-represented class.\n * @param sampleWeightMode The mode for the sample weights.\n * @return A Promise of weight tensor, of which the size of the first dimension\n *     matches that of `y`.\n */\nexport async function standardizeWeights(\n    y: Tensor, sampleWeight?: Tensor, classWeight?: ClassWeight,\n    sampleWeightMode?: 'temporal'): Promise<Tensor> {\n  if (sampleWeight != null || sampleWeightMode != null) {\n    // TODO(cais): Once 'temporal' mode is implemented, document it in the doc\n    // string.\n    throw new Error('Support sampleWeight is not implemented yet');\n  }\n\n  if (classWeight != null) {\n    // Apply class weights per sample.\n    const yClasses: Tensor1D = tidy(() => {\n      if (y.shape.length === 1) {\n        // Assume class indices.\n        return clone(y) as Tensor1D;\n      } else if (y.shape.length === 2) {\n        if (y.shape[1] > 1) {\n          // Assume one-hot encoding of classes.\n          const axis = 1;\n          return argMax(y, axis);\n        } else if (y.shape[1] === 1) {\n          // Class index.\n          return reshape(y, [y.shape[0]]);\n        } else {\n          throw new Error(\n              `Encountered unexpected last-dimension size (${y.shape[1]}) ` +\n              `during handling of class weights. The size is expected to be ` +\n              `>= 1.`);\n        }\n      } else {\n        throw new Error(\n            `Unexpected rank of target (y) tensor (${y.rank}) during ` +\n            `handling of class weights. The rank is expected to be 1 or 2.`);\n      }\n    });\n\n    const yClassIndices = Array.from(await yClasses.data());\n    dispose(yClasses);\n    const classSampleWeight: number[] = [];\n    yClassIndices.forEach(classIndex => {\n      if (classWeight[classIndex] == null) {\n        throw new Error(\n            `classWeight must contain all classes in the training data. ` +\n            `The class ${classIndex} exists in the data but not in ` +\n            `classWeight`);\n      } else {\n        classSampleWeight.push(classWeight[classIndex]);\n      }\n    });\n\n    return tensor1d(classSampleWeight, 'float32');\n  } else {\n    return null;\n  }\n}\n\n/**\n * Apply per-sample weights on the loss values from a number of samples.\n *\n * @param losses Loss tensor of shape `[batchSize]`.\n * @param sampleWeights Per-sample weight tensor of shape `[batchSize]`.\n * @returns Tensor of the same shape as`losses`.\n */\nexport function computeWeightedLoss(losses: Tensor, sampleWeights: Tensor) {\n  return mul(losses, sampleWeights);\n}\n"]}