@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
251 lines (211 loc) • 9.66 kB
text/typescript
import {
AIModel,
Completion,
CompletionChunkObject,
DriverOptions,
EmbeddingsOptions,
EmbeddingsResult,
ExecutionOptions,
ModelSearchPayload,
PromptRole,
PromptSegment,
} from '@llumiverse/common';
import { AbstractDriver } from '@llumiverse/core';
import { beforeEach, describe, expect, it } from 'vitest';
import { BedrockDriver } from './index.js';
// ---------------------------------------------------------------------------
// Unit tests: getExtractedStream tool use handling
// ---------------------------------------------------------------------------
describe('BedrockDriver getExtractedStream — tool use', () => {
let driver: BedrockDriver;
let toolBlocks: Map<number, { id: string; name: string }>;
beforeEach(() => {
driver = new BedrockDriver({ region: 'us-east-1' });
toolBlocks = new Map();
});
it('emits an initial tool_use chunk on contentBlockStart', () => {
const chunk = driver['getExtractedStream'](
{
contentBlockStart: {
contentBlockIndex: 1,
start: { toolUse: { toolUseId: 'tool-abc', name: 'my_tool' } },
},
},
undefined,
undefined,
toolBlocks
);
expect(chunk.tool_use).toHaveLength(1);
expect(chunk.tool_use![0]).toMatchObject({ id: 'tool-abc', tool_name: 'my_tool', tool_input: '' });
expect(toolBlocks.get(1)).toEqual({ id: 'tool-abc', name: 'my_tool' });
});
it('emits a delta tool_use chunk on contentBlockDelta', () => {
toolBlocks.set(1, { id: 'tool-abc', name: 'my_tool' });
const chunk = driver['getExtractedStream'](
{
contentBlockDelta: {
contentBlockIndex: 1,
delta: { toolUse: { input: '{"key":' } },
},
},
undefined,
undefined,
toolBlocks
);
expect(chunk.tool_use).toHaveLength(1);
expect(chunk.tool_use![0]).toMatchObject({ id: 'tool-abc', tool_name: '', tool_input: '{"key":' });
});
it('removes the block from the map on contentBlockStop', () => {
toolBlocks.set(1, { id: 'tool-abc', name: 'my_tool' });
driver['getExtractedStream'](
{ contentBlockStop: { contentBlockIndex: 1 } },
undefined,
undefined,
toolBlocks
);
expect(toolBlocks.has(1)).toBe(false);
});
it('tracks two interleaved tool calls by independent contentBlockIndex', () => {
driver['getExtractedStream'](
{ contentBlockStart: { contentBlockIndex: 1, start: { toolUse: { toolUseId: 'id-1', name: 'tool_a' } } } },
undefined, undefined, toolBlocks
);
driver['getExtractedStream'](
{ contentBlockStart: { contentBlockIndex: 3, start: { toolUse: { toolUseId: 'id-2', name: 'tool_b' } } } },
undefined, undefined, toolBlocks
);
expect(toolBlocks.get(1)).toEqual({ id: 'id-1', name: 'tool_a' });
expect(toolBlocks.get(3)).toEqual({ id: 'id-2', name: 'tool_b' });
const chunk = driver['getExtractedStream'](
{ contentBlockDelta: { contentBlockIndex: 3, delta: { toolUse: { input: '"val"' } } } },
undefined, undefined, toolBlocks
);
expect(chunk.tool_use![0].id).toBe('id-2');
});
it('still extracts text deltas when no tool use is present', () => {
const chunk = driver['getExtractedStream'](
{ contentBlockDelta: { contentBlockIndex: 0, delta: { text: 'hello' } } },
undefined,
undefined,
toolBlocks
);
expect(chunk.result).toEqual([{ type: 'text', value: 'hello' }]);
expect(chunk.tool_use).toBeUndefined();
});
it('emits finish_reason "tool_use" from messageStop', () => {
const chunk = driver['getExtractedStream'](
{ messageStop: { stopReason: 'tool_use' } },
undefined,
undefined,
toolBlocks
);
expect(chunk.finish_reason).toBe('tool_use');
});
});
// ---------------------------------------------------------------------------
// Integration tests: full accumulation via driver.stream()
// ---------------------------------------------------------------------------
class FakeDriver extends AbstractDriver<DriverOptions, string> {
provider = 'fake';
chunks: CompletionChunkObject[] = [];
async requestTextCompletion(_prompt: string, _options: ExecutionOptions): Promise<Completion> {
throw new Error('not implemented');
}
async requestTextCompletionStream(_prompt: string, _options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
const chunks = this.chunks;
return (async function* () { for (const c of chunks) yield c; })();
}
async listModels(_params?: ModelSearchPayload): Promise<AIModel[]> { return []; }
async validateConnection(): Promise<boolean> { return true; }
async generateEmbeddings(_options: EmbeddingsOptions): Promise<EmbeddingsResult> {
throw new Error('not implemented');
}
}
const FAKE_SEGMENTS: PromptSegment[] = [{ role: PromptRole.user, content: 'test' }];
describe('driver.stream() — Bedrock tool use accumulation', () => {
it('assembles and JSON-parses tool_input from streamed chunks', async () => {
const driver = new FakeDriver({});
const options: ExecutionOptions = { model: 'test-model' };
// Simulate what the fixed getExtractedStream emits for one tool call
driver.chunks = [
{ result: [], tool_use: [{ id: 'tool-1', tool_name: 'do_thing', tool_input: '' as any }] },
{ result: [], tool_use: [{ id: 'tool-1', tool_name: '', tool_input: '{"param"' as any }] },
{ result: [], tool_use: [{ id: 'tool-1', tool_name: '', tool_input: ':"hello"}' as any }] },
{ result: [], finish_reason: 'tool_use' },
];
const stream = await driver.stream(FAKE_SEGMENTS, options);
for await (const _ of stream) { /* drain */ }
expect(stream.completion!.finish_reason).toBe('tool_use');
expect(stream.completion!.tool_use).toHaveLength(1);
expect(stream.completion!.tool_use![0]).toMatchObject({
id: 'tool-1',
tool_name: 'do_thing',
tool_input: { param: 'hello' },
});
});
it('handles two simultaneous tool calls', async () => {
const driver = new FakeDriver({});
const options: ExecutionOptions = { model: 'test-model' };
driver.chunks = [
{ result: [], tool_use: [{ id: 'id-a', tool_name: 'tool_a', tool_input: '' as any }] },
{ result: [], tool_use: [{ id: 'id-b', tool_name: 'tool_b', tool_input: '' as any }] },
{ result: [], tool_use: [{ id: 'id-a', tool_name: '', tool_input: '{"x":1}' as any }] },
{ result: [], tool_use: [{ id: 'id-b', tool_name: '', tool_input: '{"y":2}' as any }] },
{ result: [], finish_reason: 'tool_use' },
];
const stream = await driver.stream(FAKE_SEGMENTS, options);
for await (const _ of stream) { /* drain */ }
const toolUse = stream.completion!.tool_use!;
expect(toolUse).toHaveLength(2);
expect(toolUse.find(t => t.id === 'id-a')!.tool_input).toEqual({ x: 1 });
expect(toolUse.find(t => t.id === 'id-b')!.tool_input).toEqual({ y: 2 });
});
it('drops truncated tool calls when finish_reason is length', async () => {
const driver = new FakeDriver({});
const options: ExecutionOptions = { model: 'test-model' };
driver.chunks = [
{ result: [], tool_use: [{ id: 'trunc', tool_name: 'tool_c', tool_input: '' as any }] },
{ result: [], tool_use: [{ id: 'trunc', tool_name: '', tool_input: '{"incomplete' as any }] },
{ result: [], finish_reason: 'length' },
];
const stream = await driver.stream(FAKE_SEGMENTS, options);
for await (const _ of stream) { /* drain */ }
expect(stream.completion!.tool_use).toBeUndefined();
});
});
describe('BedrockDriver buildStreamingConversation', () => {
it('writes streamed text and tool use blocks back into the assistant message', () => {
const driver = new BedrockDriver({ region: 'us-east-1' });
const prompt = {
modelId: 'anthropic.claude-sonnet',
messages: [
{ role: 'user', content: [{ text: 'What is the weather in Paris?' }] },
],
};
const conversation = driver.buildStreamingConversation(
prompt as any,
[{ type: 'text', value: 'Let me check.' }] as any,
[{
id: 'tool-1',
tool_name: 'get_weather',
tool_input: { location: 'Paris' },
}],
{ model: 'anthropic.claude-sonnet' } as ExecutionOptions
) as any;
expect(conversation.messages).toHaveLength(2);
expect(conversation.messages[0]).toEqual(prompt.messages[0]);
expect(conversation.messages[1]).toEqual({
role: 'assistant',
content: [
{ text: 'Let me check.' },
{
toolUse: {
toolUseId: 'tool-1',
name: 'get_weather',
input: { location: 'Paris' },
},
},
],
});
});
});