@copilotkit/runtime
Version:
<img src="https://github.com/user-attachments/assets/0a6b64d9-e193-4940-a3f6-60334ac34084" alt="banner" style="border-radius: 12px; border: 2px solid #d6d4fa;" />
295 lines (293 loc) • 8.57 kB
JavaScript
import "reflect-metadata";
import { runOnBeforeHandler, runOnError, runOnRequest, runOnResponse } from "./hooks.mjs";
import { addCorsHeaders, handleCors } from "./fetch-cors.mjs";
import { matchRoute } from "./fetch-router.mjs";
import { callAfterRequestMiddleware, callBeforeRequestMiddleware } from "./middleware.mjs";
import { handleRunAgent } from "../handlers/handle-run.mjs";
import { handleConnectAgent } from "../handlers/handle-connect.mjs";
import { handleStopAgent } from "../handlers/handle-stop.mjs";
import { handleGetRuntimeInfo } from "../handlers/get-runtime-info.mjs";
import { handleTranscribe } from "../handlers/handle-transcribe.mjs";
import { handleDebugEvents } from "../handlers/handle-debug-events.mjs";
import { handleArchiveThread, handleClearThreads, handleDeleteThread, handleGetThreadEvents, handleGetThreadMessages, handleGetThreadState, handleListThreads, handleSubscribeToThreads, handleUpdateThread } from "../handlers/intelligence/threads.mjs";
import { createJsonRequest, expectString, parseMethodCall } from "../endpoints/single-route-helpers.mjs";
import { fireInstanceCreatedTelemetry } from "../telemetry/instance-created.mjs";
import { logger } from "@copilotkit/shared";
//#region src/v2/runtime/core/fetch-handler.ts
function createCopilotRuntimeHandler(options) {
const { runtime, basePath, mode = "multi-route", cors, hooks } = options;
fireInstanceCreatedTelemetry({ runtime });
const corsConfig = resolveCorsConfig(cors);
return async (request) => {
const path = new URL(request.url, "http://localhost").pathname;
const requestOrigin = request.headers.get("origin");
const baseCtx = {
request,
path,
runtime
};
let route;
try {
if (corsConfig) {
const preflight = handleCors(request, corsConfig);
if (preflight) return preflight;
}
request = await runOnRequest(hooks, {
...baseCtx,
request
});
try {
const maybeModified = await callBeforeRequestMiddleware({
runtime,
request,
path
});
if (maybeModified) request = maybeModified;
} catch (mwError) {
logger.error({
err: mwError,
url: request.url,
path
}, "Error running before request middleware");
if (mwError instanceof Response) return maybeAddCors(mwError, corsConfig, requestOrigin);
throw mwError;
}
let response;
if (mode === "single-route") {
const resolved = await resolveSingleRoute(request, basePath, path);
route = resolved.route;
const { methodCall } = resolved;
request = await runOnBeforeHandler(hooks, {
request,
path,
runtime,
route
});
if (route.method === "agent/run" || route.method === "agent/connect" || route.method === "transcribe") request = createJsonRequest(request, methodCall.body);
response = await dispatchRoute(runtime, request, route);
} else {
const matched = matchRoute(path, basePath);
if (!matched) throw jsonResponse({ error: "Not found" }, 404);
const methodError = validateHttpMethod(request.method, matched);
if (methodError) {
route = matched;
throw methodError;
}
route = matched;
request = await runOnBeforeHandler(hooks, {
request,
path,
runtime,
route
});
response = await dispatchRoute(runtime, request, route);
}
response = await runOnResponse(hooks, {
request,
response,
path,
runtime,
route
});
response = maybeAddCors(response, corsConfig, requestOrigin);
callAfterRequestMiddleware({
runtime,
response: response.clone(),
path
}).catch((error) => {
logger.error({
err: error,
url: request.url,
path
}, "Error running after request middleware");
});
return response;
} catch (error) {
if (error instanceof Response) return maybeAddCors(await runOnResponse(hooks, {
request,
response: error,
path,
runtime,
route: route ?? { method: "info" }
}), corsConfig, requestOrigin);
try {
const errorResponse = await runOnError(hooks, {
request,
error,
path,
runtime,
route
});
if (errorResponse) return maybeAddCors(errorResponse, corsConfig, requestOrigin);
} catch (hookError) {
logger.error({
err: hookError,
originalErr: error,
url: request.url,
path
}, "onError hook threw");
}
logger.error({
err: error,
url: request.url,
path
}, "Unhandled error in CopilotKit runtime handler");
return maybeAddCors(jsonResponse({ error: "internal_error" }, 500), corsConfig, requestOrigin);
}
};
}
function dispatchRoute(runtime, request, route) {
switch (route.method) {
case "agent/run": return handleRunAgent({
runtime,
request,
agentId: route.agentId
});
case "agent/connect": return handleConnectAgent({
runtime,
request,
agentId: route.agentId
});
case "agent/stop": return handleStopAgent({
runtime,
request,
agentId: route.agentId,
threadId: route.threadId
});
case "info": return handleGetRuntimeInfo({
runtime,
request
});
case "transcribe": return handleTranscribe({
runtime,
request
});
case "threads/clear": return Promise.resolve(handleClearThreads({
runtime,
request
}));
case "threads/list": return handleListThreads({
runtime,
request
});
case "threads/subscribe": return handleSubscribeToThreads({
runtime,
request
});
case "threads/update":
if (request.method.toUpperCase() === "DELETE") return handleDeleteThread({
runtime,
request,
threadId: route.threadId
});
return handleUpdateThread({
runtime,
request,
threadId: route.threadId
});
case "threads/archive": return handleArchiveThread({
runtime,
request,
threadId: route.threadId
});
case "threads/messages": return handleGetThreadMessages({
runtime,
request,
threadId: route.threadId
});
case "threads/events": return handleGetThreadEvents({
runtime,
request,
threadId: route.threadId
});
case "threads/state": return handleGetThreadState({
runtime,
request,
threadId: route.threadId
});
case "cpk-debug-events": return Promise.resolve(handleDebugEvents({
runtime,
request
}));
}
}
async function resolveSingleRoute(request, basePath, pathname) {
if (basePath) {
const normalizedBase = basePath.length > 1 && basePath.endsWith("/") ? basePath.slice(0, -1) : basePath;
if (!pathname.startsWith(normalizedBase)) throw jsonResponse({ error: "Not found" }, 404);
}
if (request.method !== "POST") throw jsonResponse({ error: "Method not allowed" }, 405, { Allow: "POST" });
const methodCall = await parseMethodCall(request);
let route;
switch (methodCall.method) {
case "agent/run":
route = {
method: "agent/run",
agentId: expectString(methodCall.params, "agentId")
};
break;
case "agent/connect":
route = {
method: "agent/connect",
agentId: expectString(methodCall.params, "agentId")
};
break;
case "agent/stop":
route = {
method: "agent/stop",
agentId: expectString(methodCall.params, "agentId"),
threadId: expectString(methodCall.params, "threadId")
};
break;
case "info":
route = { method: "info" };
break;
case "transcribe":
route = { method: "transcribe" };
break;
}
return {
route,
methodCall
};
}
function validateHttpMethod(httpMethod, route) {
const method = httpMethod.toUpperCase();
switch (route.method) {
case "info":
case "threads/list":
case "threads/messages":
case "threads/events":
case "threads/state":
case "cpk-debug-events":
if (method === "GET") return null;
return jsonResponse({ error: "Method not allowed" }, 405, { Allow: "GET" });
case "threads/update":
if (method === "PATCH" || method === "DELETE") return null;
return jsonResponse({ error: "Method not allowed" }, 405, { Allow: "PATCH, DELETE" });
default:
if (method === "POST") return null;
return jsonResponse({ error: "Method not allowed" }, 405, { Allow: "POST" });
}
}
function resolveCorsConfig(cors) {
if (!cors) return null;
if (cors === true) return {};
return cors;
}
function maybeAddCors(response, config, requestOrigin) {
if (!config) return response;
return addCorsHeaders(response, config, requestOrigin);
}
function jsonResponse(body, status, extraHeaders) {
return new Response(JSON.stringify(body), {
status,
headers: {
"Content-Type": "application/json",
...extraHeaders
}
});
}
//#endregion
export { createCopilotRuntimeHandler };
//# sourceMappingURL=fetch-handler.mjs.map