UNPKG

ml-q-learning

Version:

Library implementing the q-learning algorithm and several exploration algorithms.

126 lines (125 loc) 4.47 kB
"use strict"; var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); }; var __importDefault = (this && this.__importDefault) || function (mod) { return (mod && mod.__esModule) ? mod : { "default": mod }; }; Object.defineProperty(exports, "__esModule", { value: true }); const dexie_1 = __importDefault(require("dexie")); class MemoryDatabase extends dexie_1.default { constructor(name = 'MemoryDatabase') { super(name); this.version(1).stores({ states: '&stateSerialized', info: '++id' }); } } class IndexedDBMemory { constructor(dbName) { this.db = new MemoryDatabase(dbName); } size() { return __awaiter(this, void 0, void 0, function* () { return this.db.states.count(); }); } setState(stateSerialized, stateStats) { return __awaiter(this, void 0, void 0, function* () { yield this.db.states.put({ stateSerialized, stateStats }); }); } setStateBulk(states) { return __awaiter(this, void 0, void 0, function* () { const bulkPut = states.map(([stateSerialized, stateStats]) => ({ stateSerialized, stateStats })); yield this.db.states.bulkPut(bulkPut); }); } hasState(stateSerialized) { return __awaiter(this, void 0, void 0, function* () { const stateInfo = yield this.db.states.get(stateSerialized); return Boolean(stateInfo); }); } getState(stateSerialized) { return __awaiter(this, void 0, void 0, function* () { const stateInfo = yield this.db.states.get(stateSerialized); if (!stateInfo) { throw new Error(`Missing state ${stateSerialized}`); } return stateInfo.stateStats; }); } eachState(callback) { return __awaiter(this, void 0, void 0, function* () { yield this.db.states.each((stateInfo) => { callback(stateInfo.stateSerialized, stateInfo.stateStats); }); }); } setInfo(info) { return __awaiter(this, void 0, void 0, function* () { const hasInfo = yield this.hasInfo(); if (!hasInfo) { yield this.db.info.put(info); } else { yield this.db.info.update(1, info); } }); } hasInfo() { return __awaiter(this, void 0, void 0, function* () { const info = yield this.db.info.get(1); return Boolean(info); }); } getInfo() { return __awaiter(this, void 0, void 0, function* () { const info = yield this.db.info.get(1); if (!info) { throw new Error('Missing training info.'); } return info; }); } restore(content) { return __awaiter(this, void 0, void 0, function* () { const bulkPut = []; for (let index = 0; index < content.Q.length; index++) { const [stateSerialized, stateStats] = content.Q[index]; bulkPut.push({ stateSerialized, stateStats }); } yield this.db.states.bulkPut(bulkPut); yield this.setInfo(content.trainingInfo); }); } backup() { return __awaiter(this, void 0, void 0, function* () { const info = yield this.getInfo(); const states = yield this.db.states.toArray(); const Q = states.map((stateInfo) => ([stateInfo.stateSerialized, stateInfo.stateStats])); const context = { Q, trainingInfo: info }; return context; }); } } exports.IndexedDBMemory = IndexedDBMemory;