UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

85 lines 3.46 kB
import * as tf from "@tensorflow/tfjs-node"; import { Dataset, DataSubset, DataSubsetFilepaths } from "../../common"; import { DataDistributionStats } from "./distribution-stats"; /** Result of preprocessing a dataset. */ export interface PreprocessResult { /** Test set, in memory. */ testSet: DataSubset; /** Paths to files of the saved test set. */ testSetFiles: DataSubsetFilepaths; /** * Paths to the training sets of each client. This array has one entry per * client. */ clientTrainFiles: DataSubsetFilepaths[]; /** Statistics on how the data was distributed among clients. */ dataDistribution: DataDistributionStats; } /** * Represents a full preprocessing pipeline for Federated Learning. * * For FL, preprocessing must split the raw data into "shards"; each FL client's * training set is equivalent to one shard. * * The preprocessing must also return an test set, used by the coordinator to * evaluate the performance of the global model. */ export declare class PreprocessPipeline { private readonly numberLabelClasses; private readonly pipeline; /** Path of the directory containing cached results of a previous run. */ private readonly rootDir; /** Path of the file containing paths to the cached results of a previous run. */ private readonly pathsFile; /** Paths to the directories to which files should be saved to. */ private readonly directories; constructor(experimentName: string, pipelineName: string, numberLabelClasses: number, pipeline: Readonly<PreprocessPipelineFunctions>); run(allowReadFromCache?: boolean): Promise<PreprocessResult>; private runAndCache; private runPipeline; private createTestSet; private createClientTrainSets; private preprocess; private readCachedResults; private cacheResults; private saveDistributionMatrix; private saveShards; } /** * Set of functions that implement the steps on a preprocessing pipeline for * federated learning. */ export interface PreprocessPipelineFunctions { /** Function that reads the raw data. Will only be called if necessary. */ readRawData: () => Promise<Dataset>; /** * An initial filtering of the data. This is useful to determine what subset * of the raw data source should be used for an experiment. * * This function is optional; if `undefined`, no filtering will be done. */ filter?: FilterFn; /** Function splitting the data into shards */ shard: ShardFn; /** * Functions for preprocessing the data into a format that is suitable for * consumption by an ML model. In a production environment, these * preprocessing functions would run on the clients, but for the sake of * simplicity, in this implementation we run them on the coordinator ahead of * time. */ preprocess: { preprocessItems: PreprocessFn; preprocessLabels: PreprocessFn; }; } /** * Function that filters a dataset, discarding datapoints that should not be * used. */ export declare type FilterFn = (data: DataSubset) => Promise<DataSubset>; /** Function that splits a full dataset into shards. */ export declare type ShardFn = (data: DataSubset) => Promise<DataSubset[]>; /** Function that transforms a tensor to prepare it for a model. */ export declare type PreprocessFn = (tensor: tf.Tensor) => tf.Tensor; //# sourceMappingURL=PreprocessPipeline.d.ts.map