UNPKG

@just-every/ensemble

Version:

LLM provider abstraction layer with unified streaming interface

206 lines 6.8 kB
import { vi } from 'vitest'; export class EnhancedRequestMock { responses; callIndex = 0; constructor(responses) { this.responses = Array.isArray(responses) ? responses : [responses]; } getMock() { return (model, messages, options) => { return this.createAsyncGenerator(options); }; } async *createAsyncGenerator(options) { for (const response of this.responses) { if (response.delay) { await new Promise(resolve => setTimeout(resolve, response.delay)); } if (response.error) { const error = typeof response.error === 'string' ? new Error(response.error) : response.error; yield { type: 'error', error: error.message, timestamp: new Date().toISOString(), }; return; } if (response.thinking && options?.includeThinking) { yield { type: 'thinking_start', timestamp: new Date().toISOString(), }; const chunks = response.thinking.match(/.{1,10}/g) || []; for (const chunk of chunks) { yield { type: 'thinking_delta', delta: chunk, timestamp: new Date().toISOString(), }; } yield { type: 'thinking_complete', content: response.thinking, timestamp: new Date().toISOString(), }; } if (response.message) { yield { type: 'message_start', timestamp: new Date().toISOString(), }; const chunks = response.message.match(/.{1,5}/g) || []; for (const chunk of chunks) { yield { type: 'text_delta', delta: chunk, timestamp: new Date().toISOString(), }; } yield { type: 'message_complete', content: response.message, timestamp: new Date().toISOString(), }; } if (response.toolCalls && response.toolCalls.length > 0) { const toolCallEvents = response.toolCalls.map((call, index) => ({ id: `call_${Date.now()}_${index}`, type: 'function', function: { name: call.name, arguments: JSON.stringify(call.arguments), }, })); yield { type: 'tool_start', tool_call: toolCallEvents, timestamp: new Date().toISOString(), }; if (options?.onToolCall) { for (const call of toolCallEvents) { options.onToolCall(call); } } } } yield { type: 'stream_end', timestamp: new Date().toISOString(), }; } static success(message = 'Success', result = 'Task completed') { return new EnhancedRequestMock({ message, toolCalls: [{ name: 'task_complete', arguments: { result } }], }); } static error(message = 'Error occurred', error = 'Task failed') { return new EnhancedRequestMock({ message, toolCalls: [{ name: 'task_fatal_error', arguments: { error } }], }); } static throws(error) { return new EnhancedRequestMock({ error: typeof error === 'string' ? new Error(error) : error, }); } static thinking(thinking, message) { return new EnhancedRequestMock({ thinking, message, }); } static toolCalls(...calls) { return new EnhancedRequestMock({ message: '', toolCalls: calls, }); } static sequence(...responses) { return new EnhancedRequestMock(responses); } } export function createMockContext(overrides = {}) { return { shouldContinue: true, metadata: {}, toolCallCount: 0, turnCount: 0, startTime: Date.now(), messages: [], isPaused: false, isHalted: false, halt: vi.fn(function () { this.shouldContinue = false; this.isHalted = true; }), pause: vi.fn(function () { this.isPaused = true; }), resume: vi.fn(function () { this.isPaused = false; }), setMetadata: vi.fn(function (key, value) { this.metadata[key] = value; }), getMetadata: vi.fn(function (key) { return this.metadata[key]; }), addMessage: vi.fn(function (message) { this.messages.push(message); }), getHistory: vi.fn(function () { return this.messages; }), ...overrides, }; } export class StreamAssertions { events = []; constructor(eventGenerator) { (async () => { for await (const event of eventGenerator) { this.events.push(event); } })(); } async waitForCompletion() { await new Promise(resolve => setTimeout(resolve, 100)); } hasEvent(type) { return this.events.some(e => e.type === type); } getEvents(type) { return this.events.filter(e => e.type === type); } hasToolCall(name) { const toolEvents = this.getEvents('tool_start'); return toolEvents.some(event => { if ('tool_call' in event && event.tool_call) { return event.tool_call.function.name === name; } return false; }); } getFinalMessage() { const messageEvents = this.getEvents('message_complete'); if (messageEvents.length > 0) { const lastEvent = messageEvents[messageEvents.length - 1]; return 'content' in lastEvent ? lastEvent.content : undefined; } return undefined; } hasError() { return this.hasEvent('error'); } getError() { const errorEvents = this.getEvents('error'); if (errorEvents.length > 0) { const errorEvent = errorEvents[0]; return 'error' in errorEvent ? errorEvent.error : undefined; } return undefined; } } //# sourceMappingURL=test_utils.js.map