@genkit-ai/core
Version:
Genkit AI framework core libraries.
400 lines • 11.9 kB
JavaScript
import WebSocket from "ws";
import { StatusCodes } from "./action.mjs";
import { GENKIT_REFLECTION_API_SPEC_VERSION, GENKIT_VERSION } from "./index.mjs";
import { logger } from "./logging.mjs";
import {
ReflectionCancelActionParamsSchema,
ReflectionConfigureParamsSchema,
ReflectionListValuesParamsSchema,
ReflectionListValuesResponseSchema,
ReflectionRunActionParamsSchema,
ReflectionRunActionStateParamsSchema,
ReflectionStreamChunkParamsSchema
} from "./reflection-types.mjs";
import { toJsonSchema } from "./schema.mjs";
import { flushTracing, setTelemetryServerUrl } from "./tracing.mjs";
let apiIndex = 0;
class ReflectionServerV2 {
registry;
options;
ws = null;
url;
index = apiIndex++;
activeActions = /* @__PURE__ */ new Map();
reconnectCount = 0;
isStopped = false;
reconnectTimeout = null;
baseDelayMs = 500;
maxDelayMs = 5e3;
pendingRequests = /* @__PURE__ */ new Map();
requestIdCounter = 0;
constructor(registry, options) {
this.registry = registry;
this.options = {
configuredEnvs: ["dev"],
...options
};
this.url = this.options.url;
}
async start() {
this.isStopped = false;
this.reconnectCount = 0;
await this.connect();
}
async connect() {
if (this.isStopped) return;
logger.debug(`Connecting to Reflection V2 server at ${this.url}`);
const ws = new WebSocket(this.url);
this.ws = ws;
this.ws.on("open", async () => {
logger.debug("Connected to Reflection V2 server.");
this.reconnectCount = 0;
await this.register();
});
this.ws.on("message", async (data) => {
try {
const message = JSON.parse(data.toString());
if ("method" in message) {
await this.handleRequest(message);
} else if ("id" in message) {
this.handleResponse(message);
}
} catch (error) {
logger.error(`Failed to parse message: ${error}`);
}
});
this.ws.on("error", (error) => {
logger.error(`Reflection V2 WebSocket error: ${error}`);
});
this.ws.on("close", (code, reason) => {
logger.debug(
`Reflection V2 WebSocket closed. Code: ${code}, Reason: ${reason}`
);
for (const [id, resolver] of this.pendingRequests.entries()) {
resolver.reject(
new Error(
`Connection closed before response was received (id: ${id})`
)
);
}
this.pendingRequests.clear();
if (!this.isStopped) {
this.scheduleReconnect();
}
});
}
scheduleReconnect() {
if (this.reconnectTimeout) return;
const delay = Math.min(
this.baseDelayMs * Math.pow(2, this.reconnectCount),
this.maxDelayMs
);
this.reconnectCount++;
logger.debug(
`Scheduling reconnection in ${delay}ms (attempt ${this.reconnectCount})`
);
this.reconnectTimeout = setTimeout(async () => {
this.reconnectTimeout = null;
await this.connect();
}, delay);
}
async stop() {
this.isStopped = true;
if (this.reconnectTimeout) {
clearTimeout(this.reconnectTimeout);
this.reconnectTimeout = null;
}
if (this.ws) {
this.ws.close();
this.ws = null;
}
}
send(message) {
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
this.ws.send(JSON.stringify(message));
}
}
sendResponse(id, result) {
this.send({
jsonrpc: "2.0",
result,
id
});
}
sendError(id, code, message, data) {
this.send({
jsonrpc: "2.0",
error: { code, message, data },
id
});
}
sendNotification(method, params) {
this.send({
jsonrpc: "2.0",
method,
params
});
}
sendRequest(method, params) {
return new Promise((resolve, reject) => {
const id = (++this.requestIdCounter).toString();
this.pendingRequests.set(id, { resolve, reject });
this.send({
jsonrpc: "2.0",
id,
method,
params
});
});
}
async register() {
const params = {
id: process.env.GENKIT_RUNTIME_ID || this.runtimeId,
pid: process.pid,
name: this.options.name || this.runtimeId,
genkitVersion: GENKIT_VERSION,
reflectionApiSpecVersion: GENKIT_REFLECTION_API_SPEC_VERSION,
envs: this.options.configuredEnvs
};
try {
const response = await this.sendRequest("register", params);
if (response && response.telemetryServerUrl) {
if (!process.env.GENKIT_TELEMETRY_SERVER) {
setTelemetryServerUrl(response.telemetryServerUrl);
logger.debug(
`Connected to telemetry server on ${response.telemetryServerUrl} via handshake`
);
}
}
} catch (err) {
logger.error(`Failed to register with CLI: ${err}`);
}
}
get runtimeId() {
return `${process.pid}${this.index ? `-${this.index}` : ""}`;
}
handleResponse(response) {
const resolver = this.pendingRequests.get(response.id);
if (!resolver) {
logger.error(`Unknown response ID: ${response.id}`);
return;
}
this.pendingRequests.delete(response.id);
if ("error" in response) {
resolver.reject(response.error);
} else {
resolver.resolve(response.result);
}
}
async handleRequest(request) {
try {
switch (request.method) {
case "listActions":
await this.handleListActions(request);
break;
case "listValues":
await this.handleListValues(request);
break;
case "runAction":
await this.handleRunAction(request);
break;
case "configure":
this.handleConfigure(request);
break;
case "cancelAction":
await this.handleCancelAction(request);
break;
case "sendInputStreamChunk":
this.handleSendInputStreamChunk(request);
break;
case "endInputStream":
this.handleEndInputStream(request);
break;
default:
if (request.id) {
this.sendError(
request.id,
-32601,
`Method not found: ${request.method}`
);
}
}
} catch (error) {
if (request.id) {
this.sendError(request.id, -32e3, error.message, {
stack: error.stack
});
}
}
}
async handleListActions(request) {
if (!request.id) return;
const actions = await this.registry.listResolvableActions();
const convertedActions = {};
Object.keys(actions).forEach((key) => {
const action = actions[key];
convertedActions[key] = {
key,
name: action.name,
description: action.description,
metadata: action.metadata
};
if (action.inputSchema || action.inputJsonSchema) {
convertedActions[key].inputSchema = toJsonSchema({
schema: action.inputSchema,
jsonSchema: action.inputJsonSchema
});
}
if (action.outputSchema || action.outputJsonSchema) {
convertedActions[key].outputSchema = toJsonSchema({
schema: action.outputSchema,
jsonSchema: action.outputJsonSchema
});
}
});
this.sendResponse(request.id, {
actions: convertedActions
});
}
async handleListValues(request) {
if (!request.id) return;
const { type } = ReflectionListValuesParamsSchema.parse(request.params);
if (type !== "defaultModel" && type !== "middleware") {
this.sendError(
request.id,
-32602,
`'type' ${type} is not supported. Only 'defaultModel' and 'middleware' are supported`
);
return;
}
const values = await this.registry.listValues(type);
const mappedValues = {};
for (const [key, value] of Object.entries(values)) {
mappedValues[key] = value && typeof value === "object" && "toJson" in value && typeof value.toJson === "function" ? value.toJson() : value;
}
this.sendResponse(
request.id,
ReflectionListValuesResponseSchema.parse({ values: mappedValues })
);
}
async handleRunAction(request) {
if (!request.id) return;
const { key, input, context, telemetryLabels, stream } = ReflectionRunActionParamsSchema.parse(request.params);
const action = await this.registry.lookupAction(key);
if (!action) {
this.sendError(request.id, -32602, `action ${key} not found`);
return;
}
const abortController = new AbortController();
let traceId;
try {
const onTraceStartCallback = ({ traceId: tid }) => {
traceId = tid;
this.activeActions.set(tid, {
abortController,
startTime: /* @__PURE__ */ new Date()
});
this.sendNotification(
"runActionState",
ReflectionRunActionStateParamsSchema.parse({
requestId: request.id,
state: { traceId: tid }
})
);
};
if (stream) {
const callback = (chunk) => {
this.sendNotification(
"streamChunk",
ReflectionStreamChunkParamsSchema.parse({
requestId: request.id,
chunk
})
);
};
const result = await action.run(input, {
context,
onChunk: callback,
telemetryLabels,
onTraceStart: onTraceStartCallback,
abortSignal: abortController.signal
});
await flushTracing();
this.sendResponse(request.id, {
result: result.result,
telemetry: {
traceId: result.telemetry.traceId
}
});
} else {
const result = await action.run(input, {
context,
telemetryLabels,
onTraceStart: onTraceStartCallback,
abortSignal: abortController.signal
});
await flushTracing();
this.sendResponse(request.id, {
result: result.result,
telemetry: {
traceId: result.telemetry.traceId
}
});
}
} catch (err) {
const isAbort = err?.name === "AbortError" || typeof DOMException !== "undefined" && err instanceof DOMException && err.name === "AbortError";
const errorResponse = {
code: isAbort ? StatusCodes.CANCELLED : StatusCodes.INTERNAL,
message: isAbort ? "Action was cancelled" : err.message,
details: {
stack: err.stack
}
};
if (err.traceId || traceId) {
errorResponse.details.traceId = err.traceId || traceId;
}
this.sendError(request.id, -32e3, errorResponse.message, errorResponse);
} finally {
if (traceId) {
this.activeActions.delete(traceId);
}
}
}
handleConfigure(request) {
const { telemetryServerUrl } = ReflectionConfigureParamsSchema.parse(
request.params
);
if (telemetryServerUrl && !process.env.GENKIT_TELEMETRY_SERVER) {
setTelemetryServerUrl(telemetryServerUrl);
logger.debug(`Connected to telemetry server on ${telemetryServerUrl}`);
}
}
async handleCancelAction(request) {
if (!request.id) return;
const { traceId } = ReflectionCancelActionParamsSchema.parse(
request.params
);
const activeAction = this.activeActions.get(traceId);
if (activeAction) {
activeAction.abortController.abort();
this.activeActions.delete(traceId);
this.sendResponse(request.id, { message: "Action cancelled" });
} else {
this.sendError(
request.id,
-32602,
"Action not found or already completed"
);
}
}
handleSendInputStreamChunk(request) {
throw new Error("Not implemented");
}
handleEndInputStream(request) {
throw new Error("Not implemented");
}
}
export {
ReflectionServerV2
};
//# sourceMappingURL=reflection-v2.mjs.map