UNPKG

@astreus-ai/astreus

Version:

AI Agent Framework with Chat Management

643 lines (561 loc) 21.5 kB
import { v4 as uuidv4 } from "uuid"; import { TaskManagerInstance, TaskConfig, TaskInstance, TaskResult, TaskManagerConfig, TaskStatus, } from "../types"; import { MemoryInstance, DatabaseInstance, ProviderModel } from "../types"; import { createDatabase } from "../database"; import { logger } from "../utils"; import { Task } from "./task"; import { validateRequiredParam } from "../utils/validation"; /** * Task Manager class that manages multiple tasks * * This class provides functionality for: * - Creating and managing multiple tasks * - Executing tasks with dependencies in order * - Persisting task state to the database * - Restoring tasks from previous sessions */ export class TaskManager implements TaskManagerInstance { private readonly tasks = new Map<string, TaskInstance>(); private tasksLoaded = false; private loadingPromise: Promise<void> | null = null; private config: TaskManagerConfig; private agentId?: string; private sessionId?: string; private memory?: MemoryInstance; private database?: DatabaseInstance; private providerModel?: ProviderModel; /** * Create a new TaskManager instance * @param config Configuration options for the task manager */ constructor(config?: TaskManagerConfig) { this.config = config || { concurrency: 5 }; this.agentId = config?.agentId; this.sessionId = config?.sessionId; this.memory = config?.memory; this.database = config?.database; this.providerModel = config?.providerModel; // Load existing tasks from the database this.loadTasksFromDatabase().catch((err) => { logger.error("Error loading tasks from database:", err); }); logger.info("Task manager initialized"); } /** * Add an existing task to the manager * @param task Task instance or configuration to add * @param model Optional LLM model to use for tool selection * @returns The added task instance */ public addExistingTask(task: TaskInstance | TaskConfig, model?: ProviderModel): TaskInstance { try { // If this is a configuration, first check if we already have this task by ID if (!(task instanceof Task) && task.id) { // Look for existing task in our memory const existingTask = this.tasks.get(task.id); if (existingTask) { logger.debug(`Task ${task.id} already exists in manager, returning existing instance`); return existingTask; } } // Ensure task.plugins only contains string values if (!(task instanceof Task) && task.plugins) { const validPlugins = task.plugins.filter(plugin => typeof plugin === 'string'); if (validPlugins.length !== task.plugins.length) { logger.warn(`TaskManager: Filtered out ${task.plugins.length - validPlugins.length} invalid plugins from task config`); task.plugins = validPlugins; } } // Determine which model to use, in order of preference: // 1. Explicitly provided model parameter // 2. Task's model (if task is a config with model) // 3. TaskManager's providerModel let taskModel = model; if (!taskModel) { if (task instanceof Task) { taskModel = task.config.model; } else if ('model' in task) { taskModel = task.model; } } if (!taskModel && this.providerModel) { taskModel = this.providerModel; } // If this is not already a Task instance, create one const taskInstance = task instanceof Task ? task : new Task( { ...(task as TaskConfig), agentId: task.agentId || this.agentId, sessionId: task.sessionId || this.sessionId, model: taskModel // Pass the determined model to the task }, this.memory, undefined, // No need to pass model here as we already set it in config this.database ); // If task was created without memory, set it now if (this.memory && taskInstance instanceof Task) { taskInstance.setMemory(this.memory); } // Add task to memory this.tasks.set(taskInstance.id, taskInstance); // Check for a parent in dependencies and add it if needed this.addTaskWithParent(taskInstance); // Task is already saved to database via createTask static method logger.debug(`Task "${taskInstance.config.name}" (${taskInstance.id}) added to manager`); return taskInstance; } catch (error) { logger.error("Error adding existing task:", error); throw error; } } /** * Helper method to check for parent ID in dependencies and add it if needed * @param task Task to check for parent dependencies */ private addTaskWithParent(task: TaskInstance): void { try { // Add parent dependency if it exists if (task.config.dependencies && task.config.dependencies.length > 0) { const parentId = task.config.dependencies[0]; const parent = this.getTask(parentId); if (parent) { // Parent exists, nothing to do return; } // Parent doesn't exist, warn about it logger.warn(`Parent task ${parentId} not found for task ${task.id}`); } } catch (error) { logger.error(`Error checking parent dependencies for task ${task.id}:`, error); } } /** * Get a task by ID * @param id ID of the task to retrieve * @returns Task instance if found, undefined otherwise */ public getTask(id: string): TaskInstance | undefined { return this.tasks.get(id); } /** * Get all tasks managed by this instance * @returns Array of all task instances */ public getAllTasks(): TaskInstance[] { return Array.from(this.tasks.values()); } /** * Create and add a new task * @param config Configuration for the new task * @param model Optional LLM model to use for tool selection * @returns The created task instance */ public async createTask(config: TaskConfig, model?: ProviderModel): Promise<TaskInstance> { try { // Determine which model to use, in order of preference: // 1. Explicitly provided model parameter // 2. Task's model (from config) // 3. TaskManager's providerModel let taskModel = model; if (!taskModel && config.model) { taskModel = config.model; } if (!taskModel && this.providerModel) { taskModel = this.providerModel; } // Create task with the manager's agent and session IDs if not provided const updatedConfig = { ...config, agentId: config.agentId || this.agentId, sessionId: config.sessionId || this.sessionId, model: taskModel }; // Create the task with memory and database const task = await Task.createTask(updatedConfig, this.memory, undefined, this.database); // Add task to manager this.addExistingTask(task); logger.info(`Created new task: "${task.config.name}" (${task.id})`); return task; } catch (error) { logger.error("Error creating task:", error); throw error; } } /** * Cancel a task by ID * @param id ID of the task to cancel * @returns true if task was found and canceled, false otherwise */ public cancelTask(id: string): boolean { try { const task = this.getTask(id); if (task) { task.cancel(); logger.info(`Task ${id} canceled`); return true; } logger.warn(`Task ${id} not found for cancellation`); return false; } catch (error) { logger.error(`Error canceling task ${id}:`, error); return false; } } /** * Load tasks from database * @returns Promise that resolves when tasks are loaded */ private async loadTasksFromDatabase(): Promise<void> { try { // Use database from instance if provided, otherwise create a new one const db = this.database || await createDatabase(); const tasksTable = db.getTable('tasks'); // Load all tasks from database const taskRecords = await tasksTable.find(); for (const record of taskRecords) { try { // Parse JSON fields const plugins = record.plugins ? JSON.parse(record.plugins) : []; const input = record.input ? JSON.parse(record.input) : null; const dependencies = record.dependencies ? JSON.parse(record.dependencies) : []; const result = record.result ? JSON.parse(record.result) : null; // Create task config from database record const taskConfig: TaskConfig = { id: record.id, name: record.name, description: record.description, plugins, input, dependencies, agentId: record.agentId, sessionId: record.sessionId, }; // Create task instance const task = new Task(taskConfig, this.memory, this.providerModel, db); // Restore task state task.status = record.status as TaskStatus; task.retries = record.retries || 0; task.createdAt = new Date(record.createdAt); task.startedAt = record.startedAt ? new Date(record.startedAt) : undefined; task.completedAt = record.completedAt ? new Date(record.completedAt) : undefined; task.result = result; task.agentId = record.agentId; task.sessionId = record.sessionId; task.contextId = record.contextId; // Add to tasks map this.tasks.set(task.id, task); logger.debug(`Loaded task ${task.id} from database with status: ${task.status}`); } catch (error) { logger.error(`Error loading task ${record.id} from database:`, error); // Continue loading other tasks } } logger.info(`Loaded ${this.tasks.size} tasks from database`); } catch (error) { logger.error("Error loading tasks from database:", error); // Don't throw error - allow task manager to continue without database tasks } } /** * Wait for tasks to be loaded from database * @returns Promise that resolves when tasks are loaded */ public async waitForTasksLoaded(): Promise<void> { if (this.tasksLoaded) { return; } if (this.loadingPromise) { return this.loadingPromise; } // If tasks haven't been loaded yet, load them now return this.loadTasksFromDatabase(); } /** * Execute a specific task by ID * @param id ID of the task to execute * @param input Optional input data to pass to the task * @returns Promise that resolves with the task execution result */ public async executeTask(id: string, input?: any): Promise<TaskResult> { try { await this.waitForTasksLoaded(); const task = this.getTask(id); if (!task) { throw new Error(`Task with ID ${id} not found`); } logger.info(`Executing task: "${task.config.name}" (${id})`); return task.execute(input); } catch (error) { logger.error(`Error executing task ${id}:`, error); throw error; } } /** * Get all tasks * @deprecated Use getAllTasks() instead for newer implementations * @returns Array of all task instances */ public getTasks(): TaskInstance[] { logger.debug("DEPRECATED: getTasks() called, use getAllTasks() instead"); return this.getAllTasks(); } /** * Get tasks by agent ID * @param agentId Agent ID to filter by * @returns Array of tasks belonging to the specified agent */ public getTasksByAgent(agentId: string): TaskInstance[] { return Array.from(this.tasks.values()).filter( (task) => task.config.agentId === agentId ); } /** * Get tasks by session ID * @param sessionId Session ID to filter by * @returns Array of tasks belonging to the specified session */ public getTasksBySession(sessionId: string): TaskInstance[] { return Array.from(this.tasks.values()).filter( (task) => task.config.sessionId === sessionId ); } /** * Set the agent ID for this task manager * @param agentId Agent ID to set */ public setAgentId(agentId: string): void { this.agentId = agentId; logger.debug(`Task manager agent ID set to ${agentId}`); } /** * Set the session ID for this task manager * @param sessionId Session ID to set */ public setSessionId(sessionId: string): void { this.sessionId = sessionId; logger.debug(`Task manager session ID set to ${sessionId}`); } /** * Set the memory instance for this task manager * @param memory Memory instance to use for task context storage */ public setMemory(memory: MemoryInstance): void { try { this.memory = memory; // Update memory for all tasks for (const task of this.tasks.values()) { if (task instanceof Task) { task.setMemory(memory); } } logger.debug("Memory instance set for task manager"); } catch (error) { logger.error("Error setting memory for task manager:", error); throw error; } } /** * Set the provider model to use for tasks * @param model Provider model to use */ public setProviderModel(model: ProviderModel): void { this.providerModel = model; logger.debug(`Task manager provider model set to ${model.name}`); } /** * Get the current provider model * @returns The current provider model or undefined */ public getProviderModel(): ProviderModel | undefined { return this.providerModel; } /** * Run specified tasks (or all if no IDs provided) * @param taskIds Optional array of task IDs to run * @returns Promise that resolves with a map of task IDs to results */ public async run(taskIds?: string[]): Promise<Map<string, TaskResult>> { await this.waitForTasksLoaded(); const results = new Map<string, TaskResult>(); let tasksToRun: TaskInstance[]; // Determine which tasks to run if (taskIds && taskIds.length > 0) { tasksToRun = taskIds .map((id) => this.getTask(id)) .filter((task): task is TaskInstance => !!task); logger.info(`Running ${tasksToRun.length} specified tasks`); } else { tasksToRun = this.getAllTasks(); logger.info(`Running all ${tasksToRun.length} tasks`); } // Create a dependency graph and task execution order const taskDependencyMap = new Map<string, Set<string>>(); const dependentTasksMap = new Map<string, Set<string>>(); // Build dependency maps for (const task of tasksToRun) { const taskId = task.id; const dependencies = new Set<string>(); // Check both dependencies and dependsOn fields (for backward compatibility) if (task.config.dependencies && task.config.dependencies.length > 0) { for (const depId of task.config.dependencies) { if (this.getTask(depId)) { dependencies.add(depId); } } } // Add dependsOn dependencies (newer approach) if (task.config.dependsOn && task.config.dependsOn.length > 0) { for (const depId of task.config.dependsOn) { if (this.getTask(depId)) { dependencies.add(depId); } } } // Store in the dependency map taskDependencyMap.set(taskId, dependencies); // Update dependent tasks map (reverse mapping) for (const depId of dependencies) { if (!dependentTasksMap.has(depId)) { dependentTasksMap.set(depId, new Set<string>()); } dependentTasksMap.get(depId)!.add(taskId); } } // Track completed tasks and their results const completedTasks = new Set<string>(); const taskQueue: TaskInstance[] = []; // Initial pass: find tasks with no dependencies for (const task of tasksToRun) { const dependencies = taskDependencyMap.get(task.id) || new Set<string>(); if (dependencies.size === 0) { taskQueue.push(task); } } logger.debug(`Initial task queue: ${taskQueue.length} tasks ready to run`); // Execute tasks in dependency order with proper concurrency const concurrency = this.config.concurrency || 5; while (taskQueue.length > 0) { // Take up to concurrency tasks from the queue const batchTasks = taskQueue.splice(0, concurrency); // Execute this batch in parallel const batchPromises = batchTasks.map(async (task) => { try { // Pass in any dependent task outputs as input const dependencies = taskDependencyMap.get(task.id) || new Set<string>(); // If there are dependencies, prepare to collect their outputs if (dependencies.size > 0) { const dependencyOutputs: Record<string, any> = {}; let hasOutputs = false; // Collect outputs from all dependencies for (const depId of dependencies) { if (results.has(depId)) { const depResult = results.get(depId)!; if (depResult.success && depResult.output) { dependencyOutputs[depId] = depResult.output; hasOutputs = true; } } } // If we have dependency outputs, merge them with task input if (hasOutputs) { // Create merged input that preserves original task input const mergedInput = { ...(task.config.input || {}), _dependencyOutputs: dependencyOutputs }; logger.debug(`Task ${task.id} received outputs from ${Object.keys(dependencyOutputs).length} dependencies`); // Execute task with merged input const result = await task.execute(mergedInput); results.set(task.id, result); } else { // No usable outputs from dependencies, just run normally const result = await task.execute(); results.set(task.id, result); } } else { // No dependencies, run task with original input const result = await task.execute(); results.set(task.id, result); } // Mark this task as completed completedTasks.add(task.id); // Check if this task completion unblocks any dependent tasks if (dependentTasksMap.has(task.id)) { const dependentTasks = dependentTasksMap.get(task.id)!; for (const dependentId of dependentTasks) { // Get the dependent task const dependentTask = this.getTask(dependentId); if (!dependentTask) continue; // Check if all dependencies of this dependent task are completed const allDependenciesCompleted = Array.from(taskDependencyMap.get(dependentId) || new Set<string>()) .every(depId => completedTasks.has(depId)); // If all dependencies are completed, add this task to the queue if (allDependenciesCompleted && !completedTasks.has(dependentId)) { taskQueue.push(dependentTask); logger.debug(`Task ${dependentId} unblocked and added to queue`); } } } } catch (error) { logger.error(`Error executing task ${task.id}:`, error); results.set(task.id, { success: false, error: error instanceof Error ? error : new Error(String(error)) }); // Mark as completed even though it failed completedTasks.add(task.id); } }); // Wait for this batch to complete before processing next batch await Promise.all(batchPromises); } // Check for unresolved tasks (might be due to circular dependencies) const unresolvedTasks = tasksToRun.filter(t => !completedTasks.has(t.id)); if (unresolvedTasks.length > 0) { logger.warn(`${unresolvedTasks.length} tasks were not executed due to dependency issues or circular dependencies`); // Add failed results for these tasks for (const task of unresolvedTasks) { results.set(task.id, { success: false, error: new Error("Task not executed due to unresolved dependencies") }); } } logger.info(`Completed running ${completedTasks.size} tasks`); return results; } /** * Cancel a specific task * @deprecated Use cancelTask() instead for newer implementations * @param taskId ID of the task to cancel * @returns true if task was found and canceled, false otherwise */ public cancel(taskId: string): boolean { logger.debug("DEPRECATED: cancel() called, use cancelTask() instead"); return this.cancelTask(taskId); } /** * Cancel all tasks */ public cancelAll(): void { try { for (const task of this.tasks.values()) { task.cancel(); } logger.info(`Canceled all ${this.tasks.size} tasks`); } catch (error) { logger.error("Error canceling all tasks:", error); } } }