UNPKG

ml-q-learning

Version:

Library implementing the q-learning algorithm and several exploration algorithms.

134 lines (133 loc) 6.12 kB
"use strict"; var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); }; Object.defineProperty(exports, "__esModule", { value: true }); const pick_action_strategy_1 = require("./pick-action-strategy"); const memory_1 = require("./memory"); class QLearningAgent { constructor(actions, pickActionStrategy = pick_action_strategy_1.greedyPickAction, memory = new memory_1.MapInMemory(), learningRate = 0.1, discountFactor = 0.99) { this.actions = actions; this.pickActionStrategy = pickActionStrategy; this.memory = memory; this.learningRate = learningRate; this.discountFactor = discountFactor; this.startEpisode = 0; this.replayMemory = []; this.episode = 0; this.trained = false; this.init(); } init() { return __awaiter(this, void 0, void 0, function* () { if (!this.episode) { this.startEpisode = (yield this.memory.hasInfo()) ? (yield this.memory.getInfo()).episode : 1; this.episode = this.startEpisode; this.trained = (yield this.memory.hasInfo()) ? (yield this.memory.getInfo()).trained : false; } }); } play(state) { return __awaiter(this, void 0, void 0, function* () { yield this.init(); const stateSerialized = state.toString(); this.episode += 1; const actionIndex = yield this.chooseActionAlgorithm(stateSerialized); const index = this.replayMemory.push([stateSerialized, actionIndex, 0]) - 1; return { action: this.actions[actionIndex], historyIndex: index, trainingInfo: { episode: this.episode, trained: this.trained } }; }); } chooseActionAlgorithm(stateSerialized) { return __awaiter(this, void 0, void 0, function* () { yield this.createStateIfNotExist(stateSerialized); const actionsStats = yield this.memory.getState(stateSerialized); const actionIndex = yield this.pickActionStrategy(actionsStats, this.episode); return actionIndex; }); } reward(step, reward) { this.replayMemory[step.historyIndex][2] += reward; } createStateIfNotExist(stateSerialized) { return __awaiter(this, void 0, void 0, function* () { if (!(yield this.memory.hasState(stateSerialized))) { yield this.memory.setState(stateSerialized, Array(this.actions.length).fill(0)); } }); } greedyPickAction(stateSerialized) { return __awaiter(this, void 0, void 0, function* () { const actionsStats = yield this.memory.getState(stateSerialized); return pick_action_strategy_1.greedyPickAction(actionsStats); }); } learn() { return __awaiter(this, void 0, void 0, function* () { if (this.replayMemory.length === 0) { return; } const map = new Map(); const getState = (stateSerialized) => __awaiter(this, void 0, void 0, function* () { const value = map.get(stateSerialized); if (!value) { return this.memory.getState(stateSerialized); } return value; }); let stateSerialized = this.replayMemory[0][0]; for (let index = 1; index < this.replayMemory.length - 1; index++) { const action = this.replayMemory[index][1]; const reward = this.replayMemory[index][2]; const stateSerializedPrime = this.replayMemory[index + 1][0]; const [stateSerializedToUpdate, actionsStats] = yield this.learningAlgorithm(action, reward, stateSerialized, stateSerializedPrime, getState); map.set(stateSerializedToUpdate, actionsStats); stateSerialized = stateSerializedPrime; } yield this.memory.setStateBulk(Array.from(map)); yield this.updateTrainingInfo(); this.replayMemory = []; }); } learningAlgorithm(action, reward, stateSerialized, stateSerializedPrime, getState) { return __awaiter(this, void 0, void 0, function* () { const actionPrime = yield this.greedyPickAction(stateSerializedPrime); const actionsStats = yield getState(stateSerialized); const actionsStatsPrime = yield getState(stateSerializedPrime); actionsStats[action] = actionsStats[action] + this.learningRate * (reward + (this.discountFactor * actionsStatsPrime[actionPrime]) - actionsStats[action]); return [stateSerialized, actionsStats]; }); } updateTrainingInfo() { return __awaiter(this, void 0, void 0, function* () { if (!(yield this.memory.hasInfo())) { yield this.memory.setInfo({ episode: this.episode, trained: false }); } const info = yield this.memory.getInfo(); const newEpisode = info.episode + this.episode - this.startEpisode; const newInfo = { episode: newEpisode, trained: info.trained }; yield this.memory.setInfo(newInfo); this.startEpisode = newEpisode; this.episode = newEpisode; return newInfo; }); } } exports.QLearningAgent = QLearningAgent;