ml-q-learning
Version:
Library implementing the q-learning algorithm and several exploration algorithms.
126 lines (125 loc) • 4.47 kB
JavaScript
"use strict";
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
return new (P || (P = Promise))(function (resolve, reject) {
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); }
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
};
var __importDefault = (this && this.__importDefault) || function (mod) {
return (mod && mod.__esModule) ? mod : { "default": mod };
};
Object.defineProperty(exports, "__esModule", { value: true });
const dexie_1 = __importDefault(require("dexie"));
class MemoryDatabase extends dexie_1.default {
constructor(name = 'MemoryDatabase') {
super(name);
this.version(1).stores({
states: '&stateSerialized',
info: '++id'
});
}
}
class IndexedDBMemory {
constructor(dbName) {
this.db = new MemoryDatabase(dbName);
}
size() {
return __awaiter(this, void 0, void 0, function* () {
return this.db.states.count();
});
}
setState(stateSerialized, stateStats) {
return __awaiter(this, void 0, void 0, function* () {
yield this.db.states.put({
stateSerialized,
stateStats
});
});
}
setStateBulk(states) {
return __awaiter(this, void 0, void 0, function* () {
const bulkPut = states.map(([stateSerialized, stateStats]) => ({
stateSerialized,
stateStats
}));
yield this.db.states.bulkPut(bulkPut);
});
}
hasState(stateSerialized) {
return __awaiter(this, void 0, void 0, function* () {
const stateInfo = yield this.db.states.get(stateSerialized);
return Boolean(stateInfo);
});
}
getState(stateSerialized) {
return __awaiter(this, void 0, void 0, function* () {
const stateInfo = yield this.db.states.get(stateSerialized);
if (!stateInfo) {
throw new Error(`Missing state ${stateSerialized}`);
}
return stateInfo.stateStats;
});
}
eachState(callback) {
return __awaiter(this, void 0, void 0, function* () {
yield this.db.states.each((stateInfo) => {
callback(stateInfo.stateSerialized, stateInfo.stateStats);
});
});
}
setInfo(info) {
return __awaiter(this, void 0, void 0, function* () {
const hasInfo = yield this.hasInfo();
if (!hasInfo) {
yield this.db.info.put(info);
}
else {
yield this.db.info.update(1, info);
}
});
}
hasInfo() {
return __awaiter(this, void 0, void 0, function* () {
const info = yield this.db.info.get(1);
return Boolean(info);
});
}
getInfo() {
return __awaiter(this, void 0, void 0, function* () {
const info = yield this.db.info.get(1);
if (!info) {
throw new Error('Missing training info.');
}
return info;
});
}
restore(content) {
return __awaiter(this, void 0, void 0, function* () {
const bulkPut = [];
for (let index = 0; index < content.Q.length; index++) {
const [stateSerialized, stateStats] = content.Q[index];
bulkPut.push({
stateSerialized,
stateStats
});
}
yield this.db.states.bulkPut(bulkPut);
yield this.setInfo(content.trainingInfo);
});
}
backup() {
return __awaiter(this, void 0, void 0, function* () {
const info = yield this.getInfo();
const states = yield this.db.states.toArray();
const Q = states.map((stateInfo) => ([stateInfo.stateSerialized, stateInfo.stateStats]));
const context = {
Q,
trainingInfo: info
};
return context;
});
}
}
exports.IndexedDBMemory = IndexedDBMemory;