genaiscript
Version:
A CLI for GenAIScript, a generative AI scripting framework.
717 lines • 31.1 kB
JavaScript
/* eslint-disable no-param-reassign */
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/* eslint-disable @typescript-eslint/explicit-function-return-type */
import { WebSocketServer } from "ws";
import { runPromptScriptTests } from "./test.js";
import { PROMPTFOO_VERSION } from "@genaiscript/runtime";
import { runScriptInternal } from "@genaiscript/api";
import { CORE_VERSION, LOG, MODEL_PROVIDER_GITHUB_COPILOT_CHAT, SERVER_PORT, TRACE_CHUNK, TRACE_FILENAME, UNHANDLED_ERROR_CODE, USER_CANCELLED_ERROR_CODE, WS_MAX_FRAME_CHUNK_LENGTH, WS_MAX_FRAME_LENGTH, AbortSignalCancellationController, MarkdownTrace, assert, chunkLines, chunkString, deleteUndefinedValues, generateId, genaiscriptDebug, isCancelError, logError, logVerbose, nodeTryReadPackage, randomHex, resolveLanguageModelConfigurations, serializeError, tryReadJSON, tryReadText, unthink, getModulePaths, sanitizeFilename, resolveRuntimeHost, } from "@genaiscript/core";
import { createReadStream } from "node:fs";
import { URL } from "node:url";
import { findOpenPort } from "./port.js";
import { applyRemoteOptions } from "./remote.js";
import * as http from "node:http";
import { startProjectWatcher } from "./watch.js";
import { extname, join, resolve } from "node:path";
import { readFile, realpath } from "node:fs/promises";
import { tryStat } from "@genaiscript/core";
import { collectRuns } from "./runs.js";
import { openaiApiChatCompletions, openaiApiModels } from "./openaiapi.js";
import { networkInterfaces } from "node:os";
const dbg = genaiscriptDebug("server");
const { __dirname } = typeof module !== "undefined" && module.filename
? getModulePaths(module)
: // eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
getModulePaths(import.meta);
/**
* Starts a WebSocket server for handling chat and script execution.
*
* @param options - Configuration options including:
* - port: The port to run the WebSocket server on.
* - httpPort: Optional HTTP port for additional services.
* - apiKey: Optional API key for authentication.
* - cors: Optional CORS configuration.
* - network: Whether to allow network access.
* - dispatchProgress: Whether to dispatch progress updates to all clients.
* - githubCopilotChatClient: Whether to enable GitHub Copilot Chat client integration.
* - remote: Remote configuration options.
* - remoteBranch: Optional branch name for remote configuration.
*/
export async function startServer(options) {
const runtimeHost = resolveRuntimeHost();
// Parse and set the server port, using a default if not specified.
const corsOrigin = options.cors || process.env.GENAISCRIPT_CORS_ORIGIN;
const apiKey = options.apiKey || process.env.GENAISCRIPT_API_KEY;
const serverHost = options.network ? "0.0.0.0" : "127.0.0.1";
const remote = options.remote;
const dispatchProgress = !!options.dispatchProgress;
const openAIChatCompletions = !!options.chat;
const runTrace = !!options.runTrace;
const port = await findOpenPort(SERVER_PORT, options);
await applyRemoteOptions(options);
const watcher = await startProjectWatcher({});
// read current project info
const { name, displayName, description, version, homepage, author } = (await nodeTryReadPackage()) || {};
const readme = (await tryReadText("README.genai.md")) || (await tryReadText("README.md"));
const wss = new WebSocketServer({ noServer: true });
const dirname = resolve(__dirname, "..");
// Stores active script runs with their cancellation controllers and traces.
let lastRunResult = undefined;
const runs = {};
// Stores active chat handlers.
const chats = {};
const toPayload = (payload) => {
const msg = JSON.stringify(payload);
if (msg.length > WS_MAX_FRAME_LENGTH) {
throw new Error(`server: message too large (${msg.length} > ${WS_MAX_FRAME_LENGTH})`);
}
return msg;
};
// Cancels all active runs and chats.
const cancelAll = () => {
for (const [runId, run] of Object.entries(runs)) {
logVerbose(`abort run ${runId}`);
run.canceller.abort("closing");
delete runs[runId];
}
for (const [chatId, chat] of Object.entries(chats)) {
logVerbose(`abort chat ${chat}`);
for (const ws of wss.clients) {
ws.send(toPayload({
type: "chat.cancel",
chatId,
}));
break;
}
delete chats[chatId];
}
};
// Handles incoming chat chunks and calls the appropriate handler.
const handleChunk = async (chunk) => {
const handler = chats[chunk.chatId];
if (handler) {
if (chunk.finishReason)
delete chats[chunk.chatId];
await handler(chunk);
}
};
const checkApiKey = (req) => {
if (!apiKey)
return true;
const { authorization } = req.headers;
if (authorization === apiKey || `Bearer ${apiKey}`)
return true;
const url = req.url.replace(/^[^?]*\?/, "");
const search = new URLSearchParams(url);
const hash = search.get("api-key");
if (hash === apiKey)
return true;
logError(`clients: connection unauthorized ${url}`);
logVerbose(`url :${req.url}`);
logVerbose(`key :${apiKey}`);
logVerbose(`auth:${authorization}`);
logVerbose(`hash:${hash}`);
return false;
};
const serverVersion = () => ({
ok: true,
version: CORE_VERSION,
node: process.version,
platform: process.platform,
arch: process.arch,
pid: process.pid,
});
const serverEnv = async () => {
return deleteUndefinedValues({
ok: true,
providers: (await resolveLanguageModelConfigurations(undefined, {
token: true,
error: true,
models: true,
})).map(({ token, ...rest }) => rest),
modelAliases: runtimeHost.modelAliases,
remote: remote
? {
url: remote,
branch: options.remoteBranch,
}
: undefined,
configuration: deleteUndefinedValues({
name: displayName || name,
description,
version,
homepage,
author,
readme,
}),
});
};
const scriptList = async () => {
logVerbose(`project: list scripts`);
const project = await watcher.project();
const scripts = project?.scripts || [];
logVerbose(`project: found ${scripts.filter((s) => !s.unlisted).length} scripts (${scripts.filter((s) => !!s.unlisted).length} unlisted)`);
return {
ok: true,
status: 0,
project,
};
};
// Configures the client language model with a completer function.
if (options?.githubCopilotChatClient) {
runtimeHost.clientLanguageModel = Object.freeze({
id: MODEL_PROVIDER_GITHUB_COPILOT_CHAT,
completer: async (req, connection, options, trace) => {
const { messages, model } = req;
const { partialCb, inner } = options;
if (!wss.clients?.size)
throw new Error("GitHub Copilot Chat Models not connected");
return new Promise((resolve, reject) => {
let responseSoFar = "";
let tokensSoFar = 0;
let finishReason;
// Add a handler for chat responses.
const chatId = generateId();
chats[chatId] = async (chunk) => {
if (!responseSoFar && chunk.model) {
logVerbose(`chat model ${chunk.model}`);
trace?.itemValue("chat model", chunk.model);
trace?.appendContent("\n\n");
}
trace?.appendToken(chunk.chunk);
responseSoFar += chunk.chunk ?? "";
tokensSoFar += chunk.tokens ?? 0;
partialCb?.({
tokensSoFar,
responseSoFar,
responseChunk: chunk.chunk,
inner,
});
finishReason = chunk.finishReason;
if (finishReason) {
trace?.appendContent("\n\n");
trace?.itemValue(`finish reason`, finishReason);
delete chats[chatId];
if (chunk.error) {
trace?.error(undefined, chunk.error);
reject(chunk.error);
}
else
resolve({
text: responseSoFar,
finishReason,
});
}
};
// Send request to LLM clients.
const payload = toPayload({
type: "chat.start",
chatId,
model,
messages,
});
for (const ws of wss.clients) {
ws.send(payload);
break;
}
});
},
});
}
// Handle server shutdown by cancelling all activities.
wss.on("close", () => {
cancelAll();
});
// send logging messages
runtimeHost.addEventListener(LOG, (ev) => {
const lev = ev;
const messages = chunkLines(lev.message, WS_MAX_FRAME_CHUNK_LENGTH);
for (const message of messages) {
const payload = toPayload({
type: "log",
level: lev.level,
message: message,
});
for (const client of wss.clients)
client.send(payload);
}
});
// Manage new WebSocket connections.
wss.on("connection", function connection(ws, req) {
logVerbose(`clients: connected (${wss.clients.size} clients)`);
ws.on("error", console.error);
ws.on("close", () => logVerbose(`clients: closed (${wss.clients.size} clients)`));
const send = (payload) => {
const cmsg = toPayload(payload);
if (dispatchProgress)
for (const client of wss.clients)
client.send(cmsg);
else
ws?.send(cmsg);
};
const sendLastRunResult = () => {
if (!lastRunResult)
return;
if (JSON.stringify(lastRunResult).length < WS_MAX_FRAME_LENGTH - 200)
send(lastRunResult);
else {
send({
type: "script.end",
runId: lastRunResult.runId,
exitCode: lastRunResult.exitCode,
});
}
};
const sendProgress = (runId, payload) => {
send({
type: "script.progress",
runId,
...payload,
});
};
// send traces of in-flight runs
const activeRuns = Object.entries(runs);
if (activeRuns.length) {
for (const [runId, run] of activeRuns) {
chunkString(unthink(run.outputTrace.content), WS_MAX_FRAME_CHUNK_LENGTH).forEach((c) => ws.send(toPayload({
type: "script.progress",
runId,
output: c,
})));
if (run?.trace) {
chunkString(run.trace.content, WS_MAX_FRAME_CHUNK_LENGTH).forEach((c) => ws.send(toPayload({
type: "script.progress",
runId,
trace: c,
})));
}
}
}
else if (lastRunResult) {
sendLastRunResult();
}
// Handle incoming messages based on their type.
ws.on("message", async (msg) => {
const data = JSON.parse(msg.toString());
const { id, type } = data;
dbg(`%s: %O`, type, data);
let response;
try {
switch (type) {
// Handle version request
case "server.version": {
logVerbose(`server: version ${CORE_VERSION}`);
response = serverVersion();
break;
}
// Handle environment request
case "server.env": {
logVerbose(`server: env`);
response = await serverEnv();
break;
}
// Handle server kill request
case "server.kill": {
logVerbose(`server: kill`);
process.exit(0);
break;
}
// Handle model configuration request
case "model.configuration": {
const { model, token } = data;
logVerbose(`model: lookup configuration ${model}`);
try {
const info = await runtimeHost.getLanguageModelConfiguration(model, { token });
response = {
ok: true,
info,
};
}
catch (e) {
response = {
ok: false,
};
}
break;
}
case "script.list": {
response = await scriptList();
break;
}
// Handle test run request
case "tests.run": {
logVerbose(`tests: run ${data.scripts?.join(", ") || "*"}`);
await runtimeHost.readConfig();
response = await runPromptScriptTests(data.scripts, {
...(data.options || {}),
verbose: true,
promptfooVersion: PROMPTFOO_VERSION,
});
break;
}
// Handle script start request
case "script.start": {
// Cancel any active scripts
const { script, files = [], options: runOptions = {}, runId } = data;
if (!script)
throw new Error("missing script");
if (files.some((f) => !f))
throw new Error("invalid file");
cancelAll();
const canceller = new AbortSignalCancellationController();
const cancellationToken = canceller.token;
const trace = runTrace ? new MarkdownTrace({ cancellationToken }) : undefined;
const outputTrace = new MarkdownTrace({
cancellationToken,
});
if (runTrace && trace) {
trace.addEventListener(TRACE_CHUNK, (ev) => {
const tev = ev;
chunkString(tev.chunk, WS_MAX_FRAME_CHUNK_LENGTH).forEach((c) => sendProgress(runId, {
trace: c,
inner: tev.inner,
}));
});
}
outputTrace.addEventListener(TRACE_CHUNK, (ev) => {
const tev = ev;
chunkString(tev.chunk, WS_MAX_FRAME_CHUNK_LENGTH).forEach((c) => sendProgress(runId, {
output: c,
inner: tev.inner,
}));
});
logVerbose(`run ${runId}: starting ${script}`);
await runtimeHost.readConfig();
const runner = runScriptInternal(script, files, {
...runOptions,
runId,
trace,
runOutputTrace: outputTrace,
runTrace: false,
cancellationToken: canceller.token,
infoCb: ({ text }) => {
sendProgress(runId, { progress: text });
},
partialCb: ({ responseChunk, responseSoFar, reasoningSoFar, tokensSoFar, responseTokens, inner, }) => {
sendProgress(runId, {
response: responseSoFar,
reasoning: reasoningSoFar,
responseChunk,
tokens: tokensSoFar,
responseTokens,
inner,
});
},
})
.then(({ exitCode, result }) => {
delete runs[runId];
logVerbose(`\nrun ${runId}: completed with ${exitCode}`);
lastRunResult = {
type: "script.end",
runId,
exitCode,
result,
trace: trace?.content || "",
};
sendLastRunResult();
})
.catch((e) => {
if (canceller.controller.signal.aborted)
return;
if (!isCancelError(e))
trace?.error(e);
logError(`\nrun ${runId}: failed`);
logError(e);
send({
type: "script.end",
runId,
result: {
status: "error",
error: serializeError(e),
},
exitCode: isCancelError(e) ? USER_CANCELLED_ERROR_CODE : UNHANDLED_ERROR_CODE,
});
});
runs[runId] = {
runner,
canceller,
trace,
outputTrace,
};
response = {
ok: true,
status: 0,
runId,
};
break;
}
// Handle script abort request
case "script.abort": {
const { runId, reason } = data;
logVerbose(`run ${runId}: abort (${reason})`);
const run = runs[runId];
if (run) {
delete runs[runId];
run.canceller.abort(reason);
}
response = {
ok: true,
status: 0,
runId,
};
break;
}
// Handle chat chunk requests
case "chat.chunk": {
await handleChunk(data);
response = { ok: true };
break;
}
default:
throw new Error(`unknown message type ${type}`);
}
}
catch (e) {
response = { ok: false, error: serializeError(e) };
}
finally {
assert(!!response);
if (response.error)
logError(response.error);
send({ id, type, response });
}
});
});
const setCORSHeaders = (res) => {
res.setHeader("Access-Control-Allow-Origin", corsOrigin);
res.setHeader("Access-Control-Allow-Methods", "OPTIONS, GET");
res.setHeader("Access-Control-Max-Age", 24 * 3600); // 1 day
res.setHeader("Access-Control-Allow-Headers", "Content-Type, Authorization, Accept");
};
const runRx = /^\/api\/runs\/(?<runId>[A-Za-z0-9_-]{12,256})$/;
const imageRx = /^\/\.genaiscript\/(images|runs\/.*?)\/[a-z0-9]{12,128}\.(png|jpg|jpeg|gif|svg)$/;
const ROOT = process.cwd();
// Create an HTTP server to handle basic requests.
const httpServer = http.createServer(async (req, res) => {
const { url, method } = req;
const route = url?.replace(/\?.*$/, "");
if (method === "OPTIONS") {
if (!corsOrigin) {
res.statusCode = 405;
res.end();
}
else {
setCORSHeaders(res);
res.statusCode = 204;
res.end();
}
return;
}
if (corsOrigin)
setCORSHeaders(res);
res.setHeader("Cache-Control", "no-store");
if (method === "GET" && route === "/") {
res.setHeader("Content-Type", "text/html");
res.setHeader("Cache-Control", "no-store");
res.statusCode = 200;
const cspUrl = new URL(`http://${req.headers.host}`).origin;
const wsCspUrl = new URL(`ws://${req.headers.host}`).origin;
const nonce = randomHex(32);
const csp = `<meta http-equiv="Content-Security-Policy" content="
default-src 'none';
frame-src ${cspUrl} https://*.github.dev/ https://github.dev/ https:;
img-src ${cspUrl} https://*.github.dev/ https://github.dev/ https: data:;
media-src ${cspUrl} https://*.github.dev/ https://github.dev/ https: data:;
connect-src ${cspUrl} ${wsCspUrl} https://*.github.dev/ wss://*.github.dev/ https://github.dev/;
script-src ${cspUrl} https://*.github.dev/ https://github.dev/ 'nonce-${nonce}';
style-src 'unsafe-inline' ${cspUrl} https://*.github.dev/ https://github.dev/;
font-src ${cspUrl} https://*.github.dev/ https://github.dev/;
"/>
<script nonce=${nonce}>
window.litNonce = ${JSON.stringify(nonce)};
window.vscodeWebviewPlaygroundNonce = ${JSON.stringify(nonce)};
</script>
`;
const filePath = join(dirname, "index.html");
const html = (await readFile(filePath, { encoding: "utf8" })).replace("<!--csp-->", csp);
res.write(html);
res.statusCode = 200;
res.end();
}
else if (method === "GET" && route === "/dist/markdown.css") {
res.setHeader("Content-Type", "text/css");
res.statusCode = 200;
const filePath = join(dirname, "markdown.css");
const stream = createReadStream(filePath);
stream.pipe(res);
}
else if (method === "GET" && route === "/dist/codicon.css") {
res.setHeader("Content-Type", "text/css");
res.statusCode = 200;
const filePath = join(dirname, "codicon.css");
const stream = createReadStream(filePath);
stream.pipe(res);
}
else if (method === "GET" && route === "/dist/codicon.ttf") {
res.setHeader("Content-Type", "font/ttf");
res.statusCode = 200;
const filePath = join(dirname, "codicon.ttf");
const stream = createReadStream(filePath);
stream.pipe(res);
}
else if (method === "GET" && route === "/dist/web.mjs") {
res.setHeader("Content-Type", "application/javascript");
res.statusCode = 200;
const filePath = join(dirname, "web.mjs");
const stream = createReadStream(filePath);
stream.pipe(res);
}
else if (method === "GET" && route === "/dist/web.mjs.map") {
const filePath = join(dirname, "web.mjs.map");
if (await tryStat(filePath)) {
res.setHeader("Content-Type", "text/json");
res.statusCode = 200;
const stream = createReadStream(filePath);
stream.pipe(res);
}
else {
res.statusCode = 404;
res.end();
}
}
else if (method === "GET" && route === "/favicon.svg") {
res.setHeader("Content-Type", "image/svg+xml");
res.statusCode = 200;
const filePath = join(dirname, "favicon.svg");
const stream = createReadStream(filePath);
stream.pipe(res);
}
else if (method === "GET" && imageRx.test(route)) {
try {
const filePath = await realpath(resolve(ROOT, sanitizeFilename(route)));
if (!filePath.startsWith(ROOT))
throw new Error(`invalid path ${filePath}`);
const stream = createReadStream(filePath);
res.setHeader("Content-Type", "image/" + extname(route));
res.statusCode = 200;
stream.pipe(res);
}
catch {
res.statusCode = 404;
res.end();
}
}
else {
// api, validate apikey
if (!checkApiKey(req)) {
console.debug(`401: missing or invalid api-key`);
res.statusCode = 401;
res.end();
return;
}
let response;
if (method === "GET" && route === "/api/version")
response = serverVersion();
else if (method === "GET" && route === "/api/scripts") {
response = await scriptList();
}
else if (method === "GET" && route === "/api/env") {
response = await serverEnv();
}
else if (method === "GET" && route === "/api/runs") {
const runs = await collectRuns();
response = {
ok: true,
runs: runs.map(({ scriptId, runId, creationTme: creationTime }) => ({
scriptId,
runId,
creationTime,
})),
};
}
else if (method === "POST" && route === "/v1/chat/completions") {
if (!openAIChatCompletions) {
console.debug(`403: chat completions not enabled`);
res.statusCode = 403;
res.end();
return;
}
await openaiApiChatCompletions(req, res);
return;
}
else if (method === "GET" && route === "/v1/models") {
await openaiApiModels(req, res);
return;
}
else if (method === "GET" && runRx.test(route)) {
const { runId } = runRx.exec(route).groups;
logVerbose(`run: get ${runId}`);
// shortcut to last run
if (runId === lastRunResult?.runId)
response = {
ok: true,
...lastRunResult,
};
else {
const runs = await collectRuns();
const run = runs.find((r) => r.runId === runId);
if (run) {
const runResult = (await tryReadJSON(join(run.dir, "res.json"))) || {};
const runTrace = (await tryReadText(join(run.dir, TRACE_FILENAME))) || "";
response = {
ok: true,
type: "script.end",
runId,
exitCode: runResult.exitCode,
result: runResult,
trace: runTrace,
};
}
}
}
if (response === undefined) {
console.debug(`404: ${method} ${url}`);
res.statusCode = 404;
res.end();
}
else {
res.statusCode = 200;
res.setHeader("Content-Type", "application/json");
res.end(JSON.stringify(response));
}
}
});
// Upgrade HTTP server to handle WebSocket connections on the /wss route.
httpServer.on("upgrade", (req, socket, head) => {
const pathname = new URL(req.url, `http://${req.headers.host}`).pathname;
if (pathname === "/" && checkApiKey(req)) {
wss.handleUpgrade(req, socket, head, (ws) => {
wss.emit("connection", ws, req);
});
}
else
socket.destroy();
});
// Start the HTTP server on the specified port.
const serverHash = apiKey ? `#api-key:${encodeURIComponent(apiKey)}` : "";
httpServer.listen(port, serverHost, () => {
console.log(`GenAIScript server v${CORE_VERSION}`);
if (remote)
console.log(`│ Remote: ${remote}${options.remoteBranch ? `#${options.remoteBranch}` : ""}`);
console.log(`│ Local http://${serverHost}:${port}/${serverHash}`);
if (options.network) {
console.log(`│ Host http://localhost:${port}/${serverHash}`);
const interfaces = networkInterfaces();
for (const ifaces of Object.values(interfaces)) {
for (const iface of ifaces) {
if (iface.family === "IPv4" && !iface.internal) {
console.log(`│ Network http://${iface.address}:${port}/${serverHash}`);
}
}
}
}
});
}
//# sourceMappingURL=server.js.map