UNPKG

ml-double-q-learning

Version:
163 lines (162 loc) 6.04 kB
"use strict"; var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); }; Object.defineProperty(exports, "__esModule", { value: true }); const pick_action_strategy_1 = require("ml-q-learning/lib/pick-action-strategy"); const double_q_learning_agent_1 = require("../double-q-learning-agent"); var Action; (function (Action) { Action[Action["Left"] = 0] = "Left"; Action[Action["Right"] = 1] = "Right"; Action[Action["Up"] = 2] = "Up"; Action[Action["Down"] = 3] = "Down"; })(Action || (Action = {})); var MazeElements; (function (MazeElements) { MazeElements["Player"] = "P"; MazeElements["EmptySpace"] = "."; MazeElements["Wall"] = "#"; MazeElements["Treasure"] = "R"; MazeElements["Trap"] = "X"; MazeElements["Finish"] = "F"; })(MazeElements || (MazeElements = {})); const maze = [ ['P', '.', '.', '#', '.', '.', '.', '#', 'R'], ['.', '#', '.', '#', '.', '.', '.', '#', '.'], ['.', '#', '.', '#', '.', '#', '.', '#', '.'], ['.', '#', 'X', '#', '.', '#', '.', '.', '.'], ['.', '#', '#', '#', 'F', '#', '.', '.', '.'], ['.', '#', '.', '#', '#', '#', '.', '#', 'X'], ['.', '.', 'X', '.', '.', '.', '.', '#', '.'], ['.', '.', '.', '.', '#', '.', '.', '#', 'R'] ]; class MazeGame { constructor(maze) { this.maze = this.copyMaze(maze); } copyMaze(maze) { return JSON.parse(JSON.stringify(maze)); } findPlayerPosition() { let column = -1; const row = this.maze.findIndex((row) => { const index = row.findIndex((item, index) => { return item === MazeElements.Player; }); if (index !== -1) { column = index; return true; } return false; }); return [row, column]; } canMoveHere([x, y]) { const maze = this.maze; return x !== -1 && x !== maze.length && y !== -1 && y !== maze[x].length && maze[x][y] !== MazeElements.Wall; } calcReward([x, y]) { let reward = -1; const maze = this.maze; if (this.canMoveHere([x, y])) { switch (maze[x][y]) { case MazeElements.Treasure: reward = 200; break; case MazeElements.Finish: reward = 1000; break; case MazeElements.Trap: reward = -200; break; } } else { reward = -10; } return reward; } move([pX, pY], [x, y]) { this.maze[pX][pY] = '.'; this.maze[x][y] = 'P'; } performAction(action) { const playerPosition = this.findPlayerPosition(); if (playerPosition[0] === -1 || playerPosition[1] === -1) { throw new Error('Missing player in maze'); } let positionAfterMove; switch (action) { case Action.Left: positionAfterMove = [playerPosition[0], playerPosition[1] - 1]; break; case Action.Down: positionAfterMove = [playerPosition[0] + 1, playerPosition[1]]; break; case Action.Right: positionAfterMove = [playerPosition[0], playerPosition[1] + 1]; break; case Action.Up: positionAfterMove = [playerPosition[0] - 1, playerPosition[1]]; break; default: throw new Error('Missing action'); } const reward = this.calcReward(positionAfterMove); let finish = false; if (this.canMoveHere(positionAfterMove)) { finish = this.maze[positionAfterMove[0]][positionAfterMove[1]] === MazeElements.Finish; if (finish) { } this.move(playerPosition, positionAfterMove); } return [this.maze, reward, finish]; } } function main() { return __awaiter(this, void 0, void 0, function* () { const agent = new double_q_learning_agent_1.DoubleQLearningAgent([Action.Left, Action.Right, Action.Up, Action.Down], pick_action_strategy_1.decayingEpsilonSoftmaxGreedyPickAction(0.05, 0.99, 3000)); let betsScore = -Infinity; console.log('Start maze'); console.log(maze); for (let numberOfPlay = 0; numberOfPlay < Infinity; numberOfPlay++) { let score = 0; const game = new MazeGame(maze); let endGame = false; const maxSteps = 10000; let stepCount = 0; while (!endGame) { const step = yield agent.play(game.maze.toString()); const [maze, reward, finish] = game.performAction(step.action); yield agent.reward(step, reward); score += reward; if (finish && betsScore < score) { betsScore = score; const memorySize = yield agent.memory.size(); console.log(` ------------------------------- numberOfPlay: ${numberOfPlay}, score: ${score} episode: ${agent.episode} memorySize: ${memorySize} ------------------------------- `); console.log(maze); } stepCount += 1; if (stepCount > maxSteps) { break; } endGame = finish; } yield agent.learn(); } }); } main();