@tensorflow/tfjs-data
Version:
TensorFlow Data API in JavaScript
163 lines • 22.7 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs-core';
import { div, max, min, sub } from '@tensorflow/tfjs-core';
/**
* Provides a function that scales numeric values into the [0, 1] interval.
*
* @param min the lower bound of the inputs, which should be mapped to 0.
* @param max the upper bound of the inputs, which should be mapped to 1,
* @return A function that maps an input ElementArray to a scaled ElementArray.
*/
export function scaleTo01(min, max) {
const range = max - min;
const minTensor = tf.scalar(min);
const rangeTensor = tf.scalar(range);
return (value) => {
if (typeof (value) === 'string') {
throw new Error('Can\'t scale a string.');
}
else {
if (value instanceof tf.Tensor) {
const result = div(sub(value, minTensor), rangeTensor);
return result;
}
else if (value instanceof Array) {
return value.map(v => (v - min) / range);
}
else {
return (value - min) / range;
}
}
};
}
/**
* Provides a function that calculates column level statistics, i.e. min, max,
* variance, stddev.
*
* @param dataset The Dataset object whose statistics will be calculated.
* @param sampleSize (Optional) If set, statistics will only be calculated
* against a subset of the whole data.
* @param shuffleWindowSize (Optional) If set, shuffle provided dataset before
* calculating statistics.
* @return A DatasetStatistics object that contains NumericColumnStatistics of
* each column.
*/
export async function computeDatasetStatistics(dataset, sampleSize, shuffleWindowSize) {
let sampleDataset = dataset;
// TODO(soergel): allow for deep shuffle where possible.
if (shuffleWindowSize != null) {
sampleDataset = sampleDataset.shuffle(shuffleWindowSize);
}
if (sampleSize != null) {
sampleDataset = sampleDataset.take(sampleSize);
}
// TODO(soergel): prepare the column objects based on a schema.
const result = {};
await sampleDataset.forEachAsync(e => {
for (const key of Object.keys(e)) {
const value = e[key];
if (typeof (value) === 'string') {
// No statistics for string element.
}
else {
let previousMean = 0;
let previousLength = 0;
let previousVariance = 0;
let columnStats = result[key];
if (columnStats == null) {
columnStats = {
min: Number.POSITIVE_INFINITY,
max: Number.NEGATIVE_INFINITY,
mean: 0,
variance: 0,
stddev: 0,
length: 0
};
result[key] = columnStats;
}
else {
previousMean = columnStats.mean;
previousLength = columnStats.length;
previousVariance = columnStats.variance;
}
let recordMin;
let recordMax;
// Calculate accumulated mean and variance following tf.Transform
// implementation
let valueLength = 0;
let valueMean = 0;
let valueVariance = 0;
let combinedLength = 0;
let combinedMean = 0;
let combinedVariance = 0;
if (value instanceof tf.Tensor) {
recordMin = min(value).dataSync()[0];
recordMax = max(value).dataSync()[0];
const valueMoment = tf.moments(value);
valueMean = valueMoment.mean.dataSync()[0];
valueVariance = valueMoment.variance.dataSync()[0];
valueLength = value.size;
}
else if (value instanceof Array) {
recordMin = value.reduce((a, b) => Math.min(a, b));
recordMax = value.reduce((a, b) => Math.max(a, b));
const valueMoment = tf.moments(value);
valueMean = valueMoment.mean.dataSync()[0];
valueVariance = valueMoment.variance.dataSync()[0];
valueLength = value.length;
}
else if (!isNaN(value) && isFinite(value)) {
recordMin = value;
recordMax = value;
valueMean = value;
valueVariance = 0;
valueLength = 1;
}
else {
columnStats = null;
continue;
}
combinedLength = previousLength + valueLength;
combinedMean = previousMean +
(valueLength / combinedLength) * (valueMean - previousMean);
combinedVariance = previousVariance +
(valueLength / combinedLength) *
(valueVariance +
((valueMean - combinedMean) * (valueMean - previousMean)) -
previousVariance);
columnStats.min = Math.min(columnStats.min, recordMin);
columnStats.max = Math.max(columnStats.max, recordMax);
columnStats.length = combinedLength;
columnStats.mean = combinedMean;
columnStats.variance = combinedVariance;
columnStats.stddev = Math.sqrt(combinedVariance);
}
}
});
// Variance and stddev should be NaN for the case of a single element.
for (const key in result) {
const stat = result[key];
if (stat.length === 1) {
stat.variance = NaN;
stat.stddev = NaN;
}
}
return result;
}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"statistics.js","sourceRoot":"","sources":["../../../../../tfjs-data/src/statistics.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;GAgBG;AAEH,OAAO,KAAK,EAAE,MAAM,uBAAuB,CAAC;AAC5C,OAAO,EAAC,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAAC,MAAM,uBAAuB,CAAC;AA8CzD;;;;;;GAMG;AACH,MAAM,UAAU,SAAS,CAAC,GAAW,EAAE,GAAW;IAEhD,MAAM,KAAK,GAAG,GAAG,GAAG,GAAG,CAAC;IACxB,MAAM,SAAS,GAAc,EAAE,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC;IAC5C,MAAM,WAAW,GAAc,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;IAChD,OAAO,CAAC,KAAmB,EAAgB,EAAE;QAC3C,IAAI,OAAO,CAAC,KAAK,CAAC,KAAK,QAAQ,EAAE;YAC/B,MAAM,IAAI,KAAK,CAAC,wBAAwB,CAAC,CAAC;SAC3C;aAAM;YACL,IAAI,KAAK,YAAY,EAAE,CAAC,MAAM,EAAE;gBAC9B,MAAM,MAAM,GAAG,GAAG,CAAC,GAAG,CAAC,KAAK,EAAE,SAAS,CAAC,EAAE,WAAW,CAAC,CAAC;gBACvD,OAAO,MAAM,CAAC;aACf;iBAAM,IAAI,KAAK,YAAY,KAAK,EAAE;gBACjC,OAAO,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,GAAG,CAAC,GAAG,KAAK,CAAC,CAAC;aAC1C;iBAAM;gBACL,OAAO,CAAC,KAAK,GAAG,GAAG,CAAC,GAAG,KAAK,CAAC;aAC9B;SACF;IACH,CAAC,CAAC;AACJ,CAAC;AAED;;;;;;;;;;;GAWG;AACH,MAAM,CAAC,KAAK,UAAU,wBAAwB,CAC1C,OAA+B,EAAE,UAAmB,EACpD,iBAA0B;IAC5B,IAAI,aAAa,GAAG,OAAO,CAAC;IAC5B,wDAAwD;IACxD,IAAI,iBAAiB,IAAI,IAAI,EAAE;QAC7B,aAAa,GAAG,aAAa,CAAC,OAAO,CAAC,iBAAiB,CAAC,CAAC;KAC1D;IACD,IAAI,UAAU,IAAI,IAAI,EAAE;QACtB,aAAa,GAAG,aAAa,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;KAChD;IAED,+DAA+D;IAC/D,MAAM,MAAM,GAAsB,EAAE,CAAC;IAErC,MAAM,aAAa,CAAC,YAAY,CAAC,CAAC,CAAC,EAAE;QACnC,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE;YAChC,MAAM,KAAK,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;YACrB,IAAI,OAAO,CAAC,KAAK,CAAC,KAAK,QAAQ,EAAE;gBAC/B,oCAAoC;aACrC;iBAAM;gBACL,IAAI,YAAY,GAAG,CAAC,CAAC;gBACrB,IAAI,cAAc,GAAG,CAAC,CAAC;gBACvB,IAAI,gBAAgB,GAAG,CAAC,CAAC;gBACzB,IAAI,WAAW,GAA4B,MAAM,CAAC,GAAG,CAAC,CAAC;gBACvD,IAAI,WAAW,IAAI,IAAI,EAAE;oBACvB,WAAW,GAAG;wBACZ,GAAG,EAAE,MAAM,CAAC,iBAAiB;wBAC7B,GAAG,EAAE,MAAM,CAAC,iBAAiB;wBAC7B,IAAI,EAAE,CAAC;wBACP,QAAQ,EAAE,CAAC;wBACX,MAAM,EAAE,CAAC;wBACT,MAAM,EAAE,CAAC;qBACV,CAAC;oBACF,MAAM,CAAC,GAAG,CAAC,GAAG,WAAW,CAAC;iBAC3B;qBAAM;oBACL,YAAY,GAAG,WAAW,CAAC,IAAI,CAAC;oBAChC,cAAc,GAAG,WAAW,CAAC,MAAM,CAAC;oBACpC,gBAAgB,GAAG,WAAW,CAAC,QAAQ,CAAC;iBACzC;gBACD,IAAI,SAAiB,CAAC;gBACtB,IAAI,SAAiB,CAAC;gBAEtB,iEAAiE;gBACjE,iBAAiB;gBACjB,IAAI,WAAW,GAAG,CAAC,CAAC;gBACpB,IAAI,SAAS,GAAG,CAAC,CAAC;gBAClB,IAAI,aAAa,GAAG,CAAC,CAAC;gBACtB,IAAI,cAAc,GAAG,CAAC,CAAC;gBACvB,IAAI,YAAY,GAAG,CAAC,CAAC;gBACrB,IAAI,gBAAgB,GAAG,CAAC,CAAC;gBAEzB,IAAI,KAAK,YAAY,EAAE,CAAC,MAAM,EAAE;oBAC9B,SAAS,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC;oBACrC,SAAS,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC;oBACrC,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC;oBACtC,SAAS,GAAG,WAAW,CAAC,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC;oBAC3C,aAAa,GAAG,WAAW,CAAC,QAAQ,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC;oBACnD,WAAW,GAAG,KAAK,CAAC,IAAI,CAAC;iBAE1B;qBAAM,IAAI,KAAK,YAAY,KAAK,EAAE;oBACjC,SAAS,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;oBACnD,SAAS,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;oBACnD,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC;oBACtC,SAAS,GAAG,WAAW,CAAC,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC;oBAC3C,aAAa,GAAG,WAAW,CAAC,QAAQ,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC;oBACnD,WAAW,GAAG,KAAK,CAAC,MAAM,CAAC;iBAE5B;qBAAM,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,IAAI,QAAQ,CAAC,KAAK,CAAC,EAAE;oBAC3C,SAAS,GAAG,KAAK,CAAC;oBAClB,SAAS,GAAG,KAAK,CAAC;oBAClB,SAAS,GAAG,KAAK,CAAC;oBAClB,aAAa,GAAG,CAAC,CAAC;oBAClB,WAAW,GAAG,CAAC,CAAC;iBAEjB;qBAAM;oBACL,WAAW,GAAG,IAAI,CAAC;oBACnB,SAAS;iBACV;gBACD,cAAc,GAAG,cAAc,GAAG,WAAW,CAAC;gBAC9C,YAAY,GAAG,YAAY;oBACvB,CAAC,WAAW,GAAG,cAAc,CAAC,GAAG,CAAC,SAAS,GAAG,YAAY,CAAC,CAAC;gBAChE,gBAAgB,GAAG,gBAAgB;oBAC/B,CAAC,WAAW,GAAG,cAAc,CAAC;wBAC1B,CAAC,aAAa;4BACb,CAAC,CAAC,SAAS,GAAG,YAAY,CAAC,GAAG,CAAC,SAAS,GAAG,YAAY,CAAC,CAAC;4BACzD,gBAAgB,CAAC,CAAC;gBAE3B,WAAW,CAAC,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,WAAW,CAAC,GAAG,EAAE,SAAS,CAAC,CAAC;gBACvD,WAAW,CAAC,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,WAAW,CAAC,GAAG,EAAE,SAAS,CAAC,CAAC;gBACvD,WAAW,CAAC,MAAM,GAAG,cAAc,CAAC;gBACpC,WAAW,CAAC,IAAI,GAAG,YAAY,CAAC;gBAChC,WAAW,CAAC,QAAQ,GAAG,gBAAgB,CAAC;gBACxC,WAAW,CAAC,MAAM,GAAG,IAAI,CAAC,IAAI,CAAC,gBAAgB,CAAC,CAAC;aAClD;SACF;IACH,CAAC,CAAC,CAAC;IACH,sEAAsE;IACtE,KAAK,MAAM,GAAG,IAAI,MAAM,EAAE;QACxB,MAAM,IAAI,GAA4B,MAAM,CAAC,GAAG,CAAC,CAAC;QAClD,IAAI,IAAI,CAAC,MAAM,KAAK,CAAC,EAAE;YACrB,IAAI,CAAC,QAAQ,GAAG,GAAG,CAAC;YACpB,IAAI,CAAC,MAAM,GAAG,GAAG,CAAC;SACnB;KACF;IACD,OAAO,MAAM,CAAC;AAChB,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n * =============================================================================\n */\n\nimport * as tf from '@tensorflow/tfjs-core';\nimport {div, max, min, sub} from '@tensorflow/tfjs-core';\n\nimport {Dataset} from './dataset';\n\n// TODO(kangyizhang): eliminate the need for ElementArray and TabularRecord, by\n// computing stats on nested structures via deepMap/deepZip.\n\n/**\n * The value associated with a given key for a single element.\n *\n * Such a value may not have a batch dimension.  A value may be a scalar or an\n * n-dimensional array.\n */\nexport type ElementArray = number|number[]|tf.Tensor|string;\n\n/**\n * A map from string keys (aka column names) to values for a single element.\n */\nexport type TabularRecord = {\n  [key: string]: ElementArray\n};\n\n// TODO(kangyizhang): Flesh out collected statistics.\n// For numeric columns we should provide mean, stddev, histogram, etc.\n// For string columns we should provide a vocabulary (at least, top-k), maybe a\n// length histogram, etc.\n// Collecting only numeric min and max is just the bare minimum for now.\n\n/** An interface representing numeric statistics of a column. */\nexport interface NumericColumnStatistics {\n  min: number;\n  max: number;\n  mean: number;\n  variance: number;\n  stddev: number;\n  length: number;\n}\n\n/**\n * An interface representing column level NumericColumnStatistics for a\n * Dataset.\n */\nexport interface DatasetStatistics {\n  [key: string]: NumericColumnStatistics;\n}\n\n/**\n * Provides a function that scales numeric values into the [0, 1] interval.\n *\n * @param min the lower bound of the inputs, which should be mapped to 0.\n * @param max the upper bound of the inputs, which should be mapped to 1,\n * @return A function that maps an input ElementArray to a scaled ElementArray.\n */\nexport function scaleTo01(min: number, max: number): (value: ElementArray) =>\n    ElementArray {\n  const range = max - min;\n  const minTensor: tf.Tensor = tf.scalar(min);\n  const rangeTensor: tf.Tensor = tf.scalar(range);\n  return (value: ElementArray): ElementArray => {\n    if (typeof (value) === 'string') {\n      throw new Error('Can\\'t scale a string.');\n    } else {\n      if (value instanceof tf.Tensor) {\n        const result = div(sub(value, minTensor), rangeTensor);\n        return result;\n      } else if (value instanceof Array) {\n        return value.map(v => (v - min) / range);\n      } else {\n        return (value - min) / range;\n      }\n    }\n  };\n}\n\n/**\n * Provides a function that calculates column level statistics, i.e. min, max,\n * variance, stddev.\n *\n * @param dataset The Dataset object whose statistics will be calculated.\n * @param sampleSize (Optional) If set, statistics will only be calculated\n *     against a subset of the whole data.\n * @param shuffleWindowSize (Optional) If set, shuffle provided dataset before\n *     calculating statistics.\n * @return A DatasetStatistics object that contains NumericColumnStatistics of\n *     each column.\n */\nexport async function computeDatasetStatistics(\n    dataset: Dataset<TabularRecord>, sampleSize?: number,\n    shuffleWindowSize?: number): Promise<DatasetStatistics> {\n  let sampleDataset = dataset;\n  // TODO(soergel): allow for deep shuffle where possible.\n  if (shuffleWindowSize != null) {\n    sampleDataset = sampleDataset.shuffle(shuffleWindowSize);\n  }\n  if (sampleSize != null) {\n    sampleDataset = sampleDataset.take(sampleSize);\n  }\n\n  // TODO(soergel): prepare the column objects based on a schema.\n  const result: DatasetStatistics = {};\n\n  await sampleDataset.forEachAsync(e => {\n    for (const key of Object.keys(e)) {\n      const value = e[key];\n      if (typeof (value) === 'string') {\n        // No statistics for string element.\n      } else {\n        let previousMean = 0;\n        let previousLength = 0;\n        let previousVariance = 0;\n        let columnStats: NumericColumnStatistics = result[key];\n        if (columnStats == null) {\n          columnStats = {\n            min: Number.POSITIVE_INFINITY,\n            max: Number.NEGATIVE_INFINITY,\n            mean: 0,\n            variance: 0,\n            stddev: 0,\n            length: 0\n          };\n          result[key] = columnStats;\n        } else {\n          previousMean = columnStats.mean;\n          previousLength = columnStats.length;\n          previousVariance = columnStats.variance;\n        }\n        let recordMin: number;\n        let recordMax: number;\n\n        // Calculate accumulated mean and variance following tf.Transform\n        // implementation\n        let valueLength = 0;\n        let valueMean = 0;\n        let valueVariance = 0;\n        let combinedLength = 0;\n        let combinedMean = 0;\n        let combinedVariance = 0;\n\n        if (value instanceof tf.Tensor) {\n          recordMin = min(value).dataSync()[0];\n          recordMax = max(value).dataSync()[0];\n          const valueMoment = tf.moments(value);\n          valueMean = valueMoment.mean.dataSync()[0];\n          valueVariance = valueMoment.variance.dataSync()[0];\n          valueLength = value.size;\n\n        } else if (value instanceof Array) {\n          recordMin = value.reduce((a, b) => Math.min(a, b));\n          recordMax = value.reduce((a, b) => Math.max(a, b));\n          const valueMoment = tf.moments(value);\n          valueMean = valueMoment.mean.dataSync()[0];\n          valueVariance = valueMoment.variance.dataSync()[0];\n          valueLength = value.length;\n\n        } else if (!isNaN(value) && isFinite(value)) {\n          recordMin = value;\n          recordMax = value;\n          valueMean = value;\n          valueVariance = 0;\n          valueLength = 1;\n\n        } else {\n          columnStats = null;\n          continue;\n        }\n        combinedLength = previousLength + valueLength;\n        combinedMean = previousMean +\n            (valueLength / combinedLength) * (valueMean - previousMean);\n        combinedVariance = previousVariance +\n            (valueLength / combinedLength) *\n                (valueVariance +\n                 ((valueMean - combinedMean) * (valueMean - previousMean)) -\n                 previousVariance);\n\n        columnStats.min = Math.min(columnStats.min, recordMin);\n        columnStats.max = Math.max(columnStats.max, recordMax);\n        columnStats.length = combinedLength;\n        columnStats.mean = combinedMean;\n        columnStats.variance = combinedVariance;\n        columnStats.stddev = Math.sqrt(combinedVariance);\n      }\n    }\n  });\n  // Variance and stddev should be NaN for the case of a single element.\n  for (const key in result) {\n    const stat: NumericColumnStatistics = result[key];\n    if (stat.length === 1) {\n      stat.variance = NaN;\n      stat.stddev = NaN;\n    }\n  }\n  return result;\n}\n"]}