@genkit-ai/core
Version:
Genkit AI framework core libraries.
515 lines (464 loc) • 14.3 kB
text/typescript
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import WebSocket from 'ws';
import { StatusCodes, type Status } from './action.js';
import { GENKIT_REFLECTION_API_SPEC_VERSION, GENKIT_VERSION } from './index.js';
import { logger } from './logging.js';
import {
ReflectionCancelActionParamsSchema,
ReflectionConfigureParamsSchema,
ReflectionListActionsResponse,
ReflectionListValuesParamsSchema,
ReflectionListValuesResponseSchema,
ReflectionRegisterParams,
ReflectionRunActionParamsSchema,
ReflectionRunActionStateParamsSchema,
ReflectionStreamChunkParamsSchema,
} from './reflection-types.js';
import type { Registry } from './registry.js';
import { toJsonSchema } from './schema.js';
import { flushTracing, setTelemetryServerUrl } from './tracing.js';
let apiIndex = 0;
interface JsonRpcRequest {
jsonrpc: '2.0';
method: string;
params?: any;
id?: string;
}
interface JsonRpcResponse {
jsonrpc: '2.0';
result?: any;
error?: {
code: number;
message: string;
data?: any;
};
id: string;
}
type JsonRpcMessage = JsonRpcRequest | JsonRpcResponse;
export interface ReflectionServerV2Options {
configuredEnvs?: string[];
name?: string;
url: string;
}
export class ReflectionServerV2 {
private registry: Registry;
private options: ReflectionServerV2Options;
private ws: WebSocket | null = null;
private url: string;
private index = apiIndex++;
private activeActions = new Map<
string,
{
abortController: AbortController;
startTime: Date;
}
>();
private reconnectCount = 0;
private isStopped = false;
private reconnectTimeout: NodeJS.Timeout | null = null;
private baseDelayMs = 500;
private maxDelayMs = 5000;
private pendingRequests = new Map<
string,
{
resolve: (value: any) => void;
reject: (reason?: any) => void;
}
>();
private requestIdCounter = 0;
constructor(registry: Registry, options: ReflectionServerV2Options) {
this.registry = registry;
this.options = {
configuredEnvs: ['dev'],
...options,
};
// The URL should be provided via environment variable by the CLI manager
this.url = this.options.url;
}
async start() {
this.isStopped = false;
this.reconnectCount = 0;
await this.connect();
}
private async connect() {
if (this.isStopped) return;
logger.debug(`Connecting to Reflection V2 server at ${this.url}`);
const ws = new WebSocket(this.url);
this.ws = ws;
this.ws.on('open', async () => {
logger.debug('Connected to Reflection V2 server.');
this.reconnectCount = 0;
await this.register();
});
this.ws.on('message', async (data) => {
try {
const message = JSON.parse(data.toString()) as any;
if ('method' in message) {
await this.handleRequest(message);
} else if ('id' in message) {
this.handleResponse(message);
}
} catch (error) {
logger.error(`Failed to parse message: ${error}`);
}
});
this.ws.on('error', (error) => {
logger.error(`Reflection V2 WebSocket error: ${error}`);
});
this.ws.on('close', (code, reason) => {
logger.debug(
`Reflection V2 WebSocket closed. Code: ${code}, Reason: ${reason}`
);
for (const [id, resolver] of this.pendingRequests.entries()) {
resolver.reject(
new Error(
`Connection closed before response was received (id: ${id})`
)
);
}
this.pendingRequests.clear();
if (!this.isStopped) {
this.scheduleReconnect();
}
});
}
private scheduleReconnect() {
if (this.reconnectTimeout) return;
const delay = Math.min(
this.baseDelayMs * Math.pow(2, this.reconnectCount),
this.maxDelayMs
);
this.reconnectCount++;
logger.debug(
`Scheduling reconnection in ${delay}ms (attempt ${this.reconnectCount})`
);
this.reconnectTimeout = setTimeout(async () => {
this.reconnectTimeout = null;
await this.connect();
}, delay);
}
async stop() {
this.isStopped = true;
if (this.reconnectTimeout) {
clearTimeout(this.reconnectTimeout);
this.reconnectTimeout = null;
}
if (this.ws) {
this.ws.close();
this.ws = null;
}
}
private send(message: JsonRpcMessage) {
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
this.ws.send(JSON.stringify(message));
}
}
private sendResponse(id: string, result: any) {
this.send({
jsonrpc: '2.0',
result,
id,
});
}
private sendError(id: string, code: number, message: string, data?: any) {
this.send({
jsonrpc: '2.0',
error: { code, message, data },
id,
});
}
private sendNotification(method: string, params: any) {
this.send({
jsonrpc: '2.0',
method,
params,
});
}
private sendRequest(method: string, params: any): Promise<any> {
return new Promise((resolve, reject) => {
const id = (++this.requestIdCounter).toString();
this.pendingRequests.set(id, { resolve, reject });
this.send({
jsonrpc: '2.0',
id,
method,
params,
});
});
}
private async register() {
const params: ReflectionRegisterParams = {
id: process.env.GENKIT_RUNTIME_ID || this.runtimeId,
pid: process.pid,
name: this.options.name || this.runtimeId,
genkitVersion: GENKIT_VERSION,
reflectionApiSpecVersion: GENKIT_REFLECTION_API_SPEC_VERSION,
envs: this.options.configuredEnvs,
};
try {
const response = await this.sendRequest('register', params);
if (response && response.telemetryServerUrl) {
if (!process.env.GENKIT_TELEMETRY_SERVER) {
setTelemetryServerUrl(response.telemetryServerUrl);
logger.debug(
`Connected to telemetry server on ${response.telemetryServerUrl} via handshake`
);
}
}
} catch (err) {
logger.error(`Failed to register with CLI: ${err}`);
}
}
get runtimeId() {
return `${process.pid}${this.index ? `-${this.index}` : ''}`;
}
private handleResponse(response: any) {
const resolver = this.pendingRequests.get(response.id);
if (!resolver) {
logger.error(`Unknown response ID: ${response.id}`);
return;
}
this.pendingRequests.delete(response.id);
if ('error' in response) {
resolver.reject(response.error);
} else {
resolver.resolve(response.result);
}
}
private async handleRequest(request: JsonRpcRequest) {
try {
switch (request.method) {
case 'listActions':
await this.handleListActions(request);
break;
case 'listValues':
await this.handleListValues(request);
break;
case 'runAction':
await this.handleRunAction(request);
break;
case 'configure':
this.handleConfigure(request);
break;
case 'cancelAction':
await this.handleCancelAction(request);
break;
case 'sendInputStreamChunk':
this.handleSendInputStreamChunk(request);
break;
case 'endInputStream':
this.handleEndInputStream(request);
break;
default:
if (request.id) {
this.sendError(
request.id,
-32601,
`Method not found: ${request.method}`
);
}
}
} catch (error: any) {
if (request.id) {
this.sendError(request.id, -32000, error.message, {
stack: error.stack,
});
}
}
}
private async handleListActions(request: JsonRpcRequest) {
if (!request.id) return; // Should be a request
const actions = await this.registry.listResolvableActions();
const convertedActions: Record<string, any> = {};
Object.keys(actions).forEach((key) => {
const action = actions[key];
convertedActions[key] = {
key,
name: action.name,
description: action.description,
metadata: action.metadata,
};
if (action.inputSchema || action.inputJsonSchema) {
convertedActions[key].inputSchema = toJsonSchema({
schema: action.inputSchema,
jsonSchema: action.inputJsonSchema,
});
}
if (action.outputSchema || action.outputJsonSchema) {
convertedActions[key].outputSchema = toJsonSchema({
schema: action.outputSchema,
jsonSchema: action.outputJsonSchema,
});
}
});
this.sendResponse(request.id, <ReflectionListActionsResponse>{
actions: convertedActions,
});
}
private async handleListValues(request: JsonRpcRequest) {
if (!request.id) return;
const { type } = ReflectionListValuesParamsSchema.parse(request.params);
if (type !== 'defaultModel' && type !== 'middleware') {
this.sendError(
request.id,
-32602,
`'type' ${type} is not supported. Only 'defaultModel' and 'middleware' are supported`
);
return;
}
const values = await this.registry.listValues(type);
const mappedValues: Record<string, any> = {};
for (const [key, value] of Object.entries(values)) {
mappedValues[key] =
value &&
typeof value === 'object' &&
'toJson' in value &&
typeof (value as any).toJson === 'function'
? (value as any).toJson()
: value;
}
this.sendResponse(
request.id,
ReflectionListValuesResponseSchema.parse({ values: mappedValues })
);
}
private async handleRunAction(request: JsonRpcRequest) {
if (!request.id) return;
const { key, input, context, telemetryLabels, stream } =
ReflectionRunActionParamsSchema.parse(request.params);
const action = await this.registry.lookupAction(key);
if (!action) {
this.sendError(request.id, -32602, `action ${key} not found`);
return;
}
const abortController = new AbortController();
let traceId: string | undefined;
try {
const onTraceStartCallback = ({ traceId: tid }: { traceId: string }) => {
traceId = tid;
this.activeActions.set(tid, {
abortController,
startTime: new Date(),
});
// Send early trace ID notification
this.sendNotification(
'runActionState',
ReflectionRunActionStateParamsSchema.parse({
requestId: request.id,
state: { traceId: tid },
})
);
};
if (stream) {
const callback = (chunk: any) => {
this.sendNotification(
'streamChunk',
ReflectionStreamChunkParamsSchema.parse({
requestId: request.id,
chunk,
})
);
};
const result = await action.run(input, {
context,
onChunk: callback,
telemetryLabels,
onTraceStart: onTraceStartCallback,
abortSignal: abortController.signal,
});
await flushTracing();
// Send final result
this.sendResponse(request.id, {
result: result.result,
telemetry: {
traceId: result.telemetry.traceId,
},
});
} else {
const result = await action.run(input, {
context,
telemetryLabels,
onTraceStart: onTraceStartCallback,
abortSignal: abortController.signal,
});
await flushTracing();
this.sendResponse(request.id, {
result: result.result,
telemetry: {
traceId: result.telemetry.traceId,
},
});
}
} catch (err: any) {
const isAbort =
err?.name === 'AbortError' ||
(typeof DOMException !== 'undefined' &&
err instanceof DOMException &&
err.name === 'AbortError');
const errorResponse: Status = {
code: isAbort ? StatusCodes.CANCELLED : StatusCodes.INTERNAL,
message: isAbort ? 'Action was cancelled' : err.message,
details: {
stack: err.stack,
},
};
if (err.traceId || traceId) {
errorResponse.details.traceId = err.traceId || traceId;
}
this.sendError(request.id, -32000, errorResponse.message, errorResponse);
} finally {
if (traceId) {
this.activeActions.delete(traceId);
}
}
}
private handleConfigure(request: JsonRpcRequest) {
const { telemetryServerUrl } = ReflectionConfigureParamsSchema.parse(
request.params
);
if (telemetryServerUrl && !process.env.GENKIT_TELEMETRY_SERVER) {
setTelemetryServerUrl(telemetryServerUrl);
logger.debug(`Connected to telemetry server on ${telemetryServerUrl}`);
}
}
private async handleCancelAction(request: JsonRpcRequest) {
if (!request.id) return;
const { traceId } = ReflectionCancelActionParamsSchema.parse(
request.params
);
const activeAction = this.activeActions.get(traceId);
if (activeAction) {
activeAction.abortController.abort();
this.activeActions.delete(traceId);
this.sendResponse(request.id, { message: 'Action cancelled' });
} else {
this.sendError(
request.id,
-32602,
'Action not found or already completed'
);
}
}
private handleSendInputStreamChunk(request: JsonRpcRequest) {
// ReflectionSendInputStreamChunkParamsSchema.parse(request.params);
throw new Error('Not implemented');
}
private handleEndInputStream(request: JsonRpcRequest) {
// ReflectionEndInputStreamParamsSchema.parse(request.params);
throw new Error('Not implemented');
}
}