UNPKG

@copilotkit/runtime

Version:

<img src="https://github.com/user-attachments/assets/0a6b64d9-e193-4940-a3f6-60334ac34084" alt="banner" style="border-radius: 12px; border: 2px solid #d6d4fa;" />

295 lines (293 loc) • 8.57 kB
import "reflect-metadata"; import { runOnBeforeHandler, runOnError, runOnRequest, runOnResponse } from "./hooks.mjs"; import { addCorsHeaders, handleCors } from "./fetch-cors.mjs"; import { matchRoute } from "./fetch-router.mjs"; import { callAfterRequestMiddleware, callBeforeRequestMiddleware } from "./middleware.mjs"; import { handleRunAgent } from "../handlers/handle-run.mjs"; import { handleConnectAgent } from "../handlers/handle-connect.mjs"; import { handleStopAgent } from "../handlers/handle-stop.mjs"; import { handleGetRuntimeInfo } from "../handlers/get-runtime-info.mjs"; import { handleTranscribe } from "../handlers/handle-transcribe.mjs"; import { handleDebugEvents } from "../handlers/handle-debug-events.mjs"; import { handleArchiveThread, handleClearThreads, handleDeleteThread, handleGetThreadEvents, handleGetThreadMessages, handleGetThreadState, handleListThreads, handleSubscribeToThreads, handleUpdateThread } from "../handlers/intelligence/threads.mjs"; import { createJsonRequest, expectString, parseMethodCall } from "../endpoints/single-route-helpers.mjs"; import { fireInstanceCreatedTelemetry } from "../telemetry/instance-created.mjs"; import { logger } from "@copilotkit/shared"; //#region src/v2/runtime/core/fetch-handler.ts function createCopilotRuntimeHandler(options) { const { runtime, basePath, mode = "multi-route", cors, hooks } = options; fireInstanceCreatedTelemetry({ runtime }); const corsConfig = resolveCorsConfig(cors); return async (request) => { const path = new URL(request.url, "http://localhost").pathname; const requestOrigin = request.headers.get("origin"); const baseCtx = { request, path, runtime }; let route; try { if (corsConfig) { const preflight = handleCors(request, corsConfig); if (preflight) return preflight; } request = await runOnRequest(hooks, { ...baseCtx, request }); try { const maybeModified = await callBeforeRequestMiddleware({ runtime, request, path }); if (maybeModified) request = maybeModified; } catch (mwError) { logger.error({ err: mwError, url: request.url, path }, "Error running before request middleware"); if (mwError instanceof Response) return maybeAddCors(mwError, corsConfig, requestOrigin); throw mwError; } let response; if (mode === "single-route") { const resolved = await resolveSingleRoute(request, basePath, path); route = resolved.route; const { methodCall } = resolved; request = await runOnBeforeHandler(hooks, { request, path, runtime, route }); if (route.method === "agent/run" || route.method === "agent/connect" || route.method === "transcribe") request = createJsonRequest(request, methodCall.body); response = await dispatchRoute(runtime, request, route); } else { const matched = matchRoute(path, basePath); if (!matched) throw jsonResponse({ error: "Not found" }, 404); const methodError = validateHttpMethod(request.method, matched); if (methodError) { route = matched; throw methodError; } route = matched; request = await runOnBeforeHandler(hooks, { request, path, runtime, route }); response = await dispatchRoute(runtime, request, route); } response = await runOnResponse(hooks, { request, response, path, runtime, route }); response = maybeAddCors(response, corsConfig, requestOrigin); callAfterRequestMiddleware({ runtime, response: response.clone(), path }).catch((error) => { logger.error({ err: error, url: request.url, path }, "Error running after request middleware"); }); return response; } catch (error) { if (error instanceof Response) return maybeAddCors(await runOnResponse(hooks, { request, response: error, path, runtime, route: route ?? { method: "info" } }), corsConfig, requestOrigin); try { const errorResponse = await runOnError(hooks, { request, error, path, runtime, route }); if (errorResponse) return maybeAddCors(errorResponse, corsConfig, requestOrigin); } catch (hookError) { logger.error({ err: hookError, originalErr: error, url: request.url, path }, "onError hook threw"); } logger.error({ err: error, url: request.url, path }, "Unhandled error in CopilotKit runtime handler"); return maybeAddCors(jsonResponse({ error: "internal_error" }, 500), corsConfig, requestOrigin); } }; } function dispatchRoute(runtime, request, route) { switch (route.method) { case "agent/run": return handleRunAgent({ runtime, request, agentId: route.agentId }); case "agent/connect": return handleConnectAgent({ runtime, request, agentId: route.agentId }); case "agent/stop": return handleStopAgent({ runtime, request, agentId: route.agentId, threadId: route.threadId }); case "info": return handleGetRuntimeInfo({ runtime, request }); case "transcribe": return handleTranscribe({ runtime, request }); case "threads/clear": return Promise.resolve(handleClearThreads({ runtime, request })); case "threads/list": return handleListThreads({ runtime, request }); case "threads/subscribe": return handleSubscribeToThreads({ runtime, request }); case "threads/update": if (request.method.toUpperCase() === "DELETE") return handleDeleteThread({ runtime, request, threadId: route.threadId }); return handleUpdateThread({ runtime, request, threadId: route.threadId }); case "threads/archive": return handleArchiveThread({ runtime, request, threadId: route.threadId }); case "threads/messages": return handleGetThreadMessages({ runtime, request, threadId: route.threadId }); case "threads/events": return handleGetThreadEvents({ runtime, request, threadId: route.threadId }); case "threads/state": return handleGetThreadState({ runtime, request, threadId: route.threadId }); case "cpk-debug-events": return Promise.resolve(handleDebugEvents({ runtime, request })); } } async function resolveSingleRoute(request, basePath, pathname) { if (basePath) { const normalizedBase = basePath.length > 1 && basePath.endsWith("/") ? basePath.slice(0, -1) : basePath; if (!pathname.startsWith(normalizedBase)) throw jsonResponse({ error: "Not found" }, 404); } if (request.method !== "POST") throw jsonResponse({ error: "Method not allowed" }, 405, { Allow: "POST" }); const methodCall = await parseMethodCall(request); let route; switch (methodCall.method) { case "agent/run": route = { method: "agent/run", agentId: expectString(methodCall.params, "agentId") }; break; case "agent/connect": route = { method: "agent/connect", agentId: expectString(methodCall.params, "agentId") }; break; case "agent/stop": route = { method: "agent/stop", agentId: expectString(methodCall.params, "agentId"), threadId: expectString(methodCall.params, "threadId") }; break; case "info": route = { method: "info" }; break; case "transcribe": route = { method: "transcribe" }; break; } return { route, methodCall }; } function validateHttpMethod(httpMethod, route) { const method = httpMethod.toUpperCase(); switch (route.method) { case "info": case "threads/list": case "threads/messages": case "threads/events": case "threads/state": case "cpk-debug-events": if (method === "GET") return null; return jsonResponse({ error: "Method not allowed" }, 405, { Allow: "GET" }); case "threads/update": if (method === "PATCH" || method === "DELETE") return null; return jsonResponse({ error: "Method not allowed" }, 405, { Allow: "PATCH, DELETE" }); default: if (method === "POST") return null; return jsonResponse({ error: "Method not allowed" }, 405, { Allow: "POST" }); } } function resolveCorsConfig(cors) { if (!cors) return null; if (cors === true) return {}; return cors; } function maybeAddCors(response, config, requestOrigin) { if (!config) return response; return addCorsHeaders(response, config, requestOrigin); } function jsonResponse(body, status, extraHeaders) { return new Response(JSON.stringify(body), { status, headers: { "Content-Type": "application/json", ...extraHeaders } }); } //#endregion export { createCopilotRuntimeHandler }; //# sourceMappingURL=fetch-handler.mjs.map