UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

51 lines 2.6 kB
"use strict"; var _AverageWeights_sum, _AverageWeights_count; Object.defineProperty(exports, "__esModule", { value: true }); exports.AverageWeights = void 0; const tslib_1 = require("tslib"); const assert_1 = tslib_1.__importDefault(require("assert")); /** Average of tensors. The average can be updated by adding tensors to it. */ class AverageWeights { constructor() { _AverageWeights_sum.set(this, void 0); _AverageWeights_count.set(this, 0); } /** Whether anything has been added to the average yet. */ isDefined() { return tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f") !== undefined; } /** Add a tensor to the average. */ add(tensor) { if (tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f") == undefined) { tslib_1.__classPrivateFieldSet(this, _AverageWeights_sum, tensor.clone(), "f"); } else { const oldSum = tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f"); tslib_1.__classPrivateFieldSet(this, _AverageWeights_sum, tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f").add(tensor), "f"); oldSum.dispose(); } tslib_1.__classPrivateFieldSet(this, _AverageWeights_count, tslib_1.__classPrivateFieldGet(this, _AverageWeights_count, "f") + 1, "f"); } /** Get the average value. */ average() { assert_1.default(tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f") !== undefined); return tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f").div(tslib_1.__classPrivateFieldGet(this, _AverageWeights_count, "f")); } /** Get the number of values that the average represents. */ get count() { return tslib_1.__classPrivateFieldGet(this, _AverageWeights_count, "f"); } /** Dispose of the {@link AverageWeights}, discarding the memory backing it. */ dispose() { // assert(!this.isDisposed(), "AverageTensor is already disposed"); tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f")?.dispose(); tslib_1.__classPrivateFieldSet(this, _AverageWeights_count, -1, "f"); } /** Whether the {@link AverageWeights} has been disposed. */ get isDisposed() { return tslib_1.__classPrivateFieldGet(this, _AverageWeights_count, "f") === -1 || (tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f")?.isDisposed ?? false); } } exports.AverageWeights = AverageWeights; _AverageWeights_sum = new WeakMap(), _AverageWeights_count = new WeakMap(); //# sourceMappingURL=AverageWeights.js.map