UNPKG

@artinet/sdk

Version:
247 lines 11.3 kB
import { defaultCancelTaskMethod, createJSONRPCMethod, defaultGetTaskPushNotificationMethod, defaultSetTaskPushNotificationMethod, defaultGetTaskMethod, defaultSendTaskMethod, } from "../../server/index.js"; import { TaskState, } from "../../types/index.js"; import { Protocol } from "../../types/services/protocol.js"; import { loadState, processUpdate } from "../../server/lib/state.js"; import { FINAL_STATES, getCurrentTimestamp, INVALID_PARAMS, INVALID_REQUEST, logError, METHOD_NOT_FOUND, TASK_NOT_FOUND, validateSendMessageParams, WORKING_UPDATE, } from "../../utils/index.js"; import { processTaskStream } from "../../transport/streaming/stream.js"; import { sendSSEEvent, setupSseStream } from "../../index.js"; import { A2ARepository } from "./repository.js"; export class A2AService { name = "a2a"; protocol = Protocol.A2A; engine; state; constructor(options) { this.engine = options.engine; this.state = new A2ARepository(options); } /** * Handles the message/stream method. * @param req The SendTaskRequest object * @param res The Express Response object */ async handleSendStreamingMessage(req, res) { validateSendMessageParams(req.params); const { message, metadata } = req.params; if (!message.taskId) { throw INVALID_PARAMS("Missing task ID"); } const taskId = message.taskId; let contextId = message.contextId ?? "unknown"; const executionContext = { id: taskId, protocol: Protocol.A2A, getRequestParams: () => req.params, isCancelled: () => this.state.activeCancellations.has(taskId), }; // Set up SSE stream with initial status setupSseStream(res, taskId, { taskId: taskId, contextId: contextId, kind: "status-update", status: { state: TaskState.Submitted, timestamp: getCurrentTimestamp(), }, final: false, }, this.state.addStreamForTask.bind(this.state)); // Load or create task let currentData = await loadState(this.state.getTaskStore(), message, metadata, taskId, contextId); // Create task context const context = this.state.createTaskContext(currentData.task, message, currentData.history); contextId = currentData.task.contextId || contextId; const workingUpdate = WORKING_UPDATE(taskId, contextId); currentData = await processUpdate(this.state.getTaskStore(), { context: context, current: currentData, update: workingUpdate, }); // Send the working status sendSSEEvent(res, currentData.task.id, workingUpdate); // Process the task using the shared method await processTaskStream(context, this.state.getTaskStore(), this.engine, res, taskId, currentData, this.state.onCancel.bind(this.state), this.state.onEnd.bind(this.state), executionContext); } /** * Handles the tasks/resubscribe method. * @param req The TaskResubscriptionRequest object * @param res The Express Response object */ async handleTaskResubscribe(req, res) { const { id: taskId } = req.params; if (!taskId) { logError("A2AService", "Task ID is required", req); throw INVALID_PARAMS("Missing task ID"); } const executionContext = { id: taskId, protocol: Protocol.A2A, getRequestParams: () => req.params, isCancelled: () => this.state.activeCancellations.has(taskId), }; // Try to load the task const data = await this.state.getTaskStore().load(taskId); if (!data) { throw TASK_NOT_FOUND("Task Id: " + taskId); } // Set up SSE stream with current task status setupSseStream(res, taskId, { taskId: taskId, contextId: data.task.contextId || "unknown", kind: "status-update", status: data.task.status, final: false, metadata: data.task.metadata, }, this.state.addStreamForTask.bind(this.state)); // Check if task is in final state if (FINAL_STATES.includes(data.task.status.state)) { // If the task is already complete, send all artifacts and close if (data.task.artifacts && data.task.artifacts.length > 0) { for (const artifact of data.task.artifacts) { const response = { taskId: taskId, contextId: data.task.contextId || "unknown", kind: "artifact-update", artifact, lastChunk: true, metadata: data.task.metadata, }; sendSSEEvent(res, taskId, response); } } // Remove from tracking and close this.state.removeStreamForTask(taskId, res); res.write("event: close\ndata: {}\n\n"); res.end(); return; } // For non-final states, create context and continue processing // We need to use the last user message as the current message const lastUserMessage = data.history .filter((msg) => msg.role === "user") .pop(); if (!lastUserMessage) { throw INVALID_REQUEST("No user message found"); } const context = this.state.createTaskContext(data.task, lastUserMessage, data.history); // Continue processing the task using the shared method await processTaskStream(context, this.state.getTaskStore(), this.engine, res, taskId, data, this.state.onCancel.bind(this.state), this.state.onEnd.bind(this.state), executionContext); } /** * Executes a method on the A2A service. * @param executionContext The execution context. * @param engine The agent engine. */ async execute({ executionContext, engine, }) { if (!executionContext.requestContext) { throw INVALID_REQUEST({ message: "Invalid request", data: { method: "unknown", params: executionContext.getRequestParams(), }, }); } if (!executionContext.requestContext?.method) { throw METHOD_NOT_FOUND({ method: "unknown" }); } //todo better callback sanitization let closeConnection = false; const callback = (error, result) => { const responseHandler = executionContext.requestContext?.response; if (error) { responseHandler.status(200); responseHandler.send({ jsonrpc: "2.0", id: executionContext.id, error: error, }); closeConnection = true; } else { responseHandler.status(200); responseHandler.send({ jsonrpc: "2.0", id: executionContext.id, result, }); } if (closeConnection) { responseHandler.end(); } }; switch (executionContext.requestContext?.method) { case "message/send": case "tasks/get": case "tasks/cancel": case "tasks/pushNotificationConfig/set": case "tasks/pushNotificationConfig/get": closeConnection = true; return await A2AService.dispatchMethod(executionContext.requestContext.method, executionContext.requestContext.params, callback, { taskStore: this.state.getTaskStore(), card: this.state.getCard(), taskHandler: engine, activeCancellations: this.state.activeCancellations, createTaskContext: this.state.createTaskContext.bind(this.state), closeStreamsForTask: this.state.closeStreamsForTask.bind(this.state), }).catch((error) => { logError("A2AService", `Error dispatching method: ${executionContext.requestContext?.method}`, error); closeConnection = true; callback(error, null); }); case "message/stream": { //todo make the following functions leverage callback const params = executionContext.requestContext.request .body; return await this.handleSendStreamingMessage(params, executionContext.requestContext.response).catch((error) => { logError("A2AService", `Error dispatching method: ${executionContext.requestContext?.method}`, error); closeConnection = true; callback(error, null); }); } break; case "tasks/resubscribe": { const params = executionContext.requestContext.request .body; return await this.handleTaskResubscribe(params, executionContext.requestContext.response).catch((error) => { logError("A2AService", `Error dispatching method: ${executionContext.requestContext?.method}`, error); closeConnection = true; callback(error, null); }); } break; default: logError("A2AService", `Unknown method: ${executionContext.requestContext?.method}`, null); callback(METHOD_NOT_FOUND({ method: executionContext.requestContext?.method }), null); break; } } async stop() { await this.state.destroy(); } /** * Dispatches a method to the A2A service. * @param method The method to dispatch. * @param params The parameters to dispatch. * @param callback The callback to dispatch. * @param deps The dependencies to dispatch. */ static async dispatchMethod(method, params, callback, deps) { switch (method) { case "message/send": return await createJSONRPCMethod(deps, defaultSendTaskMethod, method)(params, callback); case "tasks/get": return await createJSONRPCMethod(deps, defaultGetTaskMethod, method)(params, callback); case "tasks/cancel": return await createJSONRPCMethod(deps, defaultCancelTaskMethod, method)(params, callback); case "tasks/pushNotificationConfig/set": return await createJSONRPCMethod(deps, defaultSetTaskPushNotificationMethod, method)(params, callback); case "tasks/pushNotificationConfig/get": return await createJSONRPCMethod(deps, defaultGetTaskPushNotificationMethod, method)(params, callback); default: throw new Error(`Unknown method: ${method}`); } } } //# sourceMappingURL=service.js.map