@just-every/ensemble
Version:
LLM provider abstraction layer with unified streaming interface
206 lines • 6.8 kB
JavaScript
import { vi } from 'vitest';
export class EnhancedRequestMock {
responses;
callIndex = 0;
constructor(responses) {
this.responses = Array.isArray(responses) ? responses : [responses];
}
getMock() {
return (model, messages, options) => {
return this.createAsyncGenerator(options);
};
}
async *createAsyncGenerator(options) {
for (const response of this.responses) {
if (response.delay) {
await new Promise(resolve => setTimeout(resolve, response.delay));
}
if (response.error) {
const error = typeof response.error === 'string' ? new Error(response.error) : response.error;
yield {
type: 'error',
error: error.message,
timestamp: new Date().toISOString(),
};
return;
}
if (response.thinking && options?.includeThinking) {
yield {
type: 'thinking_start',
timestamp: new Date().toISOString(),
};
const chunks = response.thinking.match(/.{1,10}/g) || [];
for (const chunk of chunks) {
yield {
type: 'thinking_delta',
delta: chunk,
timestamp: new Date().toISOString(),
};
}
yield {
type: 'thinking_complete',
content: response.thinking,
timestamp: new Date().toISOString(),
};
}
if (response.message) {
yield {
type: 'message_start',
timestamp: new Date().toISOString(),
};
const chunks = response.message.match(/.{1,5}/g) || [];
for (const chunk of chunks) {
yield {
type: 'text_delta',
delta: chunk,
timestamp: new Date().toISOString(),
};
}
yield {
type: 'message_complete',
content: response.message,
timestamp: new Date().toISOString(),
};
}
if (response.toolCalls && response.toolCalls.length > 0) {
const toolCallEvents = response.toolCalls.map((call, index) => ({
id: `call_${Date.now()}_${index}`,
type: 'function',
function: {
name: call.name,
arguments: JSON.stringify(call.arguments),
},
}));
yield {
type: 'tool_start',
tool_call: toolCallEvents,
timestamp: new Date().toISOString(),
};
if (options?.onToolCall) {
for (const call of toolCallEvents) {
options.onToolCall(call);
}
}
}
}
yield {
type: 'stream_end',
timestamp: new Date().toISOString(),
};
}
static success(message = 'Success', result = 'Task completed') {
return new EnhancedRequestMock({
message,
toolCalls: [{ name: 'task_complete', arguments: { result } }],
});
}
static error(message = 'Error occurred', error = 'Task failed') {
return new EnhancedRequestMock({
message,
toolCalls: [{ name: 'task_fatal_error', arguments: { error } }],
});
}
static throws(error) {
return new EnhancedRequestMock({
error: typeof error === 'string' ? new Error(error) : error,
});
}
static thinking(thinking, message) {
return new EnhancedRequestMock({
thinking,
message,
});
}
static toolCalls(...calls) {
return new EnhancedRequestMock({
message: '',
toolCalls: calls,
});
}
static sequence(...responses) {
return new EnhancedRequestMock(responses);
}
}
export function createMockContext(overrides = {}) {
return {
shouldContinue: true,
metadata: {},
toolCallCount: 0,
turnCount: 0,
startTime: Date.now(),
messages: [],
isPaused: false,
isHalted: false,
halt: vi.fn(function () {
this.shouldContinue = false;
this.isHalted = true;
}),
pause: vi.fn(function () {
this.isPaused = true;
}),
resume: vi.fn(function () {
this.isPaused = false;
}),
setMetadata: vi.fn(function (key, value) {
this.metadata[key] = value;
}),
getMetadata: vi.fn(function (key) {
return this.metadata[key];
}),
addMessage: vi.fn(function (message) {
this.messages.push(message);
}),
getHistory: vi.fn(function () {
return this.messages;
}),
...overrides,
};
}
export class StreamAssertions {
events = [];
constructor(eventGenerator) {
(async () => {
for await (const event of eventGenerator) {
this.events.push(event);
}
})();
}
async waitForCompletion() {
await new Promise(resolve => setTimeout(resolve, 100));
}
hasEvent(type) {
return this.events.some(e => e.type === type);
}
getEvents(type) {
return this.events.filter(e => e.type === type);
}
hasToolCall(name) {
const toolEvents = this.getEvents('tool_start');
return toolEvents.some(event => {
if ('tool_call' in event && event.tool_call) {
return event.tool_call.function.name === name;
}
return false;
});
}
getFinalMessage() {
const messageEvents = this.getEvents('message_complete');
if (messageEvents.length > 0) {
const lastEvent = messageEvents[messageEvents.length - 1];
return 'content' in lastEvent ? lastEvent.content : undefined;
}
return undefined;
}
hasError() {
return this.hasEvent('error');
}
getError() {
const errorEvents = this.getEvents('error');
if (errorEvents.length > 0) {
const errorEvent = errorEvents[0];
return 'error' in errorEvent ? errorEvent.error : undefined;
}
return undefined;
}
}
//# sourceMappingURL=test_utils.js.map