UNPKG

genaiscript

Version:

A CLI for GenAIScript, a generative AI scripting framework.

717 lines 31.1 kB
/* eslint-disable no-param-reassign */ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. /* eslint-disable @typescript-eslint/explicit-function-return-type */ import { WebSocketServer } from "ws"; import { runPromptScriptTests } from "./test.js"; import { PROMPTFOO_VERSION } from "@genaiscript/runtime"; import { runScriptInternal } from "@genaiscript/api"; import { CORE_VERSION, LOG, MODEL_PROVIDER_GITHUB_COPILOT_CHAT, SERVER_PORT, TRACE_CHUNK, TRACE_FILENAME, UNHANDLED_ERROR_CODE, USER_CANCELLED_ERROR_CODE, WS_MAX_FRAME_CHUNK_LENGTH, WS_MAX_FRAME_LENGTH, AbortSignalCancellationController, MarkdownTrace, assert, chunkLines, chunkString, deleteUndefinedValues, generateId, genaiscriptDebug, isCancelError, logError, logVerbose, nodeTryReadPackage, randomHex, resolveLanguageModelConfigurations, serializeError, tryReadJSON, tryReadText, unthink, getModulePaths, sanitizeFilename, resolveRuntimeHost, } from "@genaiscript/core"; import { createReadStream } from "node:fs"; import { URL } from "node:url"; import { findOpenPort } from "./port.js"; import { applyRemoteOptions } from "./remote.js"; import * as http from "node:http"; import { startProjectWatcher } from "./watch.js"; import { extname, join, resolve } from "node:path"; import { readFile, realpath } from "node:fs/promises"; import { tryStat } from "@genaiscript/core"; import { collectRuns } from "./runs.js"; import { openaiApiChatCompletions, openaiApiModels } from "./openaiapi.js"; import { networkInterfaces } from "node:os"; const dbg = genaiscriptDebug("server"); const { __dirname } = typeof module !== "undefined" && module.filename ? getModulePaths(module) : // eslint-disable-next-line @typescript-eslint/ban-ts-comment // @ts-ignore getModulePaths(import.meta); /** * Starts a WebSocket server for handling chat and script execution. * * @param options - Configuration options including: * - port: The port to run the WebSocket server on. * - httpPort: Optional HTTP port for additional services. * - apiKey: Optional API key for authentication. * - cors: Optional CORS configuration. * - network: Whether to allow network access. * - dispatchProgress: Whether to dispatch progress updates to all clients. * - githubCopilotChatClient: Whether to enable GitHub Copilot Chat client integration. * - remote: Remote configuration options. * - remoteBranch: Optional branch name for remote configuration. */ export async function startServer(options) { const runtimeHost = resolveRuntimeHost(); // Parse and set the server port, using a default if not specified. const corsOrigin = options.cors || process.env.GENAISCRIPT_CORS_ORIGIN; const apiKey = options.apiKey || process.env.GENAISCRIPT_API_KEY; const serverHost = options.network ? "0.0.0.0" : "127.0.0.1"; const remote = options.remote; const dispatchProgress = !!options.dispatchProgress; const openAIChatCompletions = !!options.chat; const runTrace = !!options.runTrace; const port = await findOpenPort(SERVER_PORT, options); await applyRemoteOptions(options); const watcher = await startProjectWatcher({}); // read current project info const { name, displayName, description, version, homepage, author } = (await nodeTryReadPackage()) || {}; const readme = (await tryReadText("README.genai.md")) || (await tryReadText("README.md")); const wss = new WebSocketServer({ noServer: true }); const dirname = resolve(__dirname, ".."); // Stores active script runs with their cancellation controllers and traces. let lastRunResult = undefined; const runs = {}; // Stores active chat handlers. const chats = {}; const toPayload = (payload) => { const msg = JSON.stringify(payload); if (msg.length > WS_MAX_FRAME_LENGTH) { throw new Error(`server: message too large (${msg.length} > ${WS_MAX_FRAME_LENGTH})`); } return msg; }; // Cancels all active runs and chats. const cancelAll = () => { for (const [runId, run] of Object.entries(runs)) { logVerbose(`abort run ${runId}`); run.canceller.abort("closing"); delete runs[runId]; } for (const [chatId, chat] of Object.entries(chats)) { logVerbose(`abort chat ${chat}`); for (const ws of wss.clients) { ws.send(toPayload({ type: "chat.cancel", chatId, })); break; } delete chats[chatId]; } }; // Handles incoming chat chunks and calls the appropriate handler. const handleChunk = async (chunk) => { const handler = chats[chunk.chatId]; if (handler) { if (chunk.finishReason) delete chats[chunk.chatId]; await handler(chunk); } }; const checkApiKey = (req) => { if (!apiKey) return true; const { authorization } = req.headers; if (authorization === apiKey || `Bearer ${apiKey}`) return true; const url = req.url.replace(/^[^?]*\?/, ""); const search = new URLSearchParams(url); const hash = search.get("api-key"); if (hash === apiKey) return true; logError(`clients: connection unauthorized ${url}`); logVerbose(`url :${req.url}`); logVerbose(`key :${apiKey}`); logVerbose(`auth:${authorization}`); logVerbose(`hash:${hash}`); return false; }; const serverVersion = () => ({ ok: true, version: CORE_VERSION, node: process.version, platform: process.platform, arch: process.arch, pid: process.pid, }); const serverEnv = async () => { return deleteUndefinedValues({ ok: true, providers: (await resolveLanguageModelConfigurations(undefined, { token: true, error: true, models: true, })).map(({ token, ...rest }) => rest), modelAliases: runtimeHost.modelAliases, remote: remote ? { url: remote, branch: options.remoteBranch, } : undefined, configuration: deleteUndefinedValues({ name: displayName || name, description, version, homepage, author, readme, }), }); }; const scriptList = async () => { logVerbose(`project: list scripts`); const project = await watcher.project(); const scripts = project?.scripts || []; logVerbose(`project: found ${scripts.filter((s) => !s.unlisted).length} scripts (${scripts.filter((s) => !!s.unlisted).length} unlisted)`); return { ok: true, status: 0, project, }; }; // Configures the client language model with a completer function. if (options?.githubCopilotChatClient) { runtimeHost.clientLanguageModel = Object.freeze({ id: MODEL_PROVIDER_GITHUB_COPILOT_CHAT, completer: async (req, connection, options, trace) => { const { messages, model } = req; const { partialCb, inner } = options; if (!wss.clients?.size) throw new Error("GitHub Copilot Chat Models not connected"); return new Promise((resolve, reject) => { let responseSoFar = ""; let tokensSoFar = 0; let finishReason; // Add a handler for chat responses. const chatId = generateId(); chats[chatId] = async (chunk) => { if (!responseSoFar && chunk.model) { logVerbose(`chat model ${chunk.model}`); trace?.itemValue("chat model", chunk.model); trace?.appendContent("\n\n"); } trace?.appendToken(chunk.chunk); responseSoFar += chunk.chunk ?? ""; tokensSoFar += chunk.tokens ?? 0; partialCb?.({ tokensSoFar, responseSoFar, responseChunk: chunk.chunk, inner, }); finishReason = chunk.finishReason; if (finishReason) { trace?.appendContent("\n\n"); trace?.itemValue(`finish reason`, finishReason); delete chats[chatId]; if (chunk.error) { trace?.error(undefined, chunk.error); reject(chunk.error); } else resolve({ text: responseSoFar, finishReason, }); } }; // Send request to LLM clients. const payload = toPayload({ type: "chat.start", chatId, model, messages, }); for (const ws of wss.clients) { ws.send(payload); break; } }); }, }); } // Handle server shutdown by cancelling all activities. wss.on("close", () => { cancelAll(); }); // send logging messages runtimeHost.addEventListener(LOG, (ev) => { const lev = ev; const messages = chunkLines(lev.message, WS_MAX_FRAME_CHUNK_LENGTH); for (const message of messages) { const payload = toPayload({ type: "log", level: lev.level, message: message, }); for (const client of wss.clients) client.send(payload); } }); // Manage new WebSocket connections. wss.on("connection", function connection(ws, req) { logVerbose(`clients: connected (${wss.clients.size} clients)`); ws.on("error", console.error); ws.on("close", () => logVerbose(`clients: closed (${wss.clients.size} clients)`)); const send = (payload) => { const cmsg = toPayload(payload); if (dispatchProgress) for (const client of wss.clients) client.send(cmsg); else ws?.send(cmsg); }; const sendLastRunResult = () => { if (!lastRunResult) return; if (JSON.stringify(lastRunResult).length < WS_MAX_FRAME_LENGTH - 200) send(lastRunResult); else { send({ type: "script.end", runId: lastRunResult.runId, exitCode: lastRunResult.exitCode, }); } }; const sendProgress = (runId, payload) => { send({ type: "script.progress", runId, ...payload, }); }; // send traces of in-flight runs const activeRuns = Object.entries(runs); if (activeRuns.length) { for (const [runId, run] of activeRuns) { chunkString(unthink(run.outputTrace.content), WS_MAX_FRAME_CHUNK_LENGTH).forEach((c) => ws.send(toPayload({ type: "script.progress", runId, output: c, }))); if (run?.trace) { chunkString(run.trace.content, WS_MAX_FRAME_CHUNK_LENGTH).forEach((c) => ws.send(toPayload({ type: "script.progress", runId, trace: c, }))); } } } else if (lastRunResult) { sendLastRunResult(); } // Handle incoming messages based on their type. ws.on("message", async (msg) => { const data = JSON.parse(msg.toString()); const { id, type } = data; dbg(`%s: %O`, type, data); let response; try { switch (type) { // Handle version request case "server.version": { logVerbose(`server: version ${CORE_VERSION}`); response = serverVersion(); break; } // Handle environment request case "server.env": { logVerbose(`server: env`); response = await serverEnv(); break; } // Handle server kill request case "server.kill": { logVerbose(`server: kill`); process.exit(0); break; } // Handle model configuration request case "model.configuration": { const { model, token } = data; logVerbose(`model: lookup configuration ${model}`); try { const info = await runtimeHost.getLanguageModelConfiguration(model, { token }); response = { ok: true, info, }; } catch (e) { response = { ok: false, }; } break; } case "script.list": { response = await scriptList(); break; } // Handle test run request case "tests.run": { logVerbose(`tests: run ${data.scripts?.join(", ") || "*"}`); await runtimeHost.readConfig(); response = await runPromptScriptTests(data.scripts, { ...(data.options || {}), verbose: true, promptfooVersion: PROMPTFOO_VERSION, }); break; } // Handle script start request case "script.start": { // Cancel any active scripts const { script, files = [], options: runOptions = {}, runId } = data; if (!script) throw new Error("missing script"); if (files.some((f) => !f)) throw new Error("invalid file"); cancelAll(); const canceller = new AbortSignalCancellationController(); const cancellationToken = canceller.token; const trace = runTrace ? new MarkdownTrace({ cancellationToken }) : undefined; const outputTrace = new MarkdownTrace({ cancellationToken, }); if (runTrace && trace) { trace.addEventListener(TRACE_CHUNK, (ev) => { const tev = ev; chunkString(tev.chunk, WS_MAX_FRAME_CHUNK_LENGTH).forEach((c) => sendProgress(runId, { trace: c, inner: tev.inner, })); }); } outputTrace.addEventListener(TRACE_CHUNK, (ev) => { const tev = ev; chunkString(tev.chunk, WS_MAX_FRAME_CHUNK_LENGTH).forEach((c) => sendProgress(runId, { output: c, inner: tev.inner, })); }); logVerbose(`run ${runId}: starting ${script}`); await runtimeHost.readConfig(); const runner = runScriptInternal(script, files, { ...runOptions, runId, trace, runOutputTrace: outputTrace, runTrace: false, cancellationToken: canceller.token, infoCb: ({ text }) => { sendProgress(runId, { progress: text }); }, partialCb: ({ responseChunk, responseSoFar, reasoningSoFar, tokensSoFar, responseTokens, inner, }) => { sendProgress(runId, { response: responseSoFar, reasoning: reasoningSoFar, responseChunk, tokens: tokensSoFar, responseTokens, inner, }); }, }) .then(({ exitCode, result }) => { delete runs[runId]; logVerbose(`\nrun ${runId}: completed with ${exitCode}`); lastRunResult = { type: "script.end", runId, exitCode, result, trace: trace?.content || "", }; sendLastRunResult(); }) .catch((e) => { if (canceller.controller.signal.aborted) return; if (!isCancelError(e)) trace?.error(e); logError(`\nrun ${runId}: failed`); logError(e); send({ type: "script.end", runId, result: { status: "error", error: serializeError(e), }, exitCode: isCancelError(e) ? USER_CANCELLED_ERROR_CODE : UNHANDLED_ERROR_CODE, }); }); runs[runId] = { runner, canceller, trace, outputTrace, }; response = { ok: true, status: 0, runId, }; break; } // Handle script abort request case "script.abort": { const { runId, reason } = data; logVerbose(`run ${runId}: abort (${reason})`); const run = runs[runId]; if (run) { delete runs[runId]; run.canceller.abort(reason); } response = { ok: true, status: 0, runId, }; break; } // Handle chat chunk requests case "chat.chunk": { await handleChunk(data); response = { ok: true }; break; } default: throw new Error(`unknown message type ${type}`); } } catch (e) { response = { ok: false, error: serializeError(e) }; } finally { assert(!!response); if (response.error) logError(response.error); send({ id, type, response }); } }); }); const setCORSHeaders = (res) => { res.setHeader("Access-Control-Allow-Origin", corsOrigin); res.setHeader("Access-Control-Allow-Methods", "OPTIONS, GET"); res.setHeader("Access-Control-Max-Age", 24 * 3600); // 1 day res.setHeader("Access-Control-Allow-Headers", "Content-Type, Authorization, Accept"); }; const runRx = /^\/api\/runs\/(?<runId>[A-Za-z0-9_-]{12,256})$/; const imageRx = /^\/\.genaiscript\/(images|runs\/.*?)\/[a-z0-9]{12,128}\.(png|jpg|jpeg|gif|svg)$/; const ROOT = process.cwd(); // Create an HTTP server to handle basic requests. const httpServer = http.createServer(async (req, res) => { const { url, method } = req; const route = url?.replace(/\?.*$/, ""); if (method === "OPTIONS") { if (!corsOrigin) { res.statusCode = 405; res.end(); } else { setCORSHeaders(res); res.statusCode = 204; res.end(); } return; } if (corsOrigin) setCORSHeaders(res); res.setHeader("Cache-Control", "no-store"); if (method === "GET" && route === "/") { res.setHeader("Content-Type", "text/html"); res.setHeader("Cache-Control", "no-store"); res.statusCode = 200; const cspUrl = new URL(`http://${req.headers.host}`).origin; const wsCspUrl = new URL(`ws://${req.headers.host}`).origin; const nonce = randomHex(32); const csp = `<meta http-equiv="Content-Security-Policy" content=" default-src 'none'; frame-src ${cspUrl} https://*.github.dev/ https://github.dev/ https:; img-src ${cspUrl} https://*.github.dev/ https://github.dev/ https: data:; media-src ${cspUrl} https://*.github.dev/ https://github.dev/ https: data:; connect-src ${cspUrl} ${wsCspUrl} https://*.github.dev/ wss://*.github.dev/ https://github.dev/; script-src ${cspUrl} https://*.github.dev/ https://github.dev/ 'nonce-${nonce}'; style-src 'unsafe-inline' ${cspUrl} https://*.github.dev/ https://github.dev/; font-src ${cspUrl} https://*.github.dev/ https://github.dev/; "/> <script nonce=${nonce}> window.litNonce = ${JSON.stringify(nonce)}; window.vscodeWebviewPlaygroundNonce = ${JSON.stringify(nonce)}; </script> `; const filePath = join(dirname, "index.html"); const html = (await readFile(filePath, { encoding: "utf8" })).replace("<!--csp-->", csp); res.write(html); res.statusCode = 200; res.end(); } else if (method === "GET" && route === "/dist/markdown.css") { res.setHeader("Content-Type", "text/css"); res.statusCode = 200; const filePath = join(dirname, "markdown.css"); const stream = createReadStream(filePath); stream.pipe(res); } else if (method === "GET" && route === "/dist/codicon.css") { res.setHeader("Content-Type", "text/css"); res.statusCode = 200; const filePath = join(dirname, "codicon.css"); const stream = createReadStream(filePath); stream.pipe(res); } else if (method === "GET" && route === "/dist/codicon.ttf") { res.setHeader("Content-Type", "font/ttf"); res.statusCode = 200; const filePath = join(dirname, "codicon.ttf"); const stream = createReadStream(filePath); stream.pipe(res); } else if (method === "GET" && route === "/dist/web.mjs") { res.setHeader("Content-Type", "application/javascript"); res.statusCode = 200; const filePath = join(dirname, "web.mjs"); const stream = createReadStream(filePath); stream.pipe(res); } else if (method === "GET" && route === "/dist/web.mjs.map") { const filePath = join(dirname, "web.mjs.map"); if (await tryStat(filePath)) { res.setHeader("Content-Type", "text/json"); res.statusCode = 200; const stream = createReadStream(filePath); stream.pipe(res); } else { res.statusCode = 404; res.end(); } } else if (method === "GET" && route === "/favicon.svg") { res.setHeader("Content-Type", "image/svg+xml"); res.statusCode = 200; const filePath = join(dirname, "favicon.svg"); const stream = createReadStream(filePath); stream.pipe(res); } else if (method === "GET" && imageRx.test(route)) { try { const filePath = await realpath(resolve(ROOT, sanitizeFilename(route))); if (!filePath.startsWith(ROOT)) throw new Error(`invalid path ${filePath}`); const stream = createReadStream(filePath); res.setHeader("Content-Type", "image/" + extname(route)); res.statusCode = 200; stream.pipe(res); } catch { res.statusCode = 404; res.end(); } } else { // api, validate apikey if (!checkApiKey(req)) { console.debug(`401: missing or invalid api-key`); res.statusCode = 401; res.end(); return; } let response; if (method === "GET" && route === "/api/version") response = serverVersion(); else if (method === "GET" && route === "/api/scripts") { response = await scriptList(); } else if (method === "GET" && route === "/api/env") { response = await serverEnv(); } else if (method === "GET" && route === "/api/runs") { const runs = await collectRuns(); response = { ok: true, runs: runs.map(({ scriptId, runId, creationTme: creationTime }) => ({ scriptId, runId, creationTime, })), }; } else if (method === "POST" && route === "/v1/chat/completions") { if (!openAIChatCompletions) { console.debug(`403: chat completions not enabled`); res.statusCode = 403; res.end(); return; } await openaiApiChatCompletions(req, res); return; } else if (method === "GET" && route === "/v1/models") { await openaiApiModels(req, res); return; } else if (method === "GET" && runRx.test(route)) { const { runId } = runRx.exec(route).groups; logVerbose(`run: get ${runId}`); // shortcut to last run if (runId === lastRunResult?.runId) response = { ok: true, ...lastRunResult, }; else { const runs = await collectRuns(); const run = runs.find((r) => r.runId === runId); if (run) { const runResult = (await tryReadJSON(join(run.dir, "res.json"))) || {}; const runTrace = (await tryReadText(join(run.dir, TRACE_FILENAME))) || ""; response = { ok: true, type: "script.end", runId, exitCode: runResult.exitCode, result: runResult, trace: runTrace, }; } } } if (response === undefined) { console.debug(`404: ${method} ${url}`); res.statusCode = 404; res.end(); } else { res.statusCode = 200; res.setHeader("Content-Type", "application/json"); res.end(JSON.stringify(response)); } } }); // Upgrade HTTP server to handle WebSocket connections on the /wss route. httpServer.on("upgrade", (req, socket, head) => { const pathname = new URL(req.url, `http://${req.headers.host}`).pathname; if (pathname === "/" && checkApiKey(req)) { wss.handleUpgrade(req, socket, head, (ws) => { wss.emit("connection", ws, req); }); } else socket.destroy(); }); // Start the HTTP server on the specified port. const serverHash = apiKey ? `#api-key:${encodeURIComponent(apiKey)}` : ""; httpServer.listen(port, serverHost, () => { console.log(`GenAIScript server v${CORE_VERSION}`); if (remote) console.log(`│ Remote: ${remote}${options.remoteBranch ? `#${options.remoteBranch}` : ""}`); console.log(`│ Local http://${serverHost}:${port}/${serverHash}`); if (options.network) { console.log(`│ Host http://localhost:${port}/${serverHash}`); const interfaces = networkInterfaces(); for (const ifaces of Object.values(interfaces)) { for (const iface of ifaces) { if (iface.family === "IPv4" && !iface.internal) { console.log(`│ Network http://${iface.address}:${port}/${serverHash}`); } } } } }); } //# sourceMappingURL=server.js.map