UNPKG

@genkit-ai/core

Version:

Genkit AI framework core libraries.

400 lines 11.9 kB
import WebSocket from "ws"; import { StatusCodes } from "./action.mjs"; import { GENKIT_REFLECTION_API_SPEC_VERSION, GENKIT_VERSION } from "./index.mjs"; import { logger } from "./logging.mjs"; import { ReflectionCancelActionParamsSchema, ReflectionConfigureParamsSchema, ReflectionListValuesParamsSchema, ReflectionListValuesResponseSchema, ReflectionRunActionParamsSchema, ReflectionRunActionStateParamsSchema, ReflectionStreamChunkParamsSchema } from "./reflection-types.mjs"; import { toJsonSchema } from "./schema.mjs"; import { flushTracing, setTelemetryServerUrl } from "./tracing.mjs"; let apiIndex = 0; class ReflectionServerV2 { registry; options; ws = null; url; index = apiIndex++; activeActions = /* @__PURE__ */ new Map(); reconnectCount = 0; isStopped = false; reconnectTimeout = null; baseDelayMs = 500; maxDelayMs = 5e3; pendingRequests = /* @__PURE__ */ new Map(); requestIdCounter = 0; constructor(registry, options) { this.registry = registry; this.options = { configuredEnvs: ["dev"], ...options }; this.url = this.options.url; } async start() { this.isStopped = false; this.reconnectCount = 0; await this.connect(); } async connect() { if (this.isStopped) return; logger.debug(`Connecting to Reflection V2 server at ${this.url}`); const ws = new WebSocket(this.url); this.ws = ws; this.ws.on("open", async () => { logger.debug("Connected to Reflection V2 server."); this.reconnectCount = 0; await this.register(); }); this.ws.on("message", async (data) => { try { const message = JSON.parse(data.toString()); if ("method" in message) { await this.handleRequest(message); } else if ("id" in message) { this.handleResponse(message); } } catch (error) { logger.error(`Failed to parse message: ${error}`); } }); this.ws.on("error", (error) => { logger.error(`Reflection V2 WebSocket error: ${error}`); }); this.ws.on("close", (code, reason) => { logger.debug( `Reflection V2 WebSocket closed. Code: ${code}, Reason: ${reason}` ); for (const [id, resolver] of this.pendingRequests.entries()) { resolver.reject( new Error( `Connection closed before response was received (id: ${id})` ) ); } this.pendingRequests.clear(); if (!this.isStopped) { this.scheduleReconnect(); } }); } scheduleReconnect() { if (this.reconnectTimeout) return; const delay = Math.min( this.baseDelayMs * Math.pow(2, this.reconnectCount), this.maxDelayMs ); this.reconnectCount++; logger.debug( `Scheduling reconnection in ${delay}ms (attempt ${this.reconnectCount})` ); this.reconnectTimeout = setTimeout(async () => { this.reconnectTimeout = null; await this.connect(); }, delay); } async stop() { this.isStopped = true; if (this.reconnectTimeout) { clearTimeout(this.reconnectTimeout); this.reconnectTimeout = null; } if (this.ws) { this.ws.close(); this.ws = null; } } send(message) { if (this.ws && this.ws.readyState === WebSocket.OPEN) { this.ws.send(JSON.stringify(message)); } } sendResponse(id, result) { this.send({ jsonrpc: "2.0", result, id }); } sendError(id, code, message, data) { this.send({ jsonrpc: "2.0", error: { code, message, data }, id }); } sendNotification(method, params) { this.send({ jsonrpc: "2.0", method, params }); } sendRequest(method, params) { return new Promise((resolve, reject) => { const id = (++this.requestIdCounter).toString(); this.pendingRequests.set(id, { resolve, reject }); this.send({ jsonrpc: "2.0", id, method, params }); }); } async register() { const params = { id: process.env.GENKIT_RUNTIME_ID || this.runtimeId, pid: process.pid, name: this.options.name || this.runtimeId, genkitVersion: GENKIT_VERSION, reflectionApiSpecVersion: GENKIT_REFLECTION_API_SPEC_VERSION, envs: this.options.configuredEnvs }; try { const response = await this.sendRequest("register", params); if (response && response.telemetryServerUrl) { if (!process.env.GENKIT_TELEMETRY_SERVER) { setTelemetryServerUrl(response.telemetryServerUrl); logger.debug( `Connected to telemetry server on ${response.telemetryServerUrl} via handshake` ); } } } catch (err) { logger.error(`Failed to register with CLI: ${err}`); } } get runtimeId() { return `${process.pid}${this.index ? `-${this.index}` : ""}`; } handleResponse(response) { const resolver = this.pendingRequests.get(response.id); if (!resolver) { logger.error(`Unknown response ID: ${response.id}`); return; } this.pendingRequests.delete(response.id); if ("error" in response) { resolver.reject(response.error); } else { resolver.resolve(response.result); } } async handleRequest(request) { try { switch (request.method) { case "listActions": await this.handleListActions(request); break; case "listValues": await this.handleListValues(request); break; case "runAction": await this.handleRunAction(request); break; case "configure": this.handleConfigure(request); break; case "cancelAction": await this.handleCancelAction(request); break; case "sendInputStreamChunk": this.handleSendInputStreamChunk(request); break; case "endInputStream": this.handleEndInputStream(request); break; default: if (request.id) { this.sendError( request.id, -32601, `Method not found: ${request.method}` ); } } } catch (error) { if (request.id) { this.sendError(request.id, -32e3, error.message, { stack: error.stack }); } } } async handleListActions(request) { if (!request.id) return; const actions = await this.registry.listResolvableActions(); const convertedActions = {}; Object.keys(actions).forEach((key) => { const action = actions[key]; convertedActions[key] = { key, name: action.name, description: action.description, metadata: action.metadata }; if (action.inputSchema || action.inputJsonSchema) { convertedActions[key].inputSchema = toJsonSchema({ schema: action.inputSchema, jsonSchema: action.inputJsonSchema }); } if (action.outputSchema || action.outputJsonSchema) { convertedActions[key].outputSchema = toJsonSchema({ schema: action.outputSchema, jsonSchema: action.outputJsonSchema }); } }); this.sendResponse(request.id, { actions: convertedActions }); } async handleListValues(request) { if (!request.id) return; const { type } = ReflectionListValuesParamsSchema.parse(request.params); if (type !== "defaultModel" && type !== "middleware") { this.sendError( request.id, -32602, `'type' ${type} is not supported. Only 'defaultModel' and 'middleware' are supported` ); return; } const values = await this.registry.listValues(type); const mappedValues = {}; for (const [key, value] of Object.entries(values)) { mappedValues[key] = value && typeof value === "object" && "toJson" in value && typeof value.toJson === "function" ? value.toJson() : value; } this.sendResponse( request.id, ReflectionListValuesResponseSchema.parse({ values: mappedValues }) ); } async handleRunAction(request) { if (!request.id) return; const { key, input, context, telemetryLabels, stream } = ReflectionRunActionParamsSchema.parse(request.params); const action = await this.registry.lookupAction(key); if (!action) { this.sendError(request.id, -32602, `action ${key} not found`); return; } const abortController = new AbortController(); let traceId; try { const onTraceStartCallback = ({ traceId: tid }) => { traceId = tid; this.activeActions.set(tid, { abortController, startTime: /* @__PURE__ */ new Date() }); this.sendNotification( "runActionState", ReflectionRunActionStateParamsSchema.parse({ requestId: request.id, state: { traceId: tid } }) ); }; if (stream) { const callback = (chunk) => { this.sendNotification( "streamChunk", ReflectionStreamChunkParamsSchema.parse({ requestId: request.id, chunk }) ); }; const result = await action.run(input, { context, onChunk: callback, telemetryLabels, onTraceStart: onTraceStartCallback, abortSignal: abortController.signal }); await flushTracing(); this.sendResponse(request.id, { result: result.result, telemetry: { traceId: result.telemetry.traceId } }); } else { const result = await action.run(input, { context, telemetryLabels, onTraceStart: onTraceStartCallback, abortSignal: abortController.signal }); await flushTracing(); this.sendResponse(request.id, { result: result.result, telemetry: { traceId: result.telemetry.traceId } }); } } catch (err) { const isAbort = err?.name === "AbortError" || typeof DOMException !== "undefined" && err instanceof DOMException && err.name === "AbortError"; const errorResponse = { code: isAbort ? StatusCodes.CANCELLED : StatusCodes.INTERNAL, message: isAbort ? "Action was cancelled" : err.message, details: { stack: err.stack } }; if (err.traceId || traceId) { errorResponse.details.traceId = err.traceId || traceId; } this.sendError(request.id, -32e3, errorResponse.message, errorResponse); } finally { if (traceId) { this.activeActions.delete(traceId); } } } handleConfigure(request) { const { telemetryServerUrl } = ReflectionConfigureParamsSchema.parse( request.params ); if (telemetryServerUrl && !process.env.GENKIT_TELEMETRY_SERVER) { setTelemetryServerUrl(telemetryServerUrl); logger.debug(`Connected to telemetry server on ${telemetryServerUrl}`); } } async handleCancelAction(request) { if (!request.id) return; const { traceId } = ReflectionCancelActionParamsSchema.parse( request.params ); const activeAction = this.activeActions.get(traceId); if (activeAction) { activeAction.abortController.abort(); this.activeActions.delete(traceId); this.sendResponse(request.id, { message: "Action cancelled" }); } else { this.sendError( request.id, -32602, "Action not found or already completed" ); } } handleSendInputStreamChunk(request) { throw new Error("Not implemented"); } handleEndInputStream(request) { throw new Error("Not implemented"); } } export { ReflectionServerV2 }; //# sourceMappingURL=reflection-v2.mjs.map