UNPKG

@genkit-ai/core

Version:

Genkit AI framework core libraries.

570 lines (540 loc) 16.9 kB
/** * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import * as assert from 'assert'; import { afterEach, beforeEach, describe, it } from 'node:test'; import { WebSocket, WebSocketServer } from 'ws'; import { z } from 'zod'; import { action } from '../src/action.js'; import { initNodeFeatures } from '../src/node.js'; import { ReflectionServerV2 } from '../src/reflection-v2.js'; import { Registry } from '../src/registry.js'; initNodeFeatures(); describe('ReflectionServerV2', () => { let wss: WebSocketServer; let server: ReflectionServerV2; let registry: Registry; let port: number; let connections: WebSocket[] = []; beforeEach(() => { return new Promise<void>((resolve) => { wss = new WebSocketServer({ port: 0 }); wss.on('listening', () => { port = (wss.address() as any).port; resolve(); }); wss.on('connection', (ws) => { connections.push(ws); // Track all connections }); registry = new Registry(); }); }); afterEach(async () => { if (server) { await server.stop(); } // Terminate all connections to let wss.close() proceed for (const ws of connections) { ws.terminate(); } connections = []; await new Promise<void>((resolve) => { wss.close(() => resolve()); }); }); it('should connect to the server and register', async () => { const connected = new Promise<void>((resolve, reject) => { const timer = setTimeout( () => reject(new Error('connect timeout')), 2000 ); wss.on('connection', (ws) => { ws.on('message', (data) => { const msg = JSON.parse(data.toString()); if (msg.method === 'register') { assert.strictEqual(msg.params.name, 'test-app'); ws.send( JSON.stringify({ jsonrpc: '2.0', result: {}, id: msg.id, }) ); clearTimeout(timer); resolve(); } }); }); }); server = new ReflectionServerV2(registry, { url: `ws://localhost:${port}`, name: 'test-app', }); await server.start(); await connected; }); it('should handle listActions', async () => { // Register a dummy action const testAction = action( { name: 'testAction', description: 'A test action', inputSchema: z.object({ foo: z.string() }), outputSchema: z.object({ bar: z.string() }), actionType: 'custom', }, async (input) => ({ bar: input.foo }) ); registry.registerAction('custom', testAction); const gotListActions = new Promise<void>((resolve, reject) => { const timer = setTimeout( () => reject(new Error('listActions timeout')), 2000 ); wss.on('connection', (ws) => { ws.on('message', (data) => { const msg = JSON.parse(data.toString()); if (msg.method === 'register') { ws.send( JSON.stringify({ jsonrpc: '2.0', result: {}, id: msg.id, }) ); // After registration, request listActions ws.send( JSON.stringify({ jsonrpc: '2.0', method: 'listActions', id: '123', }) ); } else if (msg.id === '123') { const actions = msg.result.actions; assert.ok(actions['/custom/testAction']); assert.strictEqual( actions['/custom/testAction'].name, 'testAction' ); clearTimeout(timer); resolve(); } }); }); }); server = new ReflectionServerV2(registry, { url: `ws://localhost:${port}`, }); await server.start(); await gotListActions; }); it('should handle listValues', async () => { registry.registerValue('middleware', 'my-mw', { template: 'foo' }); const gotListValues = new Promise<void>((resolve, reject) => { const timer = setTimeout( () => reject(new Error('listValues timeout')), 2000 ); wss.on('connection', (ws) => { ws.on('message', (data) => { const msg = JSON.parse(data.toString()); if (msg.method === 'register') { ws.send( JSON.stringify({ jsonrpc: '2.0', result: {}, id: msg.id, }) ); ws.send( JSON.stringify({ jsonrpc: '2.0', method: 'listValues', params: { type: 'middleware' }, id: '124', }) ); } else if (msg.id === '124') { assert.ok(msg.result.values['my-mw']); assert.strictEqual(msg.result.values['my-mw'].template, 'foo'); clearTimeout(timer); resolve(); } }); }); }); server = new ReflectionServerV2(registry, { url: `ws://localhost:${port}`, }); await server.start(); await gotListValues; }); it('should handle listValues with toJson mapping', async () => { registry.registerValue('middleware', 'mw1', { toJson: () => ({ name: 'mw1' }), }); registry.registerValue('middleware', 'mw2', { name: 'mw2', }); const gotListValues = new Promise<void>((resolve, reject) => { const timer = setTimeout( () => reject(new Error('listValues toJson timeout')), 2000 ); wss.on('connection', (ws) => { ws.on('message', (data) => { const msg = JSON.parse(data.toString()); if (msg.method === 'register') { ws.send( JSON.stringify({ jsonrpc: '2.0', result: {}, id: msg.id, }) ); ws.send( JSON.stringify({ jsonrpc: '2.0', method: 'listValues', params: { type: 'middleware' }, id: '125', }) ); } else if (msg.id === '125') { assert.ok(msg.result.values['mw1']); assert.strictEqual(msg.result.values['mw1'].name, 'mw1'); assert.ok(msg.result.values['mw2']); assert.strictEqual(msg.result.values['mw2'].name, 'mw2'); clearTimeout(timer); resolve(); } }); }); }); server = new ReflectionServerV2(registry, { url: `ws://localhost:${port}`, }); await server.start(); await gotListValues; }); it('should reject unsupported type parameter for listValues in V2', async () => { const gotError = new Promise<void>((resolve, reject) => { const timer = setTimeout( () => reject(new Error('listValues error timeout')), 2000 ); wss.on('connection', (ws) => { ws.on('message', (data) => { const msg = JSON.parse(data.toString()); if (msg.method === 'register') { ws.send( JSON.stringify({ jsonrpc: '2.0', result: {}, id: msg.id, }) ); ws.send( JSON.stringify({ jsonrpc: '2.0', method: 'listValues', params: { type: 'unsupported_type' }, id: '126', }) ); } else if (msg.id === '126') { assert.ok(msg.error); assert.strictEqual(msg.error.code, -32602); assert.match(msg.error.message, /is not supported/); clearTimeout(timer); resolve(); } }); }); }); server = new ReflectionServerV2(registry, { url: `ws://localhost:${port}`, }); await server.start(); await gotError; }); it('should handle runAction', async () => { const testAction = action( { name: 'testAction', inputSchema: z.object({ foo: z.string() }), outputSchema: z.object({ bar: z.string() }), actionType: 'custom', }, async (input) => ({ bar: input.foo }) ); registry.registerAction('custom', testAction); const actionRun = new Promise<void>((resolve, reject) => { const timeout = setTimeout( () => reject(new Error('runAction timeout')), 2000 ); wss.on('connection', (ws) => { ws.on('message', (data) => { try { const msg = JSON.parse(data.toString()); if (msg.method === 'register') { ws.send( JSON.stringify({ jsonrpc: '2.0', result: {}, id: msg.id, }) ); ws.send( JSON.stringify({ jsonrpc: '2.0', method: 'runAction', params: { key: '/custom/testAction', input: { foo: 'baz' }, }, id: '456', }) ); } else if (msg.id === '456') { if (msg.error) { reject( new Error(`runAction error: ${JSON.stringify(msg.error)}`) ); return; } assert.strictEqual(msg.result.result.bar, 'baz'); clearTimeout(timeout); resolve(); } } catch (e) { clearTimeout(timeout); reject(e); } }); }); }); server = new ReflectionServerV2(registry, { url: `ws://localhost:${port}`, }); await server.start(); await actionRun; }); it('should handle streaming runAction', async () => { const streamAction = action( { name: 'streamAction', inputSchema: z.object({ foo: z.string() }), outputSchema: z.string(), actionType: 'custom', }, async (input, { sendChunk }) => { sendChunk('chunk1'); sendChunk('chunk2'); return 'done'; } ); registry.registerAction('custom', streamAction); const chunks: any[] = []; const actionRun = new Promise<void>((resolve, reject) => { const timeout = setTimeout( () => reject(new Error('streamAction timeout')), 2000 ); wss.on('connection', (ws) => { ws.on('message', (data) => { try { const msg = JSON.parse(data.toString()); if (msg.method === 'register') { ws.send( JSON.stringify({ jsonrpc: '2.0', result: {}, id: msg.id, }) ); ws.send( JSON.stringify({ jsonrpc: '2.0', method: 'runAction', params: { key: '/custom/streamAction', input: { foo: 'baz' }, stream: true, }, id: '789', }) ); } else if (msg.method === 'streamChunk') { chunks.push(msg.params.chunk); } else if (msg.id === '789') { if (msg.error) { reject( new Error(`streamAction error: ${JSON.stringify(msg.error)}`) ); return; } assert.strictEqual(msg.result.result, 'done'); assert.deepStrictEqual(chunks, ['chunk1', 'chunk2']); clearTimeout(timeout); resolve(); } } catch (e) { clearTimeout(timeout); reject(e); } }); }); }); server = new ReflectionServerV2(registry, { url: `ws://localhost:${port}`, }); await server.start(); await actionRun; }); it('should handle cancelAction', async () => { let cancelSignal: AbortSignal | undefined; const longAction = action( { name: 'longAction', inputSchema: z.any(), outputSchema: z.any(), actionType: 'custom', }, async (_, { abortSignal }) => { cancelSignal = abortSignal; await new Promise((resolve, reject) => { const timer = setTimeout(resolve, 5000); if (abortSignal.aborted) { clearTimeout(timer); reject(new Error('Action cancelled')); return; } abortSignal.addEventListener('abort', () => { clearTimeout(timer); reject(new Error('Action cancelled')); }); }); } ); registry.registerAction('custom', longAction); const actionCancelled = new Promise<void>((resolve, reject) => { const timeout = setTimeout( () => reject(new Error('cancelAction timeout')), 2000 ); wss.on('connection', (ws) => { ws.on('message', (data) => { try { const msg = JSON.parse(data.toString()); if (msg.method === 'register') { ws.send( JSON.stringify({ jsonrpc: '2.0', result: {}, id: msg.id, }) ); // Start action ws.send( JSON.stringify({ jsonrpc: '2.0', method: 'runAction', params: { key: '/custom/longAction', input: {}, }, id: '999', }) ); } else if (msg.method === 'runActionState') { // Got traceId, send cancel const traceId = msg.params.state.traceId; ws.send( JSON.stringify({ jsonrpc: '2.0', method: 'cancelAction', params: { traceId }, id: '1000', }) ); } else if (msg.id === '1000') { // Cancel response assert.strictEqual(msg.result.message, 'Action cancelled'); } else if (msg.id === '999') { // Run action response (should be error) if (msg.error) { // Ensure code indicates cancellation if possible, or just error // In implementation we send code -32000 and message 'Action was cancelled' assert.match(msg.error.message, /cancelled/); assert.ok(cancelSignal?.aborted); clearTimeout(timeout); resolve(); } else { reject(new Error('Action should have failed')); } } } catch (e) { clearTimeout(timeout); reject(e); } }); }); }); server = new ReflectionServerV2(registry, { url: `ws://localhost:${port}`, }); await server.start(); await actionCancelled; }); it('should reconnect when lost connection and register again', async () => { let connectionCount = 0; const reconnected = new Promise<void>((resolve, reject) => { const timeout = setTimeout( () => reject(new Error('reconnect timeout')), 5000 ); wss.on('connection', (ws) => { connectionCount++; ws.on('message', (data) => { const msg = JSON.parse(data.toString()); if (msg.method === 'register') { if (connectionCount === 1) { ws.terminate(); // Simulate server drop } else if (connectionCount === 2) { ws.send( JSON.stringify({ jsonrpc: '2.0', result: {}, id: msg.id, }) ); clearTimeout(timeout); resolve(); } } }); }); }); server = new ReflectionServerV2(registry, { url: `ws://localhost:${port}`, }); (server as any).baseDelayMs = 10; // Fast for testing await server.start(); await reconnected; }); });