UNPKG

ml5-save

Version:
962 lines (827 loc) 26.3 kB
import * as tf from '@tensorflow/tfjs'; import { saveBlob } from '../utils/io'; import nnUtils from './NeuralNetworkUtils'; class NeuralNetworkData { constructor() { this.meta = { inputUnits: null, // Number outputUnits: null, // Number // objects describing input/output data by property name inputs: {}, // { name1: {dtype}, name2: {dtype} } outputs: {}, // { name1: {dtype} } isNormalized: false, // Boolean - keep this in meta for model saving/loading }; this.isMetadataReady = false; this.isWarmedUp = false; this.data = { raw: [], // array of {xs:{}, ys:{}} }; // methods // summarize data this.createMetadata = this.createMetadata.bind(this); this.getDataStats = this.getDataStats.bind(this); this.getInputMetaStats = this.getInputMetaStats.bind(this); this.getDataUnits = this.getDataUnits.bind(this); this.getInputMetaUnits = this.getInputMetaUnits.bind(this); this.getDTypesFromData = this.getDTypesFromData.bind(this); // add data this.addData = this.addData.bind(this); // data conversion this.convertRawToTensors = this.convertRawToTensors.bind(this); // data normalization / unnormalization this.normalizeDataRaw = this.normalizeDataRaw.bind(this); this.normalizeInputData = this.normalizeInputData.bind(this); this.normalizeArray = this.normalizeArray.bind(this); this.unnormalizeArray = this.unnormalizeArray.bind(this); // one hot this.applyOneHotEncodingsToDataRaw = this.applyOneHotEncodingsToDataRaw.bind(this); this.getDataOneHot = this.getDataOneHot.bind(this); this.getInputMetaOneHot = this.getInputMetaOneHot.bind(this); this.createOneHotEncodings = this.createOneHotEncodings.bind(this); // Saving / loading data this.loadDataFromUrl = this.loadDataFromUrl.bind(this); this.loadJSON = this.loadJSON.bind(this); this.loadCSV = this.loadCSV.bind(this); this.loadBlob = this.loadBlob.bind(this); this.loadData = this.loadData.bind(this); this.saveData = this.saveData.bind(this); this.saveMeta = this.saveMeta.bind(this); this.loadMeta = this.loadMeta.bind(this); // data loading helpers this.findEntries = this.findEntries.bind(this); this.formatRawData = this.formatRawData.bind(this); this.csvToJSON = this.csvToJSON.bind(this); } /** * //////////////////////////////////////////////////////// * Summarize Data * //////////////////////////////////////////////////////// */ /** * create the metadata from the data * this covers: * 1. getting the datatype from the data * 2. getting the min and max from the data * 3. getting the oneHot encoded values * 4. getting the inputShape and outputUnits from the data * @param {*} dataRaw * @param {*} inputShape */ createMetadata(dataRaw, inputShape = null) { // get the data type for each property this.getDTypesFromData(dataRaw); // get the stats - min, max this.getDataStats(dataRaw); // onehot encode this.getDataOneHot(dataRaw); // calculate the input units from the data this.getDataUnits(dataRaw, inputShape); this.isMetadataReady = true; return { ...this.meta }; } /* * //////////////////////////////////////////////// * data Summary * //////////////////////////////////////////////// */ /** * get stats about the data * @param {*} dataRaw */ getDataStats(dataRaw) { const meta = Object.assign({}, this.meta); const inputMeta = this.getInputMetaStats(dataRaw, meta.inputs, 'xs'); const outputMeta = this.getInputMetaStats(dataRaw, meta.outputs, 'ys'); meta.inputs = inputMeta; meta.outputs = outputMeta; this.meta = { ...this.meta, ...meta, }; return meta; } /** * getRawStats * get back the min and max of each label * @param {*} dataRaw * @param {*} inputOrOutputMeta * @param {*} xsOrYs */ // eslint-disable-next-line no-unused-vars, class-methods-use-this getInputMetaStats(dataRaw, inputOrOutputMeta, xsOrYs) { const inputMeta = Object.assign({}, inputOrOutputMeta); Object.keys(inputMeta).forEach(k => { if (inputMeta[k].dtype === 'string') { inputMeta[k].min = 0; inputMeta[k].max = 1; } else if (inputMeta[k].dtype === 'number') { const dataAsArray = dataRaw.map(item => item[xsOrYs][k]); inputMeta[k].min = nnUtils.getMin(dataAsArray); inputMeta[k].max = nnUtils.getMax(dataAsArray); } else if (inputMeta[k].dtype === 'array') { const dataAsArray = dataRaw.map(item => item[xsOrYs][k]).flat(); inputMeta[k].min = nnUtils.getMin(dataAsArray); inputMeta[k].max = nnUtils.getMax(dataAsArray); } }); return inputMeta; } /** * get the data units, inputshape and output units * @param {*} dataRaw */ getDataUnits(dataRaw, _arrayShape = null) { const arrayShape = _arrayShape !== null ? _arrayShape : undefined; const meta = Object.assign({}, this.meta); // if the data has a shape pass it in let inputShape; if (arrayShape) { inputShape = arrayShape; } else { inputShape = [this.getInputMetaUnits(dataRaw, meta.inputs)].flat(); } console.log(inputShape); const outputShape = this.getInputMetaUnits(dataRaw, meta.outputs); meta.inputUnits = inputShape; meta.outputUnits = outputShape; this.meta = { ...this.meta, ...meta, }; return meta; } /** * get input * @param {*} _inputsMeta * @param {*} _dataRaw */ // eslint-disable-next-line class-methods-use-this, no-unused-vars getInputMetaUnits(_dataRaw, _inputsMeta) { let units = 0; const inputsMeta = Object.assign({}, _inputsMeta); Object.entries(inputsMeta).forEach(arr => { const { dtype } = arr[1]; if (dtype === 'number') { units += 1; } else if (dtype === 'string') { const { uniqueValues } = arr[1]; const uniqueCount = uniqueValues.length; units += uniqueCount; } else if (dtype === 'array') { // TODO: User must input the shape of the // image size correctly. units = []; } }); return units; } /** * getDTypesFromData * gets the data types of the data we're using * important for handling oneHot */ getDTypesFromData(_dataRaw) { const meta = { ...this.meta, inputs: {}, outputs: {}, }; const sample = _dataRaw[0]; const xs = Object.keys(sample.xs); const ys = Object.keys(sample.ys); xs.forEach(prop => { meta.inputs[prop] = { dtype: nnUtils.getDataType(sample.xs[prop]), }; }); ys.forEach(prop => { meta.outputs[prop] = { dtype: nnUtils.getDataType(sample.ys[prop]), }; }); // TODO: check if all entries have the same dtype. // otherwise throw an error this.meta = meta; return meta; } /** * //////////////////////////////////////////////////////// * Add Data * //////////////////////////////////////////////////////// */ /** * Add Data * @param {object} xInputObj, {key: value}, key must be the name of the property value must be a String, Number, or Array * @param {*} yInputObj, {key: value}, key must be the name of the property value must be a String, Number, or Array */ addData(xInputObj, yInputObj) { this.data.raw.push({ xs: xInputObj, ys: yInputObj, }); } /** * //////////////////////////////////////////////////////// * Tensor handling * //////////////////////////////////////////////////////// */ /** * convertRawToTensors * converts array of {xs, ys} to tensors * @param {*} _dataRaw * @param {*} meta */ // eslint-disable-next-line class-methods-use-this, no-unused-vars convertRawToTensors(dataRaw) { const meta = Object.assign({}, this.meta); const dataLength = dataRaw.length; return tf.tidy(() => { const inputArr = []; const outputArr = []; dataRaw.forEach(row => { // get xs const xs = Object.keys(meta.inputs) .map(k => { return row.xs[k]; }) .flat(); inputArr.push(xs); // get ys const ys = Object.keys(meta.outputs) .map(k => { return row.ys[k]; }) .flat(); outputArr.push(ys); }); const inputs = tf.tensor(inputArr.flat(), [dataLength, ...meta.inputUnits]); const outputs = tf.tensor(outputArr.flat(), [dataLength, meta.outputUnits]); return { inputs, outputs, }; }); } /** * //////////////////////////////////////////////////////// * data normalization / unnormalization * //////////////////////////////////////////////////////// */ /** * normalize the dataRaw input * @param {*} dataRaw */ normalizeDataRaw(dataRaw) { const meta = Object.assign({}, this.meta); const normXs = this.normalizeInputData(dataRaw, meta.inputs, 'xs'); const normYs = this.normalizeInputData(dataRaw, meta.outputs, 'ys'); const normalizedData = nnUtils.zipArrays(normXs, normYs); return normalizedData; } /** * normalizeRaws * @param {*} dataRaw * @param {*} inputOrOutputMeta * @param {*} xsOrYs */ // eslint-disable-next-line no-unused-vars, class-methods-use-this normalizeInputData(dataRaw, inputOrOutputMeta, xsOrYs) { // the data length const dataLength = dataRaw.length; // the copy of the inputs.meta[inputOrOutput] const inputMeta = Object.assign({}, inputOrOutputMeta); // normalized output object const normalized = {}; Object.keys(inputMeta).forEach(k => { // get the min and max values const options = { min: inputMeta[k].min, max: inputMeta[k].max, }; const dataAsArray = dataRaw.map(item => item[xsOrYs][k]); // depending on the input type, normalize accordingly if (inputMeta[k].dtype === 'string') { options.legend = inputMeta[k].legend; normalized[k] = this.normalizeArray(dataAsArray, options); } else if (inputMeta[k].dtype === 'number') { normalized[k] = this.normalizeArray(dataAsArray, options); } else if (inputMeta[k].dtype === 'array') { normalized[k] = dataAsArray.map(item => this.normalizeArray(item, options)); } }); // create a normalized version of data.raws const output = [...new Array(dataLength).fill(null)].map((item, idx) => { const row = { [xsOrYs]: {}, }; Object.keys(inputMeta).forEach(k => { row[xsOrYs][k] = normalized[k][idx]; }); return row; }); return output; } /** * normalizeArray * @param {*} _input * @param {*} _options */ // eslint-disable-next-line no-unused-vars, class-methods-use-this normalizeArray(inputArray, options) { const { min, max } = options; // if the data are onehot encoded, replace the string // value with the onehot array // if none exists, return the given value if (options.legend) { const normalized = inputArray.map(v => { return options.legend[v] ? options.legend[v] : v; }); return normalized; } // if the dtype is a number if (inputArray.every(v => typeof v === 'number')) { const normalized = inputArray.map(v => nnUtils.normalizeValue(v, min, max)); return normalized; } // otherwise return the input array // return inputArray; throw new Error('error in inputArray of normalizeArray() function'); } /** * unNormalizeArray * @param {*} _input * @param {*} _options */ // eslint-disable-next-line no-unused-vars, class-methods-use-this unnormalizeArray(inputArray, options) { const { min, max } = options; // if the data is onehot encoded then remap the // values from those oneHot arrays if (options.legend) { const unnormalized = inputArray.map(v => { let res; Object.entries(options.legend).forEach(item => { const key = item[0]; const val = item[1]; const matches = v.map((num, idx) => num === val[idx]).every(truthy => truthy === true); if (matches) res = key; }); return res; }); return unnormalized; } // if the dtype is a number if (inputArray.every(v => typeof v === 'number')) { const unnormalized = inputArray.map(v => nnUtils.unnormalizeValue(v, min, max)); return unnormalized; } // otherwise return the input array // return inputArray; throw new Error('error in inputArray of normalizeArray() function'); } /* * //////////////////////////////////////////////// * One hot encoding handling * //////////////////////////////////////////////// */ /** * applyOneHotEncodingsToDataRaw * does not set this.data.raws * but rather returns them * @param {*} _dataRaw * @param {*} _meta */ applyOneHotEncodingsToDataRaw(dataRaw) { const meta = Object.assign({}, this.meta); const output = dataRaw.map(row => { const xs = { ...row.xs, }; const ys = { ...row.ys, }; // get xs Object.keys(meta.inputs).forEach(k => { if (meta.inputs[k].legend) { xs[k] = meta.inputs[k].legend[row.xs[k]]; } }); Object.keys(meta.outputs).forEach(k => { if (meta.outputs[k].legend) { ys[k] = meta.outputs[k].legend[row.ys[k]]; } }); return { xs, ys, }; }); return output; } /** * getDataOneHot * creates onehot encodings for the input and outputs * and adds them to the meta info * @param {*} dataRaw */ getDataOneHot(dataRaw) { const meta = Object.assign({}, this.meta); const inputMeta = this.getInputMetaOneHot(dataRaw, meta.inputs, 'xs'); const outputMeta = this.getInputMetaOneHot(dataRaw, meta.outputs, 'ys'); meta.inputs = inputMeta; meta.outputs = outputMeta; this.meta = { ...this.meta, ...meta, }; return meta; } /** * getOneHotMeta * @param {*} _inputsMeta * @param {*} _dataRaw * @param {*} xsOrYs */ getInputMetaOneHot(_dataRaw, _inputsMeta, xsOrYs) { const inputsMeta = Object.assign({}, _inputsMeta); Object.entries(inputsMeta).forEach(arr => { // the key const key = arr[0]; // the value const { dtype } = arr[1]; if (dtype === 'string') { const uniqueVals = [...new Set(_dataRaw.map(obj => obj[xsOrYs][key]))]; const oneHotMeta = this.createOneHotEncodings(uniqueVals); inputsMeta[key] = { ...inputsMeta[key], ...oneHotMeta, }; } }); return inputsMeta; } /** * Returns a legend mapping the * data values to oneHot encoded values */ // eslint-disable-next-line class-methods-use-this, no-unused-vars createOneHotEncodings(_uniqueValuesArray) { return tf.tidy(() => { const output = { uniqueValues: _uniqueValuesArray, legend: {}, }; const uniqueVals = _uniqueValuesArray; // [...new Set(this.data.raw.map(obj => obj.xs[prop]))] // get back values from 0 to the length of the uniqueVals array const onehotValues = uniqueVals.map((item, idx) => idx); // oneHot encode the values in the 1d tensor const oneHotEncodedValues = tf.oneHot(tf.tensor1d(onehotValues, 'int32'), uniqueVals.length); // convert them from tensors back out to an array const oneHotEncodedValuesArray = oneHotEncodedValues.arraySync(); // populate the legend with the key/values uniqueVals.forEach((uVal, uIdx) => { output.legend[uVal] = oneHotEncodedValuesArray[uIdx]; }); return output; }); } /** * //////////////////////////////////////////////// * saving / loading data * //////////////////////////////////////////////// */ /** * Loads data from a URL using the appropriate function * @param {*} dataUrl * @param {*} inputs * @param {*} outputs */ async loadDataFromUrl(dataUrl, inputs, outputs) { try { let result; if (dataUrl.endsWith('.csv')) { result = await this.loadCSV(dataUrl, inputs, outputs); } else if (dataUrl.endsWith('.json')) { result = await this.loadJSON(dataUrl, inputs, outputs); } else if (dataUrl.includes('blob')) { result = await this.loadBlob(dataUrl, inputs, outputs); } else { throw new Error('Not a valid data format. Must be csv or json'); } return result; } catch (error) { console.error(error); throw new Error(error); } } /** * loadJSON * @param {*} _dataUrlOrJson * @param {*} _inputLabelsArray * @param {*} _outputLabelsArray */ async loadJSON(dataUrlOrJson, inputLabels, outputLabels) { try { let json; // handle loading parsedJson if (dataUrlOrJson instanceof Object) { json = Object.assign({}, dataUrlOrJson); } else { const data = await fetch(dataUrlOrJson); json = await data.json(); } // format the data.raw array const result = this.formatRawData(json, inputLabels, outputLabels); return result; } catch (err) { console.error('error loading json'); throw new Error(err); } } /** * loadCSV * @param {*} _dataUrl * @param {*} _inputLabelsArray * @param {*} _outputLabelsArray */ async loadCSV(dataUrl, inputLabels, outputLabels) { try { const myCsv = tf.data.csv(dataUrl); const loadedData = await myCsv.toArray(); const json = { entries: loadedData, }; // format the data.raw array const result = this.formatRawData(json, inputLabels, outputLabels); return result; } catch (err) { console.error('error loading csv', err); throw new Error(err); } } /** * loadBlob * @param {*} _dataUrlOrJson * @param {*} _inputLabelsArray * @param {*} _outputLabelsArray */ async loadBlob(dataUrlOrJson, inputLabels, outputLabels) { try { const data = await fetch(dataUrlOrJson); const text = await data.text(); let result; if (nnUtils.isJsonOrString(text)) { const json = JSON.parse(text); result = await this.loadJSON(json, inputLabels, outputLabels); } else { const json = this.csvToJSON(text); result = await this.loadJSON(json, inputLabels, outputLabels); } return result; } catch (err) { console.log('mmm might be passing in a string or something!', err); throw new Error(err); } } /** * loadData from fileinput or path * @param {*} filesOrPath * @param {*} callback */ async loadData(filesOrPath = null, callback) { try { let loadedData; if (typeof filesOrPath !== 'string') { const file = filesOrPath[0]; const fr = new FileReader(); fr.readAsText(file); if (file.name.includes('.json')) { const temp = await file.text(); loadedData = JSON.parse(temp); } else { console.log('data must be a json object containing an array called "data" or "entries'); } } else { loadedData = await fetch(filesOrPath); const text = await loadedData.text(); if (nnUtils.isJsonOrString(text)) { loadedData = JSON.parse(text); } else { console.log( 'Whoops! something went wrong. Either this kind of data is not supported yet or there is an issue with .loadData', ); } } this.data.raw = this.findEntries(loadedData); // check if a data or entries property exists if (!this.data.raw.length > 0) { console.log('data must be a json object containing an array called "data" '); } if (callback) { callback(); } } catch (error) { throw new Error(error); } } /** * saveData * @param {*} name */ async saveData(name) { const today = new Date(); const date = `${String(today.getFullYear())}-${String(today.getMonth() + 1)}-${String( today.getDate(), )}`; const time = `${String(today.getHours())}-${String(today.getMinutes())}-${String( today.getSeconds(), )}`; const datetime = `${date}_${time}`; let dataName = datetime; if (name) dataName = name; const output = { data: this.data.raw, }; await saveBlob(JSON.stringify(output), `${dataName}.json`, 'text/plain'); } /** * Saves metadata of the data * @param {*} nameOrCb * @param {*} cb */ async saveMeta(nameOrCb, cb) { let modelName; let callback; if (typeof nameOrCb === 'function') { modelName = 'model'; callback = nameOrCb; } else if (typeof nameOrCb === 'string') { modelName = nameOrCb; if (typeof cb === 'function') { callback = cb; } } else { modelName = 'model'; } await saveBlob(JSON.stringify(this.meta), `${modelName}_meta.json`, 'text/plain'); if (callback) { callback(); } } /** * load a model and metadata * @param {*} filesOrPath * @param {*} callback */ async loadMeta(filesOrPath = null, callback) { if (filesOrPath instanceof FileList) { const files = await Promise.all( Array.from(filesOrPath).map(async file => { if (file.name.includes('.json') && !file.name.includes('_meta')) { return { name: 'model', file, }; } else if (file.name.includes('.json') && file.name.includes('_meta.json')) { const modelMetadata = await file.text(); return { name: 'metadata', file: modelMetadata, }; } else if (file.name.includes('.bin')) { return { name: 'weights', file, }; } return { name: null, file: null, }; }), ); const modelMetadata = JSON.parse(files.find(item => item.name === 'metadata').file); this.meta = modelMetadata; } else if (filesOrPath instanceof Object) { // filesOrPath = {model: URL, metadata: URL, weights: URL} let modelMetadata = await fetch(filesOrPath.metadata); modelMetadata = await modelMetadata.text(); modelMetadata = JSON.parse(modelMetadata); this.meta = modelMetadata; } else { const metaPath = `${filesOrPath.substring(0, filesOrPath.lastIndexOf('/'))}/model_meta.json`; let modelMetadata = await fetch(metaPath); modelMetadata = await modelMetadata.json(); this.meta = modelMetadata; } this.isMetadataReady = true; this.isWarmedUp = true; if (callback) { callback(); } return this.meta; } /* * //////////////////////////////////////////////// * data loading helpers * //////////////////////////////////////////////// */ /** * // TODO: convert ys into strings, if the task is classification // if (this.config.architecture.task === "classification" && typeof output.ys[prop] !== "string") { // output.ys[prop] += ""; // } * formatRawData * takes a json and set the this.data.raw * @param {*} json * @param {Array} inputLabels * @param {Array} outputLabels */ formatRawData(json, inputLabels, outputLabels) { // Recurse through the json object to find // an array containing `entries` or `data` const dataArray = this.findEntries(json); if (!dataArray.length > 0) { console.log(`your data must be contained in an array in \n a property called 'entries' or 'data' of your json object`); } // create an array of json objects [{xs,ys}] const result = dataArray.map((item, idx) => { const output = { xs: {}, ys: {}, }; inputLabels.forEach(k => { if (item[k] !== undefined) { output.xs[k] = item[k]; } else { console.error(`the input label ${k} does not exist at row ${idx}`); } }); outputLabels.forEach(k => { if (item[k] !== undefined) { output.ys[k] = item[k]; } else { console.error(`the output label ${k} does not exist at row ${idx}`); } }); return output; }); // set this.data.raw this.data.raw = result; return result; } /** * csvToJSON * Creates a csv from a string * @param {*} csv */ // via: http://techslides.com/convert-csv-to-json-in-javascript // eslint-disable-next-line class-methods-use-this csvToJSON(csv) { // split the string by linebreak const lines = csv.split('\n'); const result = []; // get the header row as an array const headers = lines[0].split(','); // iterate through every row for (let i = 1; i < lines.length; i += 1) { // create a json object for each row const row = {}; // split the current line into an array const currentline = lines[i].split(','); // for each header, create a key/value pair headers.forEach((k, idx) => { row[k] = currentline[idx]; }); // add this to the result array result.push(row); } return { entries: result, }; } /** * findEntries * recursively attempt to find the entries * or data array for the given json object * @param {*} _data */ findEntries(_data) { const parentCopy = Object.assign({}, _data); if (parentCopy.entries && parentCopy.entries instanceof Array) { return parentCopy.entries; } else if (parentCopy.data && parentCopy.data instanceof Array) { return parentCopy.data; } const keys = Object.keys(parentCopy); // eslint-disable-next-line consistent-return keys.forEach(k => { if (typeof parentCopy[k] === 'object') { return this.findEntries(parentCopy[k]); } }); return parentCopy; } } export default NeuralNetworkData;