UNPKG

ml-double-q-learning

Version:
78 lines (77 loc) 4.18 kB
"use strict"; var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); }; Object.defineProperty(exports, "__esModule", { value: true }); const ml_q_learning_1 = require("ml-q-learning"); function sumActionsStats(actionsStatsA, actionsStatsB) { const actionsStats = actionsStatsA.reduce((actionsStats, actionStatsA, index) => { const actionStatsB = actionsStatsB[index]; actionsStats.push(actionStatsA + actionStatsB); return actionsStats; }, []); return actionsStats; } var SelectedUpdate; (function (SelectedUpdate) { SelectedUpdate[SelectedUpdate["A"] = 0] = "A"; SelectedUpdate[SelectedUpdate["B"] = 1] = "B"; })(SelectedUpdate || (SelectedUpdate = {})); function chooseRandomUpdate() { return Math.random() >= 0.5 ? SelectedUpdate.A : SelectedUpdate.B; } class DoubleQLearningAgent extends ml_q_learning_1.QLearningAgent { chooseActionAlgorithm(stateSerialized) { return __awaiter(this, void 0, void 0, function* () { const stateSerializedA = `A${stateSerialized}`; const stateSerializedB = `B${stateSerialized}`; yield Promise.all([ this.createStateIfNotExist(stateSerializedA), this.createStateIfNotExist(stateSerializedB) ]); const [actionsStatsA, actionsStatsB] = yield Promise.all([ this.memory.getState(stateSerializedA), this.memory.getState(stateSerializedB) ]); const actionsStats = sumActionsStats(actionsStatsA, actionsStatsB); const actionIndex = yield this.pickActionStrategy(actionsStats, this.episode); return actionIndex; }); } learningAlgorithm(action, reward, stateSerialized, stateSerializedPrime, getState) { return __awaiter(this, void 0, void 0, function* () { const selectedUpdate = chooseRandomUpdate(); if (selectedUpdate === SelectedUpdate.A) { const stateSerializedAPrime = `A${stateSerializedPrime}`; const stateSerializedA = `A${stateSerialized}`; const stateSerializedBPrime = `B${stateSerializedPrime}`; const [actionPrime, actionsStatsA, actionsStatsBPrime] = yield Promise.all([ this.greedyPickAction(stateSerializedAPrime), getState(stateSerializedA), getState(stateSerializedBPrime) ]); actionsStatsA[action] = actionsStatsA[action] + this.learningRate * (reward + (this.discountFactor * actionsStatsBPrime[actionPrime]) - actionsStatsA[action]); return [stateSerializedA, actionsStatsA]; } if (selectedUpdate === SelectedUpdate.B) { const stateSerializedBPrime = `B${stateSerializedPrime}`; const stateSerializedB = `B${stateSerialized}`; const stateSerializedAPrime = `A${stateSerializedPrime}`; const [actionPrime, actionsStatsB, actionsStatsAPrime] = yield Promise.all([ this.greedyPickAction(stateSerializedBPrime), getState(stateSerializedB), getState(stateSerializedAPrime) ]); actionsStatsB[action] = actionsStatsB[action] + this.learningRate * (reward + (this.discountFactor * actionsStatsAPrime[actionPrime]) - actionsStatsB[action]); return [stateSerializedB, actionsStatsB]; } throw new Error('The learning algorithm did not return anything.'); }); } } exports.DoubleQLearningAgent = DoubleQLearningAgent;