UNPKG

@xynehq/jaf

Version:

Juspay Agent Framework - A purely functional agent framework with immutable state and composable tools

435 lines 19.3 kB
/** * A2A Redis Task Provider for JAF * Pure functional Redis-based storage for A2A tasks */ import { createA2ATaskNotFoundError, createA2ATaskStorageError, createSuccess, createFailure } from '../types.js'; import { serializeA2ATask, deserializeA2ATask, sanitizeTask } from '../serialization.js'; /** * Create a Redis-based A2A task provider */ export const createA2ARedisTaskProvider = async (config, redisClient) => { const keyPrefix = config.keyPrefix || 'jaf:a2a:tasks:'; // Pure functions for key generation const getTaskKey = (taskId) => `${keyPrefix}task:${taskId}`; const getContextIndexKey = (contextId) => `${keyPrefix}context:${contextId}`; const getStateIndexKey = (state) => `${keyPrefix}state:${state}`; const getStatsKey = () => `${keyPrefix}stats`; const getMetaKey = (taskId) => `${keyPrefix}meta:${taskId}`; // Pure function to convert Redis hash to serialized task const hashToSerializedTask = (hash) => ({ taskId: hash.taskId, contextId: hash.contextId, state: hash.state, taskData: hash.taskData, statusMessage: hash.statusMessage, createdAt: hash.createdAt, updatedAt: hash.updatedAt, metadata: hash.metadata }); // Pure function to convert serialized task to Redis hash const serializedTaskToHash = (serialized) => { const hash = { taskId: serialized.taskId, contextId: serialized.contextId, state: serialized.state, taskData: serialized.taskData, createdAt: serialized.createdAt, updatedAt: serialized.updatedAt }; if (serialized.statusMessage) hash.statusMessage = serialized.statusMessage; if (serialized.metadata) hash.metadata = serialized.metadata; return hash; }; // Forward declare provider for recursive calls const provider = {}; const providerImpl = { storeTask: async (task, metadata) => { try { // Validate and sanitize task const sanitizeResult = sanitizeTask(task); if (!sanitizeResult.success) { return sanitizeResult; } // Serialize task const serializeResult = serializeA2ATask(sanitizeResult.data, metadata); if (!serializeResult.success) { return serializeResult; } const serialized = serializeResult.data; const taskKey = getTaskKey(task.id); const contextIndexKey = getContextIndexKey(task.contextId); const stateIndexKey = getStateIndexKey(task.status.state); // Use Redis transaction for atomicity const multi = redisClient.multi(); // Store task data as hash const taskHash = serializedTaskToHash(serialized); multi.hmset(taskKey, taskHash); // Set TTL if specified if (metadata?.expiresAt) { const ttlSeconds = Math.floor((metadata.expiresAt.getTime() - Date.now()) / 1000); if (ttlSeconds > 0) { multi.expire(taskKey, ttlSeconds); } } else if (config.defaultTtl) { multi.expire(taskKey, config.defaultTtl); } // Add to indices multi.sadd(contextIndexKey, task.id); multi.sadd(stateIndexKey, task.id); // Update stats multi.hincrby(getStatsKey(), 'totalTasks', 1); multi.hincrby(getStatsKey(), `state:${task.status.state}`, 1); await multi.exec(); return createSuccess(undefined); } catch (error) { return createFailure(createA2ATaskStorageError('store', 'redis', task.id, error)); } }, getTask: async (taskId) => { try { const taskKey = getTaskKey(taskId); const exists = await redisClient.exists(taskKey); if (!exists) { return createSuccess(null); } const hash = await redisClient.hgetall(taskKey); if (!hash || !hash.taskData) { return createSuccess(null); } const serialized = hashToSerializedTask(hash); const deserializeResult = deserializeA2ATask(serialized); if (!deserializeResult.success) { return deserializeResult; } return createSuccess(deserializeResult.data); } catch (error) { return createFailure(createA2ATaskStorageError('get', 'redis', taskId, error)); } }, updateTask: async (task, metadata) => { try { const taskKey = getTaskKey(task.id); const exists = await redisClient.exists(taskKey); if (!exists) { return createFailure(createA2ATaskNotFoundError(task.id, 'redis')); } // Get existing task to check for state changes const existingHash = await redisClient.hgetall(taskKey); const oldState = existingHash.state; // Validate and sanitize task const sanitizeResult = sanitizeTask(task); if (!sanitizeResult.success) { return sanitizeResult; } // Merge metadata const existingMetadata = existingHash.metadata ? JSON.parse(existingHash.metadata) : {}; const mergedMetadata = { ...existingMetadata, ...metadata }; // Serialize updated task const serializeResult = serializeA2ATask(sanitizeResult.data, mergedMetadata); if (!serializeResult.success) { return serializeResult; } const serialized = serializeResult.data; const multi = redisClient.multi(); // Update task data const taskHash = serializedTaskToHash(serialized); multi.hmset(taskKey, taskHash); // Update indices if state changed if (oldState !== task.status.state) { const oldStateIndexKey = getStateIndexKey(oldState); const newStateIndexKey = getStateIndexKey(task.status.state); multi.srem(oldStateIndexKey, task.id); multi.sadd(newStateIndexKey, task.id); // Update stats multi.hincrby(getStatsKey(), `state:${oldState}`, -1); multi.hincrby(getStatsKey(), `state:${task.status.state}`, 1); } await multi.exec(); return createSuccess(undefined); } catch (error) { return createFailure(createA2ATaskStorageError('update', 'redis', task.id, error)); } }, updateTaskStatus: async (taskId, newState, statusMessage, timestamp) => { try { const taskKey = getTaskKey(taskId); const exists = await redisClient.exists(taskKey); if (!exists) { return createFailure(createA2ATaskNotFoundError(taskId, 'redis')); } // Get existing task const hash = await redisClient.hgetall(taskKey); const serialized = hashToSerializedTask(hash); const deserializeResult = deserializeA2ATask(serialized); if (!deserializeResult.success) { return deserializeResult; } const task = deserializeResult.data; // Update task status const updatedTask = { ...task, status: { ...task.status, state: newState, message: statusMessage || task.status.message, timestamp: timestamp || new Date().toISOString() } }; return providerImpl.updateTask(updatedTask); } catch (error) { return createFailure(createA2ATaskStorageError('update-status', 'redis', taskId, error)); } }, findTasks: async (query) => { try { let taskIds = []; if (query.contextId) { // Get tasks by context const contextIndexKey = getContextIndexKey(query.contextId); taskIds = await redisClient.smembers(contextIndexKey); } else if (query.state) { // Get tasks by state const stateIndexKey = getStateIndexKey(query.state); taskIds = await redisClient.smembers(stateIndexKey); } else { // Get all task keys and extract IDs const pattern = `${keyPrefix}task:*`; const keys = await redisClient.keys(pattern); taskIds = keys.map((key) => key.replace(`${keyPrefix}task:`, '')); } // Filter by specific task ID if provided if (query.taskId) { taskIds = taskIds.filter(id => id === query.taskId); } // Fetch tasks and apply additional filters const results = []; for (const taskId of taskIds) { const taskKey = getTaskKey(taskId); const exists = await redisClient.exists(taskKey); if (!exists) continue; const hash = await redisClient.hgetall(taskKey); if (!hash || !hash.taskData) continue; // Apply date filters if (query.since) { const createdAt = new Date(hash.createdAt); if (createdAt < query.since) continue; } if (query.until) { const createdAt = new Date(hash.createdAt); if (createdAt > query.until) continue; } const serialized = hashToSerializedTask(hash); const deserializeResult = deserializeA2ATask(serialized); if (deserializeResult.success) { results.push(deserializeResult.data); } } // Sort by timestamp (newest first) results.sort((a, b) => { const timeA = new Date(a.status.timestamp || '').getTime(); const timeB = new Date(b.status.timestamp || '').getTime(); return timeB - timeA; }); // Apply pagination const offset = query.offset || 0; const limit = query.limit || results.length; const paginatedResults = results.slice(offset, offset + limit); return createSuccess(paginatedResults); } catch (error) { return createFailure(createA2ATaskStorageError('find', 'redis', undefined, error)); } }, getTasksByContext: async (contextId, limit) => { return providerImpl.findTasks({ contextId, limit }); }, deleteTask: async (taskId) => { try { const taskKey = getTaskKey(taskId); const exists = await redisClient.exists(taskKey); if (!exists) { return createSuccess(false); } // Get task data for index cleanup const hash = await redisClient.hgetall(taskKey); const contextId = hash.contextId; const state = hash.state; const multi = redisClient.multi(); // Delete task multi.del(taskKey); // Remove from indices if (contextId) { const contextIndexKey = getContextIndexKey(contextId); multi.srem(contextIndexKey, taskId); } if (state) { const stateIndexKey = getStateIndexKey(state); multi.srem(stateIndexKey, taskId); } // Update stats multi.hincrby(getStatsKey(), 'totalTasks', -1); if (state) { multi.hincrby(getStatsKey(), `state:${state}`, -1); } await multi.exec(); return createSuccess(true); } catch (error) { return createFailure(createA2ATaskStorageError('delete', 'redis', taskId, error)); } }, deleteTasksByContext: async (contextId) => { try { const contextIndexKey = getContextIndexKey(contextId); const taskIds = await redisClient.smembers(contextIndexKey); if (taskIds.length === 0) { return createSuccess(0); } let deletedCount = 0; for (const taskId of taskIds) { const deleteResult = await providerImpl.deleteTask(taskId); if (deleteResult.success && deleteResult.data) { deletedCount++; } } return createSuccess(deletedCount); } catch (error) { return createFailure(createA2ATaskStorageError('delete-by-context', 'redis', undefined, error)); } }, cleanupExpiredTasks: async () => { try { // Redis automatically handles TTL expiration, but we can clean up orphaned indices const pattern = `${keyPrefix}task:*`; const taskKeys = await redisClient.keys(pattern); let cleanedCount = 0; for (const taskKey of taskKeys) { const exists = await redisClient.exists(taskKey); if (!exists) { // This shouldn't happen with Redis TTL, but clean up if needed const taskId = taskKey.replace(`${keyPrefix}task:`, ''); const deleteResult = await providerImpl.deleteTask(taskId); if (deleteResult.success && deleteResult.data) { cleanedCount++; } } } return createSuccess(cleanedCount); } catch (error) { return createFailure(createA2ATaskStorageError('cleanup', 'redis', undefined, error)); } }, getTaskStats: async (contextId) => { try { const tasksByState = { submitted: 0, working: 0, 'input-required': 0, completed: 0, canceled: 0, failed: 0, rejected: 0, 'auth-required': 0, unknown: 0 }; let totalTasks = 0; let oldestTask; let newestTask; if (contextId) { // Get tasks for specific context const contextIndexKey = getContextIndexKey(contextId); const taskIds = await redisClient.smembers(contextIndexKey); for (const taskId of taskIds) { const taskKey = getTaskKey(taskId); const exists = await redisClient.exists(taskKey); if (!exists) continue; const hash = await redisClient.hgetall(taskKey); if (!hash) continue; totalTasks++; const state = hash.state; if (state) { tasksByState[state]++; } const createdAt = new Date(hash.createdAt); if (!oldestTask || createdAt < oldestTask) { oldestTask = createdAt; } if (!newestTask || createdAt > newestTask) { newestTask = createdAt; } } } else { // Get global stats from Redis hash const statsKey = getStatsKey(); const stats = await redisClient.hgetall(statsKey); totalTasks = parseInt(stats.totalTasks || '0'); // Get state counts for (const state of Object.keys(tasksByState)) { tasksByState[state] = parseInt(stats[`state:${state}`] || '0'); } // For global stats, we'd need to scan all tasks to get date ranges // This is expensive, so we'll leave them undefined for now } return createSuccess({ totalTasks, tasksByState, oldestTask, newestTask }); } catch (error) { return createFailure(createA2ATaskStorageError('stats', 'redis', undefined, error)); } }, healthCheck: async () => { try { const startTime = Date.now(); // Simple ping to Redis await redisClient.ping(); const latencyMs = Date.now() - startTime; return createSuccess({ healthy: true, latencyMs }); } catch (error) { return createSuccess({ healthy: false, error: error.message }); } }, close: async () => { try { // Redis client cleanup is typically handled externally // We don't close the client here as it might be shared return createSuccess(undefined); } catch (error) { return createFailure(createA2ATaskStorageError('close', 'redis', undefined, error)); } } }; // Set up the provider variable Object.assign(provider, providerImpl); return provider; }; //# sourceMappingURL=redis.js.map