ml-double-q-learning
Version:
Library implementing the double-q-learning algorithm.
78 lines (77 loc) • 4.18 kB
JavaScript
;
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
return new (P || (P = Promise))(function (resolve, reject) {
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); }
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
};
Object.defineProperty(exports, "__esModule", { value: true });
const ml_q_learning_1 = require("ml-q-learning");
function sumActionsStats(actionsStatsA, actionsStatsB) {
const actionsStats = actionsStatsA.reduce((actionsStats, actionStatsA, index) => {
const actionStatsB = actionsStatsB[index];
actionsStats.push(actionStatsA + actionStatsB);
return actionsStats;
}, []);
return actionsStats;
}
var SelectedUpdate;
(function (SelectedUpdate) {
SelectedUpdate[SelectedUpdate["A"] = 0] = "A";
SelectedUpdate[SelectedUpdate["B"] = 1] = "B";
})(SelectedUpdate || (SelectedUpdate = {}));
function chooseRandomUpdate() {
return Math.random() >= 0.5 ? SelectedUpdate.A : SelectedUpdate.B;
}
class DoubleQLearningAgent extends ml_q_learning_1.QLearningAgent {
chooseActionAlgorithm(stateSerialized) {
return __awaiter(this, void 0, void 0, function* () {
const stateSerializedA = `A${stateSerialized}`;
const stateSerializedB = `B${stateSerialized}`;
yield Promise.all([
this.createStateIfNotExist(stateSerializedA),
this.createStateIfNotExist(stateSerializedB)
]);
const [actionsStatsA, actionsStatsB] = yield Promise.all([
this.memory.getState(stateSerializedA),
this.memory.getState(stateSerializedB)
]);
const actionsStats = sumActionsStats(actionsStatsA, actionsStatsB);
const actionIndex = yield this.pickActionStrategy(actionsStats, this.episode);
return actionIndex;
});
}
learningAlgorithm(action, reward, stateSerialized, stateSerializedPrime, getState) {
return __awaiter(this, void 0, void 0, function* () {
const selectedUpdate = chooseRandomUpdate();
if (selectedUpdate === SelectedUpdate.A) {
const stateSerializedAPrime = `A${stateSerializedPrime}`;
const stateSerializedA = `A${stateSerialized}`;
const stateSerializedBPrime = `B${stateSerializedPrime}`;
const [actionPrime, actionsStatsA, actionsStatsBPrime] = yield Promise.all([
this.greedyPickAction(stateSerializedAPrime),
getState(stateSerializedA),
getState(stateSerializedBPrime)
]);
actionsStatsA[action] = actionsStatsA[action] + this.learningRate * (reward + (this.discountFactor * actionsStatsBPrime[actionPrime]) - actionsStatsA[action]);
return [stateSerializedA, actionsStatsA];
}
if (selectedUpdate === SelectedUpdate.B) {
const stateSerializedBPrime = `B${stateSerializedPrime}`;
const stateSerializedB = `B${stateSerialized}`;
const stateSerializedAPrime = `A${stateSerializedPrime}`;
const [actionPrime, actionsStatsB, actionsStatsAPrime] = yield Promise.all([
this.greedyPickAction(stateSerializedBPrime),
getState(stateSerializedB),
getState(stateSerializedAPrime)
]);
actionsStatsB[action] = actionsStatsB[action] + this.learningRate * (reward + (this.discountFactor * actionsStatsAPrime[actionPrime]) - actionsStatsB[action]);
return [stateSerializedB, actionsStatsB];
}
throw new Error('The learning algorithm did not return anything.');
});
}
}
exports.DoubleQLearningAgent = DoubleQLearningAgent;