UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

68 lines 2.75 kB
import { Rank } from "@tensorflow/tfjs-node"; import { Dataset, DataSubset } from "../../../common"; import { PreprocessResult } from "../../../coordinator"; import { Environment } from "../../../coordinator/cli"; import { ShardingOptions } from "../coordinator/MnistCoordinator"; import { MnistModelName } from "./model"; import { DatasetName } from "./data/download"; /** * Raw MNIST data, as downloaded from the Internet. Images have rank 4 (number * of images, width, height, channels), and labels are scalars in the range 0-9, * so the list of labels has rank 1. */ export declare type RawMnistDataset = Dataset<Rank.R4, Rank.R1>; /** * MNIST data after preprocessing by {@link preprocess}. The images have rank 2 * if the model is 2NN (list of vectors of flat images), and rank 4 if the model * is CNN (number of images, width, height channels). The labels are encoded as * one-hot vectors, so have rank 2. */ export declare type ProcessedMnistDataset = Dataset<Rank.R2 | Rank.R4, Rank.R2>; export declare type ProcessedMnistDataSubset = DataSubset<Rank.R4 | Rank.R2, Rank.R2>; /** * Interface specializing {@link PreprocessResult} with a more specific type * describing the ranks of the test set. */ export interface MnistPreprocessResult extends PreprocessResult { testSet: ProcessedMnistDataSubset; } export interface MnistPreprocessOptions { /** Name of the raw dataset to use. */ dataset: DatasetName; /** * Name of the model; preprocessing might vary slightly depending on the model */ modelName: MnistModelName; /** * Number of label classes to include. Include fewer than 10 to run smaller * test experiments. */ numberLabelClasses: number; /** * Number of clients. Tells us into how many shards to break up the data. * @see {@link MnistCoordinatorOptions} */ numberClients: number; /** Environment. Dictates where to read data files from. */ environment: Environment; /** * Number of digit batches per client. A digit batch is a slice of data, * which mostly includes the same digit. * @see {@link MnistCoordinatorOptions} */ numberDigitBatchesPerClient: number; /** * Whether to allow reads from cache (`true`), or to force re-computation of * the preprocess pipeline (`false`). */ allowReadFromCache?: boolean; shardingOptions: ShardingOptions; } /** * Preprocess the raw MNIST data into client shards and a test set. * * @param options Preprocessing options * @returns The results of the preprocessing, with all data saved to disk */ export declare function preprocess(options: MnistPreprocessOptions): Promise<MnistPreprocessResult>; //# sourceMappingURL=preprocess.d.ts.map