UNPKG

@tensorflow/tfjs-data

Version:

TensorFlow Data API in JavaScript

163 lines 22.7 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * ============================================================================= */ import * as tf from '@tensorflow/tfjs-core'; import { div, max, min, sub } from '@tensorflow/tfjs-core'; /** * Provides a function that scales numeric values into the [0, 1] interval. * * @param min the lower bound of the inputs, which should be mapped to 0. * @param max the upper bound of the inputs, which should be mapped to 1, * @return A function that maps an input ElementArray to a scaled ElementArray. */ export function scaleTo01(min, max) { const range = max - min; const minTensor = tf.scalar(min); const rangeTensor = tf.scalar(range); return (value) => { if (typeof (value) === 'string') { throw new Error('Can\'t scale a string.'); } else { if (value instanceof tf.Tensor) { const result = div(sub(value, minTensor), rangeTensor); return result; } else if (value instanceof Array) { return value.map(v => (v - min) / range); } else { return (value - min) / range; } } }; } /** * Provides a function that calculates column level statistics, i.e. min, max, * variance, stddev. * * @param dataset The Dataset object whose statistics will be calculated. * @param sampleSize (Optional) If set, statistics will only be calculated * against a subset of the whole data. * @param shuffleWindowSize (Optional) If set, shuffle provided dataset before * calculating statistics. * @return A DatasetStatistics object that contains NumericColumnStatistics of * each column. */ export async function computeDatasetStatistics(dataset, sampleSize, shuffleWindowSize) { let sampleDataset = dataset; // TODO(soergel): allow for deep shuffle where possible. if (shuffleWindowSize != null) { sampleDataset = sampleDataset.shuffle(shuffleWindowSize); } if (sampleSize != null) { sampleDataset = sampleDataset.take(sampleSize); } // TODO(soergel): prepare the column objects based on a schema. const result = {}; await sampleDataset.forEachAsync(e => { for (const key of Object.keys(e)) { const value = e[key]; if (typeof (value) === 'string') { // No statistics for string element. } else { let previousMean = 0; let previousLength = 0; let previousVariance = 0; let columnStats = result[key]; if (columnStats == null) { columnStats = { min: Number.POSITIVE_INFINITY, max: Number.NEGATIVE_INFINITY, mean: 0, variance: 0, stddev: 0, length: 0 }; result[key] = columnStats; } else { previousMean = columnStats.mean; previousLength = columnStats.length; previousVariance = columnStats.variance; } let recordMin; let recordMax; // Calculate accumulated mean and variance following tf.Transform // implementation let valueLength = 0; let valueMean = 0; let valueVariance = 0; let combinedLength = 0; let combinedMean = 0; let combinedVariance = 0; if (value instanceof tf.Tensor) { recordMin = min(value).dataSync()[0]; recordMax = max(value).dataSync()[0]; const valueMoment = tf.moments(value); valueMean = valueMoment.mean.dataSync()[0]; valueVariance = valueMoment.variance.dataSync()[0]; valueLength = value.size; } else if (value instanceof Array) { recordMin = value.reduce((a, b) => Math.min(a, b)); recordMax = value.reduce((a, b) => Math.max(a, b)); const valueMoment = tf.moments(value); valueMean = valueMoment.mean.dataSync()[0]; valueVariance = valueMoment.variance.dataSync()[0]; valueLength = value.length; } else if (!isNaN(value) && isFinite(value)) { recordMin = value; recordMax = value; valueMean = value; valueVariance = 0; valueLength = 1; } else { columnStats = null; continue; } combinedLength = previousLength + valueLength; combinedMean = previousMean + (valueLength / combinedLength) * (valueMean - previousMean); combinedVariance = previousVariance + (valueLength / combinedLength) * (valueVariance + ((valueMean - combinedMean) * (valueMean - previousMean)) - previousVariance); columnStats.min = Math.min(columnStats.min, recordMin); columnStats.max = Math.max(columnStats.max, recordMax); columnStats.length = combinedLength; columnStats.mean = combinedMean; columnStats.variance = combinedVariance; columnStats.stddev = Math.sqrt(combinedVariance); } } }); // Variance and stddev should be NaN for the case of a single element. for (const key in result) { const stat = result[key]; if (stat.length === 1) { stat.variance = NaN; stat.stddev = NaN; } } return result; } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"statistics.js","sourceRoot":"","sources":["../../../../../tfjs-data/src/statistics.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;GAgBG;AAEH,OAAO,KAAK,EAAE,MAAM,uBAAuB,CAAC;AAC5C,OAAO,EAAC,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAAC,MAAM,uBAAuB,CAAC;AA8CzD;;;;;;GAMG;AACH,MAAM,UAAU,SAAS,CAAC,GAAW,EAAE,GAAW;IAEhD,MAAM,KAAK,GAAG,GAAG,GAAG,GAAG,CAAC;IACxB,MAAM,SAAS,GAAc,EAAE,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC;IAC5C,MAAM,WAAW,GAAc,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;IAChD,OAAO,CAAC,KAAmB,EAAgB,EAAE;QAC3C,IAAI,OAAO,CAAC,KAAK,CAAC,KAAK,QAAQ,EAAE;YAC/B,MAAM,IAAI,KAAK,CAAC,wBAAwB,CAAC,CAAC;SAC3C;aAAM;YACL,IAAI,KAAK,YAAY,EAAE,CAAC,MAAM,EAAE;gBAC9B,MAAM,MAAM,GAAG,GAAG,CAAC,GAAG,CAAC,KAAK,EAAE,SAAS,CAAC,EAAE,WAAW,CAAC,CAAC;gBACvD,OAAO,MAAM,CAAC;aACf;iBAAM,IAAI,KAAK,YAAY,KAAK,EAAE;gBACjC,OAAO,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,GAAG,CAAC,GAAG,KAAK,CAAC,CAAC;aAC1C;iBAAM;gBACL,OAAO,CAAC,KAAK,GAAG,GAAG,CAAC,GAAG,KAAK,CAAC;aAC9B;SACF;IACH,CAAC,CAAC;AACJ,CAAC;AAED;;;;;;;;;;;GAWG;AACH,MAAM,CAAC,KAAK,UAAU,wBAAwB,CAC1C,OAA+B,EAAE,UAAmB,EACpD,iBAA0B;IAC5B,IAAI,aAAa,GAAG,OAAO,CAAC;IAC5B,wDAAwD;IACxD,IAAI,iBAAiB,IAAI,IAAI,EAAE;QAC7B,aAAa,GAAG,aAAa,CAAC,OAAO,CAAC,iBAAiB,CAAC,CAAC;KAC1D;IACD,IAAI,UAAU,IAAI,IAAI,EAAE;QACtB,aAAa,GAAG,aAAa,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;KAChD;IAED,+DAA+D;IAC/D,MAAM,MAAM,GAAsB,EAAE,CAAC;IAErC,MAAM,aAAa,CAAC,YAAY,CAAC,CAAC,CAAC,EAAE;QACnC,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE;YAChC,MAAM,KAAK,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;YACrB,IAAI,OAAO,CAAC,KAAK,CAAC,KAAK,QAAQ,EAAE;gBAC/B,oCAAoC;aACrC;iBAAM;gBACL,IAAI,YAAY,GAAG,CAAC,CAAC;gBACrB,IAAI,cAAc,GAAG,CAAC,CAAC;gBACvB,IAAI,gBAAgB,GAAG,CAAC,CAAC;gBACzB,IAAI,WAAW,GAA4B,MAAM,CAAC,GAAG,CAAC,CAAC;gBACvD,IAAI,WAAW,IAAI,IAAI,EAAE;oBACvB,WAAW,GAAG;wBACZ,GAAG,EAAE,MAAM,CAAC,iBAAiB;wBAC7B,GAAG,EAAE,MAAM,CAAC,iBAAiB;wBAC7B,IAAI,EAAE,CAAC;wBACP,QAAQ,EAAE,CAAC;wBACX,MAAM,EAAE,CAAC;wBACT,MAAM,EAAE,CAAC;qBACV,CAAC;oBACF,MAAM,CAAC,GAAG,CAAC,GAAG,WAAW,CAAC;iBAC3B;qBAAM;oBACL,YAAY,GAAG,WAAW,CAAC,IAAI,CAAC;oBAChC,cAAc,GAAG,WAAW,CAAC,MAAM,CAAC;oBACpC,gBAAgB,GAAG,WAAW,CAAC,QAAQ,CAAC;iBACzC;gBACD,IAAI,SAAiB,CAAC;gBACtB,IAAI,SAAiB,CAAC;gBAEtB,iEAAiE;gBACjE,iBAAiB;gBACjB,IAAI,WAAW,GAAG,CAAC,CAAC;gBACpB,IAAI,SAAS,GAAG,CAAC,CAAC;gBAClB,IAAI,aAAa,GAAG,CAAC,CAAC;gBACtB,IAAI,cAAc,GAAG,CAAC,CAAC;gBACvB,IAAI,YAAY,GAAG,CAAC,CAAC;gBACrB,IAAI,gBAAgB,GAAG,CAAC,CAAC;gBAEzB,IAAI,KAAK,YAAY,EAAE,CAAC,MAAM,EAAE;oBAC9B,SAAS,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC;oBACrC,SAAS,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC;oBACrC,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC;oBACtC,SAAS,GAAG,WAAW,CAAC,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC;oBAC3C,aAAa,GAAG,WAAW,CAAC,QAAQ,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC;oBACnD,WAAW,GAAG,KAAK,CAAC,IAAI,CAAC;iBAE1B;qBAAM,IAAI,KAAK,YAAY,KAAK,EAAE;oBACjC,SAAS,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;oBACnD,SAAS,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;oBACnD,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC;oBACtC,SAAS,GAAG,WAAW,CAAC,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC;oBAC3C,aAAa,GAAG,WAAW,CAAC,QAAQ,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC;oBACnD,WAAW,GAAG,KAAK,CAAC,MAAM,CAAC;iBAE5B;qBAAM,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,IAAI,QAAQ,CAAC,KAAK,CAAC,EAAE;oBAC3C,SAAS,GAAG,KAAK,CAAC;oBAClB,SAAS,GAAG,KAAK,CAAC;oBAClB,SAAS,GAAG,KAAK,CAAC;oBAClB,aAAa,GAAG,CAAC,CAAC;oBAClB,WAAW,GAAG,CAAC,CAAC;iBAEjB;qBAAM;oBACL,WAAW,GAAG,IAAI,CAAC;oBACnB,SAAS;iBACV;gBACD,cAAc,GAAG,cAAc,GAAG,WAAW,CAAC;gBAC9C,YAAY,GAAG,YAAY;oBACvB,CAAC,WAAW,GAAG,cAAc,CAAC,GAAG,CAAC,SAAS,GAAG,YAAY,CAAC,CAAC;gBAChE,gBAAgB,GAAG,gBAAgB;oBAC/B,CAAC,WAAW,GAAG,cAAc,CAAC;wBAC1B,CAAC,aAAa;4BACb,CAAC,CAAC,SAAS,GAAG,YAAY,CAAC,GAAG,CAAC,SAAS,GAAG,YAAY,CAAC,CAAC;4BACzD,gBAAgB,CAAC,CAAC;gBAE3B,WAAW,CAAC,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,WAAW,CAAC,GAAG,EAAE,SAAS,CAAC,CAAC;gBACvD,WAAW,CAAC,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,WAAW,CAAC,GAAG,EAAE,SAAS,CAAC,CAAC;gBACvD,WAAW,CAAC,MAAM,GAAG,cAAc,CAAC;gBACpC,WAAW,CAAC,IAAI,GAAG,YAAY,CAAC;gBAChC,WAAW,CAAC,QAAQ,GAAG,gBAAgB,CAAC;gBACxC,WAAW,CAAC,MAAM,GAAG,IAAI,CAAC,IAAI,CAAC,gBAAgB,CAAC,CAAC;aAClD;SACF;IACH,CAAC,CAAC,CAAC;IACH,sEAAsE;IACtE,KAAK,MAAM,GAAG,IAAI,MAAM,EAAE;QACxB,MAAM,IAAI,GAA4B,MAAM,CAAC,GAAG,CAAC,CAAC;QAClD,IAAI,IAAI,CAAC,MAAM,KAAK,CAAC,EAAE;YACrB,IAAI,CAAC,QAAQ,GAAG,GAAG,CAAC;YACpB,IAAI,CAAC,MAAM,GAAG,GAAG,CAAC;SACnB;KACF;IACD,OAAO,MAAM,CAAC;AAChB,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n * =============================================================================\n */\n\nimport * as tf from '@tensorflow/tfjs-core';\nimport {div, max, min, sub} from '@tensorflow/tfjs-core';\n\nimport {Dataset} from './dataset';\n\n// TODO(kangyizhang): eliminate the need for ElementArray and TabularRecord, by\n// computing stats on nested structures via deepMap/deepZip.\n\n/**\n * The value associated with a given key for a single element.\n *\n * Such a value may not have a batch dimension.  A value may be a scalar or an\n * n-dimensional array.\n */\nexport type ElementArray = number|number[]|tf.Tensor|string;\n\n/**\n * A map from string keys (aka column names) to values for a single element.\n */\nexport type TabularRecord = {\n  [key: string]: ElementArray\n};\n\n// TODO(kangyizhang): Flesh out collected statistics.\n// For numeric columns we should provide mean, stddev, histogram, etc.\n// For string columns we should provide a vocabulary (at least, top-k), maybe a\n// length histogram, etc.\n// Collecting only numeric min and max is just the bare minimum for now.\n\n/** An interface representing numeric statistics of a column. */\nexport interface NumericColumnStatistics {\n  min: number;\n  max: number;\n  mean: number;\n  variance: number;\n  stddev: number;\n  length: number;\n}\n\n/**\n * An interface representing column level NumericColumnStatistics for a\n * Dataset.\n */\nexport interface DatasetStatistics {\n  [key: string]: NumericColumnStatistics;\n}\n\n/**\n * Provides a function that scales numeric values into the [0, 1] interval.\n *\n * @param min the lower bound of the inputs, which should be mapped to 0.\n * @param max the upper bound of the inputs, which should be mapped to 1,\n * @return A function that maps an input ElementArray to a scaled ElementArray.\n */\nexport function scaleTo01(min: number, max: number): (value: ElementArray) =>\n    ElementArray {\n  const range = max - min;\n  const minTensor: tf.Tensor = tf.scalar(min);\n  const rangeTensor: tf.Tensor = tf.scalar(range);\n  return (value: ElementArray): ElementArray => {\n    if (typeof (value) === 'string') {\n      throw new Error('Can\\'t scale a string.');\n    } else {\n      if (value instanceof tf.Tensor) {\n        const result = div(sub(value, minTensor), rangeTensor);\n        return result;\n      } else if (value instanceof Array) {\n        return value.map(v => (v - min) / range);\n      } else {\n        return (value - min) / range;\n      }\n    }\n  };\n}\n\n/**\n * Provides a function that calculates column level statistics, i.e. min, max,\n * variance, stddev.\n *\n * @param dataset The Dataset object whose statistics will be calculated.\n * @param sampleSize (Optional) If set, statistics will only be calculated\n *     against a subset of the whole data.\n * @param shuffleWindowSize (Optional) If set, shuffle provided dataset before\n *     calculating statistics.\n * @return A DatasetStatistics object that contains NumericColumnStatistics of\n *     each column.\n */\nexport async function computeDatasetStatistics(\n    dataset: Dataset<TabularRecord>, sampleSize?: number,\n    shuffleWindowSize?: number): Promise<DatasetStatistics> {\n  let sampleDataset = dataset;\n  // TODO(soergel): allow for deep shuffle where possible.\n  if (shuffleWindowSize != null) {\n    sampleDataset = sampleDataset.shuffle(shuffleWindowSize);\n  }\n  if (sampleSize != null) {\n    sampleDataset = sampleDataset.take(sampleSize);\n  }\n\n  // TODO(soergel): prepare the column objects based on a schema.\n  const result: DatasetStatistics = {};\n\n  await sampleDataset.forEachAsync(e => {\n    for (const key of Object.keys(e)) {\n      const value = e[key];\n      if (typeof (value) === 'string') {\n        // No statistics for string element.\n      } else {\n        let previousMean = 0;\n        let previousLength = 0;\n        let previousVariance = 0;\n        let columnStats: NumericColumnStatistics = result[key];\n        if (columnStats == null) {\n          columnStats = {\n            min: Number.POSITIVE_INFINITY,\n            max: Number.NEGATIVE_INFINITY,\n            mean: 0,\n            variance: 0,\n            stddev: 0,\n            length: 0\n          };\n          result[key] = columnStats;\n        } else {\n          previousMean = columnStats.mean;\n          previousLength = columnStats.length;\n          previousVariance = columnStats.variance;\n        }\n        let recordMin: number;\n        let recordMax: number;\n\n        // Calculate accumulated mean and variance following tf.Transform\n        // implementation\n        let valueLength = 0;\n        let valueMean = 0;\n        let valueVariance = 0;\n        let combinedLength = 0;\n        let combinedMean = 0;\n        let combinedVariance = 0;\n\n        if (value instanceof tf.Tensor) {\n          recordMin = min(value).dataSync()[0];\n          recordMax = max(value).dataSync()[0];\n          const valueMoment = tf.moments(value);\n          valueMean = valueMoment.mean.dataSync()[0];\n          valueVariance = valueMoment.variance.dataSync()[0];\n          valueLength = value.size;\n\n        } else if (value instanceof Array) {\n          recordMin = value.reduce((a, b) => Math.min(a, b));\n          recordMax = value.reduce((a, b) => Math.max(a, b));\n          const valueMoment = tf.moments(value);\n          valueMean = valueMoment.mean.dataSync()[0];\n          valueVariance = valueMoment.variance.dataSync()[0];\n          valueLength = value.length;\n\n        } else if (!isNaN(value) && isFinite(value)) {\n          recordMin = value;\n          recordMax = value;\n          valueMean = value;\n          valueVariance = 0;\n          valueLength = 1;\n\n        } else {\n          columnStats = null;\n          continue;\n        }\n        combinedLength = previousLength + valueLength;\n        combinedMean = previousMean +\n            (valueLength / combinedLength) * (valueMean - previousMean);\n        combinedVariance = previousVariance +\n            (valueLength / combinedLength) *\n                (valueVariance +\n                 ((valueMean - combinedMean) * (valueMean - previousMean)) -\n                 previousVariance);\n\n        columnStats.min = Math.min(columnStats.min, recordMin);\n        columnStats.max = Math.max(columnStats.max, recordMax);\n        columnStats.length = combinedLength;\n        columnStats.mean = combinedMean;\n        columnStats.variance = combinedVariance;\n        columnStats.stddev = Math.sqrt(combinedVariance);\n      }\n    }\n  });\n  // Variance and stddev should be NaN for the case of a single element.\n  for (const key in result) {\n    const stat: NumericColumnStatistics = result[key];\n    if (stat.length === 1) {\n      stat.variance = NaN;\n      stat.stddev = NaN;\n    }\n  }\n  return result;\n}\n"]}