@genkit-ai/core
Version:
Genkit AI framework core libraries.
426 lines • 13.7 kB
JavaScript
"use strict";
var __create = Object.create;
var __defProp = Object.defineProperty;
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
var __getOwnPropNames = Object.getOwnPropertyNames;
var __getProtoOf = Object.getPrototypeOf;
var __hasOwnProp = Object.prototype.hasOwnProperty;
var __export = (target, all) => {
for (var name in all)
__defProp(target, name, { get: all[name], enumerable: true });
};
var __copyProps = (to, from, except, desc) => {
if (from && typeof from === "object" || typeof from === "function") {
for (let key of __getOwnPropNames(from))
if (!__hasOwnProp.call(to, key) && key !== except)
__defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable });
}
return to;
};
var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__getProtoOf(mod)) : {}, __copyProps(
// If the importer is in node compatibility mode or this is not an ESM
// file that has been converted to a CommonJS file using a Babel-
// compatible transform (i.e. "__esModule" has not been set), then set
// "default" to the CommonJS "module.exports" for node compatibility.
isNodeMode || !mod || !mod.__esModule ? __defProp(target, "default", { value: mod, enumerable: true }) : target,
mod
));
var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod);
var reflection_v2_exports = {};
__export(reflection_v2_exports, {
ReflectionServerV2: () => ReflectionServerV2
});
module.exports = __toCommonJS(reflection_v2_exports);
var import_ws = __toESM(require("ws"));
var import_action = require("./action.js");
var import_index = require("./index.js");
var import_logging = require("./logging.js");
var import_reflection_types = require("./reflection-types.js");
var import_schema = require("./schema.js");
var import_tracing = require("./tracing.js");
let apiIndex = 0;
class ReflectionServerV2 {
registry;
options;
ws = null;
url;
index = apiIndex++;
activeActions = /* @__PURE__ */ new Map();
reconnectCount = 0;
isStopped = false;
reconnectTimeout = null;
baseDelayMs = 500;
maxDelayMs = 5e3;
pendingRequests = /* @__PURE__ */ new Map();
requestIdCounter = 0;
constructor(registry, options) {
this.registry = registry;
this.options = {
configuredEnvs: ["dev"],
...options
};
this.url = this.options.url;
}
async start() {
this.isStopped = false;
this.reconnectCount = 0;
await this.connect();
}
async connect() {
if (this.isStopped) return;
import_logging.logger.debug(`Connecting to Reflection V2 server at ${this.url}`);
const ws = new import_ws.default(this.url);
this.ws = ws;
this.ws.on("open", async () => {
import_logging.logger.debug("Connected to Reflection V2 server.");
this.reconnectCount = 0;
await this.register();
});
this.ws.on("message", async (data) => {
try {
const message = JSON.parse(data.toString());
if ("method" in message) {
await this.handleRequest(message);
} else if ("id" in message) {
this.handleResponse(message);
}
} catch (error) {
import_logging.logger.error(`Failed to parse message: ${error}`);
}
});
this.ws.on("error", (error) => {
import_logging.logger.error(`Reflection V2 WebSocket error: ${error}`);
});
this.ws.on("close", (code, reason) => {
import_logging.logger.debug(
`Reflection V2 WebSocket closed. Code: ${code}, Reason: ${reason}`
);
for (const [id, resolver] of this.pendingRequests.entries()) {
resolver.reject(
new Error(
`Connection closed before response was received (id: ${id})`
)
);
}
this.pendingRequests.clear();
if (!this.isStopped) {
this.scheduleReconnect();
}
});
}
scheduleReconnect() {
if (this.reconnectTimeout) return;
const delay = Math.min(
this.baseDelayMs * Math.pow(2, this.reconnectCount),
this.maxDelayMs
);
this.reconnectCount++;
import_logging.logger.debug(
`Scheduling reconnection in ${delay}ms (attempt ${this.reconnectCount})`
);
this.reconnectTimeout = setTimeout(async () => {
this.reconnectTimeout = null;
await this.connect();
}, delay);
}
async stop() {
this.isStopped = true;
if (this.reconnectTimeout) {
clearTimeout(this.reconnectTimeout);
this.reconnectTimeout = null;
}
if (this.ws) {
this.ws.close();
this.ws = null;
}
}
send(message) {
if (this.ws && this.ws.readyState === import_ws.default.OPEN) {
this.ws.send(JSON.stringify(message));
}
}
sendResponse(id, result) {
this.send({
jsonrpc: "2.0",
result,
id
});
}
sendError(id, code, message, data) {
this.send({
jsonrpc: "2.0",
error: { code, message, data },
id
});
}
sendNotification(method, params) {
this.send({
jsonrpc: "2.0",
method,
params
});
}
sendRequest(method, params) {
return new Promise((resolve, reject) => {
const id = (++this.requestIdCounter).toString();
this.pendingRequests.set(id, { resolve, reject });
this.send({
jsonrpc: "2.0",
id,
method,
params
});
});
}
async register() {
const params = {
id: process.env.GENKIT_RUNTIME_ID || this.runtimeId,
pid: process.pid,
name: this.options.name || this.runtimeId,
genkitVersion: import_index.GENKIT_VERSION,
reflectionApiSpecVersion: import_index.GENKIT_REFLECTION_API_SPEC_VERSION,
envs: this.options.configuredEnvs
};
try {
const response = await this.sendRequest("register", params);
if (response && response.telemetryServerUrl) {
if (!process.env.GENKIT_TELEMETRY_SERVER) {
(0, import_tracing.setTelemetryServerUrl)(response.telemetryServerUrl);
import_logging.logger.debug(
`Connected to telemetry server on ${response.telemetryServerUrl} via handshake`
);
}
}
} catch (err) {
import_logging.logger.error(`Failed to register with CLI: ${err}`);
}
}
get runtimeId() {
return `${process.pid}${this.index ? `-${this.index}` : ""}`;
}
handleResponse(response) {
const resolver = this.pendingRequests.get(response.id);
if (!resolver) {
import_logging.logger.error(`Unknown response ID: ${response.id}`);
return;
}
this.pendingRequests.delete(response.id);
if ("error" in response) {
resolver.reject(response.error);
} else {
resolver.resolve(response.result);
}
}
async handleRequest(request) {
try {
switch (request.method) {
case "listActions":
await this.handleListActions(request);
break;
case "listValues":
await this.handleListValues(request);
break;
case "runAction":
await this.handleRunAction(request);
break;
case "configure":
this.handleConfigure(request);
break;
case "cancelAction":
await this.handleCancelAction(request);
break;
case "sendInputStreamChunk":
this.handleSendInputStreamChunk(request);
break;
case "endInputStream":
this.handleEndInputStream(request);
break;
default:
if (request.id) {
this.sendError(
request.id,
-32601,
`Method not found: ${request.method}`
);
}
}
} catch (error) {
if (request.id) {
this.sendError(request.id, -32e3, error.message, {
stack: error.stack
});
}
}
}
async handleListActions(request) {
if (!request.id) return;
const actions = await this.registry.listResolvableActions();
const convertedActions = {};
Object.keys(actions).forEach((key) => {
const action = actions[key];
convertedActions[key] = {
key,
name: action.name,
description: action.description,
metadata: action.metadata
};
if (action.inputSchema || action.inputJsonSchema) {
convertedActions[key].inputSchema = (0, import_schema.toJsonSchema)({
schema: action.inputSchema,
jsonSchema: action.inputJsonSchema
});
}
if (action.outputSchema || action.outputJsonSchema) {
convertedActions[key].outputSchema = (0, import_schema.toJsonSchema)({
schema: action.outputSchema,
jsonSchema: action.outputJsonSchema
});
}
});
this.sendResponse(request.id, {
actions: convertedActions
});
}
async handleListValues(request) {
if (!request.id) return;
const { type } = import_reflection_types.ReflectionListValuesParamsSchema.parse(request.params);
if (type !== "defaultModel" && type !== "middleware") {
this.sendError(
request.id,
-32602,
`'type' ${type} is not supported. Only 'defaultModel' and 'middleware' are supported`
);
return;
}
const values = await this.registry.listValues(type);
const mappedValues = {};
for (const [key, value] of Object.entries(values)) {
mappedValues[key] = value && typeof value === "object" && "toJson" in value && typeof value.toJson === "function" ? value.toJson() : value;
}
this.sendResponse(
request.id,
import_reflection_types.ReflectionListValuesResponseSchema.parse({ values: mappedValues })
);
}
async handleRunAction(request) {
if (!request.id) return;
const { key, input, context, telemetryLabels, stream } = import_reflection_types.ReflectionRunActionParamsSchema.parse(request.params);
const action = await this.registry.lookupAction(key);
if (!action) {
this.sendError(request.id, -32602, `action ${key} not found`);
return;
}
const abortController = new AbortController();
let traceId;
try {
const onTraceStartCallback = ({ traceId: tid }) => {
traceId = tid;
this.activeActions.set(tid, {
abortController,
startTime: /* @__PURE__ */ new Date()
});
this.sendNotification(
"runActionState",
import_reflection_types.ReflectionRunActionStateParamsSchema.parse({
requestId: request.id,
state: { traceId: tid }
})
);
};
if (stream) {
const callback = (chunk) => {
this.sendNotification(
"streamChunk",
import_reflection_types.ReflectionStreamChunkParamsSchema.parse({
requestId: request.id,
chunk
})
);
};
const result = await action.run(input, {
context,
onChunk: callback,
telemetryLabels,
onTraceStart: onTraceStartCallback,
abortSignal: abortController.signal
});
await (0, import_tracing.flushTracing)();
this.sendResponse(request.id, {
result: result.result,
telemetry: {
traceId: result.telemetry.traceId
}
});
} else {
const result = await action.run(input, {
context,
telemetryLabels,
onTraceStart: onTraceStartCallback,
abortSignal: abortController.signal
});
await (0, import_tracing.flushTracing)();
this.sendResponse(request.id, {
result: result.result,
telemetry: {
traceId: result.telemetry.traceId
}
});
}
} catch (err) {
const isAbort = err?.name === "AbortError" || typeof DOMException !== "undefined" && err instanceof DOMException && err.name === "AbortError";
const errorResponse = {
code: isAbort ? import_action.StatusCodes.CANCELLED : import_action.StatusCodes.INTERNAL,
message: isAbort ? "Action was cancelled" : err.message,
details: {
stack: err.stack
}
};
if (err.traceId || traceId) {
errorResponse.details.traceId = err.traceId || traceId;
}
this.sendError(request.id, -32e3, errorResponse.message, errorResponse);
} finally {
if (traceId) {
this.activeActions.delete(traceId);
}
}
}
handleConfigure(request) {
const { telemetryServerUrl } = import_reflection_types.ReflectionConfigureParamsSchema.parse(
request.params
);
if (telemetryServerUrl && !process.env.GENKIT_TELEMETRY_SERVER) {
(0, import_tracing.setTelemetryServerUrl)(telemetryServerUrl);
import_logging.logger.debug(`Connected to telemetry server on ${telemetryServerUrl}`);
}
}
async handleCancelAction(request) {
if (!request.id) return;
const { traceId } = import_reflection_types.ReflectionCancelActionParamsSchema.parse(
request.params
);
const activeAction = this.activeActions.get(traceId);
if (activeAction) {
activeAction.abortController.abort();
this.activeActions.delete(traceId);
this.sendResponse(request.id, { message: "Action cancelled" });
} else {
this.sendError(
request.id,
-32602,
"Action not found or already completed"
);
}
}
handleSendInputStreamChunk(request) {
throw new Error("Not implemented");
}
handleEndInputStream(request) {
throw new Error("Not implemented");
}
}
// Annotate the CommonJS export names for ESM import in node:
0 && (module.exports = {
ReflectionServerV2
});
//# sourceMappingURL=reflection-v2.js.map