UNPKG

stella-htn-js

Version:

A flexible Hierarchical Task Network (HTN) planner for creating complex AI behaviors in JavaScript.

891 lines (804 loc) 30.2 kB
/** * Provides a generic interface to interact with the world state. * This class acts as a bridge between the planner and the user's * specific state management system (e.g., a simple object, an ECS). * * @performance * The `clone` function is the most performance-critical part of this class. * For production game use, it is HIGHLY recommended to provide an efficient, * specialized cloning function. The modern `structuredClone()` is an excellent * default choice, being much faster than `JSON.parse(JSON.stringify(state))`. */ export class WorldStateProxy { constructor({ getState, setState, clone, incrementState, generateCacheKey, updateObject, }) { if ( typeof getState !== "function" || typeof setState !== "function" || typeof clone !== "function" ) { throw new Error( "WorldStateProxy requires 'getState', 'setState', and 'clone' functions." ); } this._getState = getState; this._setState = setState; this._clone = clone; this._incrementState = incrementState; // Optional: A function to generate a unique key for the current state for caching. // If not provided, the planner will fall back to a slower stringify method. this._generateCacheKey = generateCacheKey; this._updateObject = updateObject; } get(key) { return this._getState(key); } set(key, value) { this._setState(key, value); } clone() { return this._clone(); } increment(key, value = 1) { if (this._incrementState) { this._incrementState(key, value); } else { const currentValue = this.get(key); this.set( key, (typeof currentValue === "number" ? currentValue : 0) + value ); } } updateObject(id, newProperties) { if (this._updateObject) { this._updateObject(id, newProperties); } else { const objects = this.get("worldObjects"); if (!Array.isArray(objects)) { console.warn( `WorldStateProxy: 'worldObjects' is not an array. Cannot update object.` ); return; } const objectIndex = objects.findIndex((o) => o.id === id); if (objectIndex !== -1) { objects[objectIndex] = { ...objects[objectIndex], ...newProperties, }; this.set("worldObjects", objects); } } } getCacheKey() { if (this._generateCacheKey) { return this._generateCacheKey(); } // Fallback for caching if no custom key generator is provided. // WARNING: This can be slow for large states. return JSON.stringify(this._getState()); } } /** * Base class for all tasks in the HTN. */ export class Task { constructor(name) { this.name = name; } } /** * Represents a single, executable action in the world. */ export class PrimitiveTask extends Task { constructor( name, { conditions = () => true, effects = () => {}, operator = () => {}, } = {} ) { super(name); this.isPrimitive = true; this.conditions = conditions; this.effects = effects; this.operator = operator; } executeOperator(context) { // In a real game engine, this would trigger animations, sounds, etc. console.log( `Executing operator for: ${this.name} with context:`, context ); this.operator(context); } } /** * A specialized PrimitiveTask for handling Smart Object interactions. * It holds a reference to the specific object instance it interacts with. */ export class SmartObjectTask extends PrimitiveTask { constructor(name, smartObject, logic) { super(name, { // The logic functions are now closures that capture the specific `smartObject` conditions: (ws, ctx) => logic.conditions(ws, ctx, smartObject), effects: (ws, ctx) => logic.effects(ws, ctx, smartObject), operator: (ctx) => { // Automatically add the target object to the context for the operator const operatorContext = { ...ctx, smartObject }; logic.operator(operatorContext); }, }); this.smartObject = smartObject; this.actionType = logic.actionType; // Store the generic action type for reference } } /** * Represents a high-level task that can be decomposed into smaller subtasks. * Each method is a potential way to accomplish the task. * * @typedef {object} Method * @property {string} name - The name of the method for debugging. * @property {number} [priority=0] - The priority of the method. Higher numbers are chosen first. * @property {function(WorldStateProxy, object): boolean} [conditions] - A function that must return true for this method to be applicable. * @property {Array<Task|string>} subtasks - An array of subtasks to be executed if this method is chosen. */ export class CompoundTask extends Task { /** * @param {string} name - The name of the compound task. * @param {Array<Method>} [methods=[]] - An array of methods. */ constructor(name, methods = []) { super(name); this.isPrimitive = false; // Validate and normalize methods methods.forEach((m, i) => { if (!m.name || !Array.isArray(m.subtasks)) { throw new Error( `Invalid method structure at index ${i} in CompoundTask '${name}'. Methods must have a 'name' and a 'subtasks' array.` ); } }); // Sort methods by priority ONCE on creation. This optimizes finding applicable methods. this.methods = methods .map((m) => ({ ...m, priority: m.priority || 0, })) .sort((a, b) => b.priority - a.priority); } /** * Finds the first applicable method based on pre-sorted priority. * @param {WorldStateProxy} worldState - The current state of the world. * @param {object} context - The current planning context. * @returns {{method: Method, index: number}|null} The best applicable method and its index, or null. */ findApplicableMethod(worldState, context) { for (let i = 0; i < this.methods.length; i++) { const method = this.methods[i]; if (!method.conditions || method.conditions(worldState, context)) { return { method, index: i, }; } } return null; } } // --- Custom Error Types --- export class PlanningError extends Error { constructor(message) { super(message); this.name = "PlanningError"; } } export class TaskNotFoundError extends PlanningError { constructor(taskName) { super(`Task '${taskName}' not found in registry.`); this.name = "TaskNotFoundError"; this.taskName = taskName; } } export class PlanningTimeoutError extends PlanningError { constructor(message) { super(message); this.name = "PlanningTimeoutError"; } } /** * The main planner class that implements the forward-decomposition HTN algorithm. */ export class Planner { constructor(config = {}) { this.config = { maxIterations: 1000, maxTime: Infinity, // Milliseconds enablePlanCaching: true, frameBudget: 4, // Default to 4ms for planning per frame/tick. ...config, }; this.logCallback = () => {}; this.taskRegistry = {}; this.planCache = new Map(); this.metrics = {}; } registerTask(task) { this.taskRegistry[task.name] = task; } setLogCallback(callback) { this.logCallback = callback; } getMetrics() { return this.metrics; } clearCache() { this.planCache.clear(); } log(type, message, data = {}) { this.logCallback(type, message, data); } _findBacktrackPoint(decompositionHistory) { while (decompositionHistory.length > 0) { const lastDecomp = decompositionHistory.pop(); const { compoundTask, lastMethodIndex, worldStateSnapshot, contextSnapshot, planState, remainingTasks, } = lastDecomp; for ( let i = lastMethodIndex + 1; i < compoundTask.methods.length; i++ ) { const nextMethod = compoundTask.methods[i]; if ( !nextMethod.conditions || nextMethod.conditions(worldStateSnapshot, contextSnapshot) ) { this.log( "info", `Found backtrack point. Trying method '${nextMethod.name}' for ${compoundTask.name}.` ); decompositionHistory.push({ ...lastDecomp, lastMethodIndex: i, }); return { worldState: worldStateSnapshot, context: contextSnapshot, planState: planState, tasksToProcess: [ ...nextMethod.subtasks, ...remainingTasks, ], }; } } this.log( "info", `No more methods to try for ${compoundTask.name}. Backtracking further...` ); } return null; } async *findPlan({ tasks, worldState, context = {} }) { const overallStartTime = performance.now(); let frameStartTime = overallStartTime; // FIX: Initialize metrics object for every run this.metrics = { iterations: 0, backtrackCount: 0, planningTime: 0, cacheHit: false, }; TaskDiscoverer.discoverAndRegister(this, worldState); const cacheKey = this.config.enablePlanCaching && tasks instanceof Task ? `${tasks.name}:${worldState.getCacheKey()}` : null; if (cacheKey && this.planCache.has(cacheKey)) { this.metrics.cacheHit = true; this.metrics.planningTime = performance.now() - overallStartTime; // Yield once to indicate completion, then return the cached value yield { status: "completed", fromCache: true, }; return this.planCache.get(cacheKey); } this.log("info", `Starting incremental planning...`); let tasksToProcess = Array.isArray(tasks) ? [...tasks] : [tasks]; let workingWorldState = worldState.clone(); let workingContext = structuredClone(context); let finalPlan = []; const decompositionHistory = []; while (tasksToProcess.length > 0) { this.metrics.iterations++; if (this.metrics.iterations > this.config.maxIterations) { // FIX: Throw timeout error correctly from within the generator throw new PlanningTimeoutError( `Planning timed out after ${this.config.maxIterations} iterations.` ); } if (this.metrics.iterations % 20 === 0) { const now = performance.now(); if (now - frameStartTime > this.config.frameBudget) { this.log( "info", `Frame budget of ${this.config.frameBudget}ms exceeded. Yielding...` ); yield { status: "running", iterations: this.metrics.iterations, }; await new Promise((resolve) => setTimeout(resolve, 0)); frameStartTime = performance.now(); } } let currentTaskOrName = tasksToProcess.shift(); let currentTask = null; // FIX: Handle task lookup when task is specified by name (string) if (typeof currentTaskOrName === "string") { currentTask = this.taskRegistry[currentTaskOrName]; if (!currentTask) { throw new TaskNotFoundError(currentTaskOrName); } } else { currentTask = currentTaskOrName; } if (currentTask.isPrimitive) { if (currentTask.conditions(workingWorldState, workingContext)) { currentTask.effects(workingWorldState, workingContext); finalPlan.push(currentTask); } else { this.log( "fail", `Conditions FAILED for ${currentTask.name}. Backtracking...` ); this.metrics.backtrackCount++; const backtrackState = this._findBacktrackPoint(decompositionHistory); if (backtrackState) { workingWorldState = backtrackState.worldState.clone(); workingContext = structuredClone( backtrackState.context ); finalPlan = [...backtrackState.planState]; tasksToProcess = backtrackState.tasksToProcess; } else { this.log("fail", "Backtracking failed. No plan found."); return null; } } } else { // Compound Task const result = currentTask.findApplicableMethod( workingWorldState, workingContext ); if (result) { const { method, index } = result; decompositionHistory.push({ compoundTask: currentTask, lastMethodIndex: index, worldStateSnapshot: workingWorldState.clone(), contextSnapshot: structuredClone(workingContext), planState: [...finalPlan], remainingTasks: [...tasksToProcess], }); tasksToProcess.unshift(...method.subtasks); } else { this.log( "fail", `No applicable method found for ${currentTask.name}. Backtracking...` ); this.metrics.backtrackCount++; const backtrackState = this._findBacktrackPoint(decompositionHistory); if (backtrackState) { workingWorldState = backtrackState.worldState.clone(); workingContext = structuredClone( backtrackState.context ); finalPlan = [...backtrackState.planState]; tasksToProcess = backtrackState.tasksToProcess; } else { this.log("fail", "Backtracking failed. No plan found."); return null; } } } } this.metrics.planningTime = performance.now() - overallStartTime; const result = { plan: finalPlan, context: workingContext, }; if (cacheKey) this.planCache.set(cacheKey, result); return result; } } /** * Manages interruption policies and prevents infinite re-planning loops. */ export class InterruptionManager { constructor(config = {}) { this.config = { maxReplanAttempts: 3, cooldownDuration: 1000, // ms backoffMultiplier: 2, enableCooldown: true, ...config, }; this.replanHistory = new Map(); // taskName -> { count, lastAttempt, cooldownUntil } this.activeInterruptors = new Map(); // interruptorId -> InterruptorState this.logCallback = () => {}; } setLogCallback(callback) { this.logCallback = callback; } log(type, message, data = {}) { this.logCallback(type, message, data); } registerInterruptor(id, interruptor) { this.activeInterruptors.set(id, { ...interruptor, id, lastTriggered: 0, triggerCount: 0, }); } unregisterInterruptor(id) { this.activeInterruptors.delete(id); } canReplan(taskName) { const now = Date.now(); const history = this.replanHistory.get(taskName); if (!history) return true; if (this.config.enableCooldown && now < history.cooldownUntil) { this.log( "warn", `Replan blocked for ${taskName}: still in cooldown`, { remainingCooldown: history.cooldownUntil - now, } ); return false; } if (history.count >= this.config.maxReplanAttempts) { this.log( "warn", `Replan blocked for ${taskName}: max attempts exceeded`, { attempts: history.count, maxAttempts: this.config.maxReplanAttempts, } ); return false; } return true; } recordReplanAttempt(taskName) { const now = Date.now(); const history = this.replanHistory.get(taskName) || { count: 0, lastAttempt: 0, cooldownUntil: 0, }; history.count++; history.lastAttempt = now; if (this.config.enableCooldown) { const cooldownDuration = this.config.cooldownDuration * Math.pow(this.config.backoffMultiplier, history.count - 1); history.cooldownUntil = now + cooldownDuration; } this.replanHistory.set(taskName, history); this.log("info", `Recorded replan attempt for ${taskName}`, { attempt: history.count, cooldownUntil: history.cooldownUntil, }); } resetReplanHistory(taskName) { this.replanHistory.delete(taskName); this.log("info", `Reset replan history for ${taskName}`); } checkInterruptions(worldState, context, currentTask) { const now = Date.now(); for (const [id, interruptor] of this.activeInterruptors) { try { const result = interruptor.check( worldState, context, currentTask ); if (result && result.interrupted) { interruptor.lastTriggered = now; interruptor.triggerCount++; this.log("info", `Interruption triggered by ${id}`, { reason: result.reason, triggerCount: interruptor.triggerCount, }); return { ...result, interruptorId: id, timestamp: now, }; } } catch (error) { this.log("error", `Interruptor ${id} threw an error:`, error); } } return null; } cleanup(maxAge = 300000) { const now = Date.now(); const cutoff = now - maxAge; for (const [taskName, history] of this.replanHistory) { if (history.lastAttempt < cutoff) { this.replanHistory.delete(taskName); } } } } /** * Central repository for the logic of smart actions. */ export const smartActionLibrary = {}; /** * Discovers tasks from smart objects in the world and registers them with the planner. */ export class TaskDiscoverer { static discoverAndRegister(planner, worldState) { const worldObjects = worldState.get("worldObjects") || []; let discoveredCount = 0; for (const obj of worldObjects) { if (!obj.smartLink || !obj.smartLink.actionType) continue; const actionType = obj.smartLink.actionType; const actionLogic = smartActionLibrary[actionType]; if (actionLogic) { const taskName = `${actionType}_${obj.id}`; // Avoid re-registering the same task instance if (planner.taskRegistry[taskName]) continue; const newTask = new SmartObjectTask(taskName, obj, { ...actionLogic, actionType, }); planner.registerTask(newTask); discoveredCount++; } } return discoveredCount; } } /** * Plan executor with sophisticated interruption handling and automatic replanning. */ export class PlanExecutor { constructor(plan, context, options = {}) { this.plan = [...plan]; this.context = context; this.currentIndex = 0; this.interruptionManager = options.interruptionManager || new InterruptionManager(); this.planner = options.planner; this.replanFn = options.replanFn; this.rootTask = options.rootTask; this.logCallback = options.logCallback || (() => {}); this.state = "running"; // 'running', 'interrupted', 'completed', 'failed' this.lastInterruption = null; this.executionStartTime = Date.now(); this.interruptionManager.setLogCallback(this.logCallback); } isDone() { return ( this.currentIndex >= this.plan.length || this.state === "completed" ); } isFailed() { return this.state === "failed"; } isInterrupted() { return this.state === "interrupted"; } async tick(worldState, options = {}) { if (this.isDone()) { if (this.state !== "completed") { this.state = "completed"; this.logCallback( "success", "Plan execution completed successfully." ); } return { status: "completed", task: null, executionTime: Date.now() - this.executionStartTime, }; } if (this.isFailed()) { return { status: "failed", task: null, reason: this.lastInterruption, }; } const currentTask = this.plan[this.currentIndex]; const interruption = this.interruptionManager.checkInterruptions( worldState, this.context, currentTask ); if (interruption) { return this.handleInterruption(interruption, worldState, options); } try { if ( currentTask.conditions && !currentTask.conditions(worldState, this.context) ) { const conditionFailure = { interrupted: true, reason: "preconditions_failed", taskName: currentTask.name, message: `Preconditions failed for ${currentTask.name} just before execution.`, }; return this.handleInterruption( conditionFailure, worldState, options ); } currentTask.executeOperator(this.context); if (currentTask.effects) { currentTask.effects(worldState, this.context); } this.currentIndex++; if (this.isDone()) { this.state = "completed"; this.logCallback( "success", "Plan execution completed successfully." ); return { status: "completed", task: currentTask, progress: 1, executionTime: Date.now() - this.executionStartTime, }; } return { status: "executing", task: currentTask, progress: this.currentIndex / this.plan.length, }; } catch (error) { this.logCallback( "error", `Task execution failed: ${currentTask.name}`, error ); const executionFailure = { interrupted: true, reason: "execution_failed", taskName: currentTask.name, error: error.message, }; return this.handleInterruption( executionFailure, worldState, options ); } } async handleInterruption(interruption, worldState, options = {}) { this.state = "interrupted"; this.lastInterruption = interruption; this.logCallback("warn", "Plan interrupted", interruption); if (!options.enableReplanning || !this.replanFn || !this.rootTask) { this.state = "failed"; return { status: "failed", interruption, reason: "replanning_not_available", }; } const taskName = this.rootTask.name || "unknown_root_task"; if (!this.interruptionManager.canReplan(taskName)) { this.state = "failed"; return { status: "failed", interruption, reason: "replan_cooldown_or_max_attempts", }; } try { this.interruptionManager.recordReplanAttempt(taskName); const replanGenerator = this.replanFn({ tasks: this.rootTask, worldState: worldState, context: this.context, }); // FIX: This is a simplified helper to run the generator to completion for the executor. const runGenerator = async () => { let lastValue; while (true) { const { value, done } = await replanGenerator.next(); if (done) { return value; } lastValue = value; } }; const replanResult = await runGenerator(); if (replanResult && replanResult.plan.length > 0) { this.logCallback( "success", "Replanning successful. New plan generated." ); this.plan = replanResult.plan; this.context = replanResult.context; this.reset(); // Reset executor state for the new plan this.interruptionManager.resetReplanHistory(taskName); return { status: "replanned", interruption, newPlan: this.plan, }; } else { this.state = "failed"; this.logCallback( "error", "Replanning failed: No new plan could be found." ); return { status: "failed", interruption, reason: "replan_failed_to_find_plan", }; } } catch (error) { this.state = "failed"; this.logCallback( "error", "Replanning failed with a critical error.", error ); return { status: "failed", interruption, reason: "replanning_error", error, }; } } reset() { this.currentIndex = 0; this.state = "running"; this.lastInterruption = null; this.executionStartTime = Date.now(); } getStats() { return { totalTasks: this.plan.length, completedTasks: this.currentIndex, progress: this.plan.length > 0 ? this.currentIndex / this.plan.length : 0, state: this.state, executionTime: Date.now() - this.executionStartTime, lastInterruption: this.lastInterruption, }; } }