ml-q-learning
Version:
Library implementing the q-learning algorithm and several exploration algorithms.
134 lines (133 loc) • 6.12 kB
JavaScript
"use strict";
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
return new (P || (P = Promise))(function (resolve, reject) {
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); }
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
};
Object.defineProperty(exports, "__esModule", { value: true });
const pick_action_strategy_1 = require("./pick-action-strategy");
const memory_1 = require("./memory");
class QLearningAgent {
constructor(actions, pickActionStrategy = pick_action_strategy_1.greedyPickAction, memory = new memory_1.MapInMemory(), learningRate = 0.1, discountFactor = 0.99) {
this.actions = actions;
this.pickActionStrategy = pickActionStrategy;
this.memory = memory;
this.learningRate = learningRate;
this.discountFactor = discountFactor;
this.startEpisode = 0;
this.replayMemory = [];
this.episode = 0;
this.trained = false;
this.init();
}
init() {
return __awaiter(this, void 0, void 0, function* () {
if (!this.episode) {
this.startEpisode = (yield this.memory.hasInfo()) ? (yield this.memory.getInfo()).episode : 1;
this.episode = this.startEpisode;
this.trained = (yield this.memory.hasInfo()) ? (yield this.memory.getInfo()).trained : false;
}
});
}
play(state) {
return __awaiter(this, void 0, void 0, function* () {
yield this.init();
const stateSerialized = state.toString();
this.episode += 1;
const actionIndex = yield this.chooseActionAlgorithm(stateSerialized);
const index = this.replayMemory.push([stateSerialized, actionIndex, 0]) - 1;
return {
action: this.actions[actionIndex],
historyIndex: index,
trainingInfo: {
episode: this.episode,
trained: this.trained
}
};
});
}
chooseActionAlgorithm(stateSerialized) {
return __awaiter(this, void 0, void 0, function* () {
yield this.createStateIfNotExist(stateSerialized);
const actionsStats = yield this.memory.getState(stateSerialized);
const actionIndex = yield this.pickActionStrategy(actionsStats, this.episode);
return actionIndex;
});
}
reward(step, reward) {
this.replayMemory[step.historyIndex][2] += reward;
}
createStateIfNotExist(stateSerialized) {
return __awaiter(this, void 0, void 0, function* () {
if (!(yield this.memory.hasState(stateSerialized))) {
yield this.memory.setState(stateSerialized, Array(this.actions.length).fill(0));
}
});
}
greedyPickAction(stateSerialized) {
return __awaiter(this, void 0, void 0, function* () {
const actionsStats = yield this.memory.getState(stateSerialized);
return pick_action_strategy_1.greedyPickAction(actionsStats);
});
}
learn() {
return __awaiter(this, void 0, void 0, function* () {
if (this.replayMemory.length === 0) {
return;
}
const map = new Map();
const getState = (stateSerialized) => __awaiter(this, void 0, void 0, function* () {
const value = map.get(stateSerialized);
if (!value) {
return this.memory.getState(stateSerialized);
}
return value;
});
let stateSerialized = this.replayMemory[0][0];
for (let index = 1; index < this.replayMemory.length - 1; index++) {
const action = this.replayMemory[index][1];
const reward = this.replayMemory[index][2];
const stateSerializedPrime = this.replayMemory[index + 1][0];
const [stateSerializedToUpdate, actionsStats] = yield this.learningAlgorithm(action, reward, stateSerialized, stateSerializedPrime, getState);
map.set(stateSerializedToUpdate, actionsStats);
stateSerialized = stateSerializedPrime;
}
yield this.memory.setStateBulk(Array.from(map));
yield this.updateTrainingInfo();
this.replayMemory = [];
});
}
learningAlgorithm(action, reward, stateSerialized, stateSerializedPrime, getState) {
return __awaiter(this, void 0, void 0, function* () {
const actionPrime = yield this.greedyPickAction(stateSerializedPrime);
const actionsStats = yield getState(stateSerialized);
const actionsStatsPrime = yield getState(stateSerializedPrime);
actionsStats[action] = actionsStats[action] + this.learningRate * (reward + (this.discountFactor * actionsStatsPrime[actionPrime]) - actionsStats[action]);
return [stateSerialized, actionsStats];
});
}
updateTrainingInfo() {
return __awaiter(this, void 0, void 0, function* () {
if (!(yield this.memory.hasInfo())) {
yield this.memory.setInfo({
episode: this.episode,
trained: false
});
}
const info = yield this.memory.getInfo();
const newEpisode = info.episode + this.episode - this.startEpisode;
const newInfo = {
episode: newEpisode,
trained: info.trained
};
yield this.memory.setInfo(newInfo);
this.startEpisode = newEpisode;
this.episode = newEpisode;
return newInfo;
});
}
}
exports.QLearningAgent = QLearningAgent;