federer
Version:
Experiments in asynchronous federated learning and decentralized learning
73 lines • 3.26 kB
TypeScript
/// <reference types="node" />
import * as tf from "@tensorflow/tfjs-node";
declare type Shapes = number[][];
/**
* Type of the weights when serialized to their `.npz` representation.
*
* @note
* The reason we have a union type here, instead of a single type, is that the
* serialized data may take on multiple forms. But all of these types are backed
* by an `ArrayBuffer`, and our serialization library (tfjs-npy-node) is capable
* of deserializing from all three.
*
* In practice:
*
* - The server generates an `ArrayBuffer` when serializing weights
* - It sends this `ArrayBuffer` through socket.io to the client
* - On the client, socket.io wraps this in a `Buffer`
* - The client sends this to its worker through comlink
* - The client's worker receives this as a `UInt8Array`
*/
export declare type SerializedWeights = ArrayBuffer | Buffer | tf.TypedArray;
/**
* Wrapper around TensorFlow model weights. Provides methods to save, load,
* serialize, deserialize, construct and combine weights.
*
* Weights are represented internally as `tf.Tensor[]`. Note that TensorFlow.js
* represents weights as `tf.Variable[]`, which is similar to `tf.Tensor`, but
* with a few key differences:
*
* - `tf.Variable` dimensions cannot be changed
* - `tf.Variable` are not automatically cleaned up by `tf.tidy`, but must be
* manually disposed of by calling `.dispose()` on them.
*
* Because we modify weights frequently in Federated Learning, we would like to
* be able to use `tf.tidy` and similar methods to avoid memory leaks.
* Therefore, {@link Weights} uses `tf.Tensor[]` internally, but has a
* {@link Weights.toVariables} method to convert to `tf.Variable[]` when needed.
*
* Note that it is the responsibility of the calling code to call `tf.tidy`,
* {@link tidy} or {@link Weights.dispose} to dispose of tensors.
*/
export declare class Weights {
readonly weights: tf.Tensor[];
constructor(weights: ReadonlyArray<tf.Tensor>);
static zero(shapes: Readonly<Shapes>): Weights;
static random(shapes: Readonly<Shapes>): Weights;
get shapes(): Readonly<Shapes>;
clone(): Weights;
/**
* Calls `tf.dispose` on all tensors in the {@link Weights}. Note that the
* {@link Weights} object should not be used after this method is called.
*/
dispose(): void;
get isDisposed(): boolean;
toVariables(): tf.Variable[];
serialize(): Promise<SerializedWeights>;
serializeSync(): SerializedWeights;
static deserialize(serialized: SerializedWeights): Weights;
save(file: string): Promise<void>;
static saveSerialized(file: string, serialized: SerializedWeights): Promise<void>;
static saveSerializedSync(file: string, serialized: SerializedWeights): void;
static load(file: string): Weights;
add(that: Weights | tf.TensorLike[] | tf.TensorLike): Weights;
sub(that: Weights | tf.TensorLike[] | tf.TensorLike): Weights;
mul(that: Weights | tf.TensorLike[] | tf.TensorLike): Weights;
div(that: Weights | tf.TensorLike[] | tf.TensorLike): Weights;
divNoNan(that: Weights | tf.TensorLike[] | tf.TensorLike): Weights;
private checkDimensions;
private elementWise;
private getValues;
}
export {};
//# sourceMappingURL=weights.d.ts.map