UNPKG

cmte

Version:

Design by Committee™ except it's just you and LLMs

270 lines (248 loc) 8.96 kB
import { describe, it, expect, beforeEach, vi } from 'vitest'; import { BaseLLMClient } from "../base-llm-client.js"; import { TextEncoder, TextDecoder } from 'util'; // Mock logger vi.mock('../../../utils/logger.js', () => ({ logger: { debug: vi.fn(), info: vi.fn(), warn: vi.fn(), error: vi.fn() } })); // Mock TextEncoder/Decoder for Node environment global.TextEncoder = TextEncoder; global.TextDecoder = TextDecoder; // Create a concrete implementation for testing class TestLLMClient extends BaseLLMClient { constructor(config) { super(config); } async healthCheck() { return true; } async completeMessages(messages, config) { this.lastConfig = this.getMergedConfig(config); return 'test response'; } async completePrompt(prompt, config) { return this.completeMessages([{ role: 'user', content: prompt }], config); } // Expose protected methods for testing async testHandleStreamingResponse(response) { return this.handleStreamingResponse(response); } async testWithExponentialBackoff(fn) { return this.withExponentialBackoff(fn); } testGetMergedConfig(config) { return this.getMergedConfig(config); } } describe('BaseLLMClient', () => { let client; beforeEach(() => { client = new TestLLMClient({ provider: 'anthropic', model: 'test-model', haiku: false, apiDryRun: false }); }); describe('handleStreamingResponse', () => { it('should handle streaming responses correctly', async () => { const chunks = ['Hello', ' world']; const encoder = new TextEncoder(); let chunkIndex = 0; const mockResponse = { body: { getReader: () => ({ read: async () => { if (chunkIndex >= chunks.length) { return { done: true, value: undefined }; } return { done: false, value: encoder.encode(chunks[chunkIndex++]) }; }, releaseLock: () => {} }) } }; const result = await client.testHandleStreamingResponse(mockResponse); expect(result).toBe('Hello world'); }); it('should throw error when response body is null', async () => { const mockResponse = { body: null }; await expect(client.testHandleStreamingResponse(mockResponse)).rejects.toThrow('Response body is null'); }); }); describe('withExponentialBackoff', () => { beforeEach(() => { vi.useFakeTimers(); }); afterEach(() => { vi.useRealTimers(); }); it('should retry failed operations', async () => { const fn = vi.fn() .mockRejectedValueOnce(new Error('First failure')) .mockRejectedValueOnce(new Error('Second failure')) .mockResolvedValueOnce('success'); const promise = client.testWithExponentialBackoff(fn); // Fast-forward through all timeouts await vi.runAllTimersAsync(); const result = await promise; expect(result).toBe('success'); expect(fn).toHaveBeenCalledTimes(3); }); it('should throw after max retries', async () => { const fn = vi.fn().mockRejectedValue(new Error('Always fails')); client.maxRetries = 3; // Explicitly set for this test case // Wrap the call in a promise to wait for potential timers const promise = expect(client.withExponentialBackoff(fn)).rejects.toThrow('Always fails'); // Run timers to simulate delays await vi.runAllTimersAsync(); // Wait for the promise to settle await promise; // Check if the function was called the expected number of times (maxRetries) expect(fn).toHaveBeenCalledTimes(3); }); it('should not retry 4xx errors except 429', async () => { const error = new Response(null, { status: 400 }); const fn = vi.fn().mockRejectedValue(error); const promise = client.testWithExponentialBackoff(fn).catch(err => { // We expect this error, so we'll handle it here expect(err).toBe(error); return null; }); // Fast-forward through all timeouts await vi.runAllTimersAsync(); await promise; expect(fn).toHaveBeenCalledTimes(1); }); it('should retry 429 errors', async () => { const error429 = new Response(null, { status: 429 }); const fn = vi.fn() .mockRejectedValueOnce(error429) .mockResolvedValueOnce('success'); const promise = client.testWithExponentialBackoff(fn); // Fast-forward through all timeouts await vi.runAllTimersAsync(); const result = await promise; expect(result).toBe('success'); expect(fn).toHaveBeenCalledTimes(2); }); }); describe('getMergedConfig', () => { it('should use default values when no config provided', () => { const config = client.testGetMergedConfig(); expect(config).toEqual({ temperature: 0.7, maxTokens: 10000, maxParallelRequests: 10, model: 'test-model', topP: 1 }); }); it('should merge provided config with defaults', () => { const config = client.testGetMergedConfig({ temperature: 0.5, maxTokens: 1000 }); expect(config).toEqual({ temperature: 0.5, maxTokens: 1000, maxParallelRequests: 10, model: 'test-model', topP: 1 }); }); it('should override model from base config', () => { const config = client.testGetMergedConfig({ model: 'override-model' }); expect(config.model).toBe('override-model'); }); }); describe('enqueueRequest', () => { let client; const MAX_PARALLEL = 3; // Use a small number for testing beforeEach(() => { client = new BaseLLMClient({}); // Override the default for testing client.maxParallelRequests = MAX_PARALLEL; // Also set the initial value for the adaptive limit used by the queue client._currentConcurrencyLimit = MAX_PARALLEL; // Initialize with minimum delay for testing predictable timing client._minInterRequestDelayMs = 1; // Use 1ms to avoid potential zero-delay issues client._interRequestDelayMs = client._minInterRequestDelayMs; client.activeRequests = 0; client.requestQueue = []; }); it('should limit concurrent requests to maxParallelRequests', async () => { let currentConcurrent = 0; let maxObservedConcurrent = 0; const requestDuration = 50; // ms const totalRequests = MAX_PARALLEL * 3; // More requests than concurrency limit const promises = []; const createMockRequestFn = (id) => { return async () => { currentConcurrent++; maxObservedConcurrent = Math.max(maxObservedConcurrent, currentConcurrent); // Simulate work await new Promise(resolve => setTimeout(resolve, requestDuration)); currentConcurrent--; return `Result ${id}`; }; }; // Enqueue all requests for (let i = 0; i < totalRequests; i++) { promises.push(client.enqueueRequest(createMockRequestFn(i))); } // Wait for all requests to complete const results = await Promise.all(promises); // Verify all requests completed expect(results.length).toBe(totalRequests); // Verify the concurrency limit was respected expect(maxObservedConcurrent).toBe(MAX_PARALLEL); // Verify the queue was emptied and counter reset expect(client.requestQueue.length).toBe(0); expect(client.activeRequests).toBe(0); }); it('should process queued requests when slots become available', async () => { const requestDuration = 50; const totalRequests = MAX_PARALLEL + 2; // A few more than the limit let completedCount = 0; const createMockRequestFn = (id) => { return async () => { await new Promise(resolve => setTimeout(resolve, requestDuration)); completedCount++; return `Result ${id}`; }; }; const promises = []; for (let i = 0; i < totalRequests; i++) { promises.push(client.enqueueRequest(createMockRequestFn(i))); } // Wait a bit longer than one batch duration, but less than total sequential duration // This ensures the queue had time to process some items after the first batch finished. await new Promise(resolve => setTimeout(resolve, requestDuration * 1.5)); // Check that *some* but not *all* requests are done (meaning queuing worked) expect(completedCount).toBeGreaterThan(MAX_PARALLEL -1); expect(completedCount).toBeLessThan(totalRequests); // Wait for all to finish await Promise.all(promises); expect(completedCount).toBe(totalRequests); }); }); });