UNPKG

@woosh/meep-engine

Version:

Pure JavaScript game engine. Fully featured and production ready.

276 lines (204 loc) 7.48 kB
import { assert } from "../../../core/assert.js"; import { returnZero } from "../../../core/function/returnZero.js"; import { mix } from "../../../core/math/mix.js"; import { seededRandom } from "../../../core/math/random/seededRandom.js"; import { MoveEdge } from "./MoveEdge.js"; import { StateNode, StateType } from "./StateNode.js"; /** * From: A Survey of Monte Carlo Tree Search Methods * The value Cp = 1/√2 was shown by Kocsis and Szepesvari [120] to satisfy the Hoeffding ineqality with rewards in the range [0, 1] * @type {number} */ const C_ks = 1 / Math.sqrt(2); /** * * @param {StateNode} parent * @param {StateNode} child * @returns {number} */ function computeNodeSelectionScore(parent, child) { const playouts = child.playouts; if (playouts === 0) { // child has 0 playouts, this can only happen if it ended up with no moves, avoid division by 0 return 0; } // Exploitation heuristic const Q = mix((child.wins + 1) / playouts, child.heuristicValue, 0.65); // Based on UCB1 // exploration heuristic const u = Math.sqrt((2 * Math.log(parent.playouts)) / playouts); return Q + C_ks * u; } /** * @template S * @author Alex Goldring * @copyright Company Named Limited (c) 2025 */ export class MonteCarloTreeSearch { /** * * @type {S} */ rootState = null; /** * * @type {StateNode|null} */ root = null; /** * * @type {function(state:S, source:StateNode):MoveEdge[]} */ computeValidMoves = null; /** * * @type {function(state:S):StateType} */ computeTerminalFlag = null; /** * Depth to which plays will be explored * @type {number} */ maxExplorationDepth = 1000; /** * * @type {Function} */ random = seededRandom(0); /** * @param {S} rootState * @param {function(state:S, source:StateNode):MoveEdge[]} computeValidMoves * @param {function(state:S):StateType} computeTerminalFlag * @param {function(S):S} cloneState * @param {function(StateNode, S):number} heuristic Estimation function for evaluation of intermediate stated, guides exploration */ initialize( { rootState, computeValidMoves, computeTerminalFlag, cloneState, heuristic = returnZero } ) { assert.isFunction(computeValidMoves, `computeValidMoves`); assert.isFunction(computeTerminalFlag, `computeTerminalFlag`); assert.isFunction(cloneState, `cloneState`); this.computeValidMoves = computeValidMoves; this.computeTerminalFlag = computeTerminalFlag; this.cloneState = cloneState; this.heuristic = heuristic; this.rootState = rootState; this.root = new StateNode(); } /** * * @param {StateNode} node * @param {S} state * @returns {StateNode} */ selectRandom(node, state) { let score; let i, bestScore, bestMove; const random = this.random; while ( node.isExpanded() && node.moves.length > 0 && !node.isTerminal() ) { bestScore = Number.NEGATIVE_INFINITY; bestMove = null; const moves = node.moves; const numMoves = moves.length; assert.notEqual(numMoves, 0, 'number of moves is 0, this is invalid state'); for (i = 0; i < numMoves; i++) { const move = moves[i]; const randomRoll = random(); if (move.isTargetMaterialized()) { const child = move.target; const s = computeNodeSelectionScore(node, child); assert.notNaN(s, 'computed Node score'); assert.isFiniteNumber(s, `computed Node score`); score = s + randomRoll; } else { //use a constant value for unexplored nodes score = randomRoll * 100; } if (score > bestScore) { bestScore = score; bestMove = move; } } if (!bestMove.isTargetMaterialized()) { //materialize the target state materializedEdgeTarget(state, node, bestMove, this.computeTerminalFlag, this.heuristic); } else { //just follow the edge bestMove.move(state); } node = bestMove.target; } return node; } /** * Perform a playout from the root node * @returns {S} final state of the playout */ playout() { const computeValidMoves = this.computeValidMoves; const computeTerminalFlag = this.computeTerminalFlag; const state = this.cloneState(this.rootState); assert.notEqual(state, this.rootState, 'cloneState must produce a new state object, instead it produced the same one'); let node = this.root; while (!node.isTerminal() && node.depth < this.maxExplorationDepth) { if (!node.isExpanded()) { node.expand(state, computeValidMoves, computeTerminalFlag); } const child = this.selectRandom(node, state); if (child === node) { // prevent infinite recursion // this should not happen? break; } node = child; } if (!node.isTerminal() && node.depth >= this.maxExplorationDepth) { //cap the state by depth, propagate heuristic score node.type = StateType.DepthCapped; } // record play-through const terminalFlag = node.type; if (terminalFlag === StateType.Win) { node.addPlayouts(1, 1, 0); } else if (terminalFlag === StateType.Loss) { node.addPlayouts(1, 0, 1); } else if (terminalFlag === StateType.Tie || terminalFlag === StateType.DepthCapped) { node.addPlayouts(1, 0, 0); } return state; } } /** * @template S * @param {S} state * @param {StateNode} source * @param {MoveEdge} edge * @param {function(S):StateType} computeTerminalFlag * @param {function(StateNode, S)} heuristic */ function materializedEdgeTarget(state, source, edge, computeTerminalFlag, heuristic) { const child = new StateNode(); child.parent = source; child.depth = source.depth + 1; const computedState = edge.move(state); const terminalFlag = computeTerminalFlag(computedState); assert.enum(terminalFlag, StateType, 'terminalFlag'); child.type = terminalFlag; edge.target = child; const childHeuristicScore = heuristic(child, computedState); assert.notNaN(childHeuristicScore, 'childHeuristicScore'); child.heuristicValue = childHeuristicScore; // bubble the heuristic score up the tree // child.bubbleUpHeuristicScore(); // heuristic value changes sign depending on the team making the move, so aggregation becomes tricky return computedState; }