federer
Version:
Experiments in asynchronous federated learning and decentralized learning
68 lines • 2.75 kB
TypeScript
import { Rank } from "@tensorflow/tfjs-node";
import { Dataset, DataSubset } from "../../../common";
import { PreprocessResult } from "../../../coordinator";
import { Environment } from "../../../coordinator/cli";
import { ShardingOptions } from "../coordinator/MnistCoordinator";
import { MnistModelName } from "./model";
import { DatasetName } from "./data/download";
/**
* Raw MNIST data, as downloaded from the Internet. Images have rank 4 (number
* of images, width, height, channels), and labels are scalars in the range 0-9,
* so the list of labels has rank 1.
*/
export declare type RawMnistDataset = Dataset<Rank.R4, Rank.R1>;
/**
* MNIST data after preprocessing by {@link preprocess}. The images have rank 2
* if the model is 2NN (list of vectors of flat images), and rank 4 if the model
* is CNN (number of images, width, height channels). The labels are encoded as
* one-hot vectors, so have rank 2.
*/
export declare type ProcessedMnistDataset = Dataset<Rank.R2 | Rank.R4, Rank.R2>;
export declare type ProcessedMnistDataSubset = DataSubset<Rank.R4 | Rank.R2, Rank.R2>;
/**
* Interface specializing {@link PreprocessResult} with a more specific type
* describing the ranks of the test set.
*/
export interface MnistPreprocessResult extends PreprocessResult {
testSet: ProcessedMnistDataSubset;
}
export interface MnistPreprocessOptions {
/** Name of the raw dataset to use. */
dataset: DatasetName;
/**
* Name of the model; preprocessing might vary slightly depending on the model
*/
modelName: MnistModelName;
/**
* Number of label classes to include. Include fewer than 10 to run smaller
* test experiments.
*/
numberLabelClasses: number;
/**
* Number of clients. Tells us into how many shards to break up the data.
* @see {@link MnistCoordinatorOptions}
*/
numberClients: number;
/** Environment. Dictates where to read data files from. */
environment: Environment;
/**
* Number of digit batches per client. A digit batch is a slice of data,
* which mostly includes the same digit.
* @see {@link MnistCoordinatorOptions}
*/
numberDigitBatchesPerClient: number;
/**
* Whether to allow reads from cache (`true`), or to force re-computation of
* the preprocess pipeline (`false`).
*/
allowReadFromCache?: boolean;
shardingOptions: ShardingOptions;
}
/**
* Preprocess the raw MNIST data into client shards and a test set.
*
* @param options Preprocessing options
* @returns The results of the preprocessing, with all data saved to disk
*/
export declare function preprocess(options: MnistPreprocessOptions): Promise<MnistPreprocessResult>;
//# sourceMappingURL=preprocess.d.ts.map