federer
Version:
Experiments in asynchronous federated learning and decentralized learning
51 lines • 2.6 kB
JavaScript
"use strict";
var _AverageWeights_sum, _AverageWeights_count;
Object.defineProperty(exports, "__esModule", { value: true });
exports.AverageWeights = void 0;
const tslib_1 = require("tslib");
const assert_1 = tslib_1.__importDefault(require("assert"));
/** Average of tensors. The average can be updated by adding tensors to it. */
class AverageWeights {
constructor() {
_AverageWeights_sum.set(this, void 0);
_AverageWeights_count.set(this, 0);
}
/** Whether anything has been added to the average yet. */
isDefined() {
return tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f") !== undefined;
}
/** Add a tensor to the average. */
add(tensor) {
if (tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f") == undefined) {
tslib_1.__classPrivateFieldSet(this, _AverageWeights_sum, tensor.clone(), "f");
}
else {
const oldSum = tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f");
tslib_1.__classPrivateFieldSet(this, _AverageWeights_sum, tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f").add(tensor), "f");
oldSum.dispose();
}
tslib_1.__classPrivateFieldSet(this, _AverageWeights_count, tslib_1.__classPrivateFieldGet(this, _AverageWeights_count, "f") + 1, "f");
}
/** Get the average value. */
average() {
assert_1.default(tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f") !== undefined);
return tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f").div(tslib_1.__classPrivateFieldGet(this, _AverageWeights_count, "f"));
}
/** Get the number of values that the average represents. */
get count() {
return tslib_1.__classPrivateFieldGet(this, _AverageWeights_count, "f");
}
/** Dispose of the {@link AverageWeights}, discarding the memory backing it. */
dispose() {
// assert(!this.isDisposed(), "AverageTensor is already disposed");
tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f")?.dispose();
tslib_1.__classPrivateFieldSet(this, _AverageWeights_count, -1, "f");
}
/** Whether the {@link AverageWeights} has been disposed. */
get isDisposed() {
return tslib_1.__classPrivateFieldGet(this, _AverageWeights_count, "f") === -1 || (tslib_1.__classPrivateFieldGet(this, _AverageWeights_sum, "f")?.isDisposed ?? false);
}
}
exports.AverageWeights = AverageWeights;
_AverageWeights_sum = new WeakMap(), _AverageWeights_count = new WeakMap();
//# sourceMappingURL=AverageWeights.js.map