@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
189 lines (158 loc) • 5.5 kB
text/typescript
import { InvokeModelWithResponseStreamResponse } from '@aws-sdk/client-bedrock-runtime';
import { Readable } from 'stream';
import { describe, expect, it, vi } from 'vitest';
import * as uuidModule from '@/utils/uuid';
import { AWSBedrockLlamaStream } from './llama';
describe('AWSBedrockLlamaStream', () => {
it('should transform Bedrock Llama stream to protocol stream', async () => {
vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('1');
const mockBedrockStream = new ReadableStream({
start(controller) {
controller.enqueue({ generation: 'Hello', generation_token_count: 1 });
controller.enqueue({ generation: ' world!', generation_token_count: 2 });
controller.enqueue({ stop_reason: 'stop' });
controller.close();
},
});
const onStartMock = vi.fn();
const onTextMock = vi.fn();
const onCompletionMock = vi.fn();
const protocolStream = AWSBedrockLlamaStream(mockBedrockStream, {
onStart: onStartMock,
onText: onTextMock,
onCompletion: onCompletionMock,
});
const decoder = new TextDecoder();
const chunks = [];
// @ts-ignore
for await (const chunk of protocolStream) {
chunks.push(decoder.decode(chunk, { stream: true }));
}
expect(chunks).toEqual([
'id: chat_1\n',
'event: text\n',
`data: "Hello"\n\n`,
'id: chat_1\n',
'event: text\n',
`data: " world!"\n\n`,
'id: chat_1\n',
'event: stop\n',
`data: "finished"\n\n`,
]);
expect(onStartMock).toHaveBeenCalledTimes(1);
expect(onTextMock).toHaveBeenNthCalledWith(1, 'Hello');
expect(onTextMock).toHaveBeenNthCalledWith(2, ' world!');
expect(onCompletionMock).toHaveBeenCalledTimes(1);
});
it('should transform Bedrock Llama AsyncIterator to protocol stream', async () => {
vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('1');
const mockBedrockStream: InvokeModelWithResponseStreamResponse = {
body: {
// @ts-ignore
async *[Symbol.asyncIterator]() {
yield { generation: 'Hello', generation_token_count: 1 };
yield { generation: ' world!', generation_token_count: 2 };
yield { stop_reason: 'stop' };
},
},
};
const onStartMock = vi.fn();
const onTextMock = vi.fn();
const onCompletionMock = vi.fn();
const protocolStream = AWSBedrockLlamaStream(mockBedrockStream, {
onStart: onStartMock,
onText: onTextMock,
onCompletion: onCompletionMock,
});
const decoder = new TextDecoder();
const chunks = [];
// @ts-ignore
for await (const chunk of protocolStream) {
chunks.push(decoder.decode(chunk, { stream: true }));
}
expect(chunks).toEqual([
'id: chat_1\n',
'event: text\n',
`data: "Hello"\n\n`,
'id: chat_1\n',
'event: text\n',
`data: " world!"\n\n`,
'id: chat_1\n',
'event: stop\n',
`data: "finished"\n\n`,
]);
expect(onStartMock).toHaveBeenCalledTimes(1);
expect(onTextMock).toHaveBeenNthCalledWith(1, 'Hello');
expect(onTextMock).toHaveBeenNthCalledWith(2, ' world!');
expect(onCompletionMock).toHaveBeenCalledTimes(1);
});
it('should handle Bedrock response with chunk property', async () => {
vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('2');
const mockBedrockStream: InvokeModelWithResponseStreamResponse = {
contentType: 'any',
body: {
// @ts-ignore
async *[Symbol.asyncIterator]() {
yield {
chunk: {
bytes: new TextEncoder().encode('{"generation":"Hello","generation_token_count":1}'),
},
};
yield {
chunk: {
bytes: new TextEncoder().encode(
'{"generation":" world!","generation_token_count":2}',
),
},
};
yield { chunk: { bytes: new TextEncoder().encode('{"stop_reason":"stop"}') } };
},
},
};
const onStartMock = vi.fn();
const onTextMock = vi.fn();
const onTokenMock = vi.fn();
const onCompletionMock = vi.fn();
const protocolStream = AWSBedrockLlamaStream(mockBedrockStream, {
onStart: onStartMock,
onText: onTextMock,
onCompletion: onCompletionMock,
});
const decoder = new TextDecoder();
const chunks = [];
// @ts-ignore
for await (const chunk of protocolStream) {
chunks.push(decoder.decode(chunk, { stream: true }));
}
expect(chunks).toEqual([
'id: chat_2\n',
'event: text\n',
`data: "Hello"\n\n`,
'id: chat_2\n',
'event: text\n',
`data: " world!"\n\n`,
'id: chat_2\n',
'event: stop\n',
`data: "finished"\n\n`,
]);
expect(onStartMock).toHaveBeenCalledTimes(1);
expect(onTextMock).toHaveBeenNthCalledWith(1, 'Hello');
expect(onTextMock).toHaveBeenNthCalledWith(2, ' world!');
expect(onCompletionMock).toHaveBeenCalledTimes(1);
});
it('should handle empty stream', async () => {
const mockBedrockStream = new ReadableStream({
start(controller) {
controller.close();
},
});
const protocolStream = AWSBedrockLlamaStream(mockBedrockStream);
const decoder = new TextDecoder();
const chunks = [];
// @ts-ignore
for await (const chunk of protocolStream) {
chunks.push(decoder.decode(chunk, { stream: true }));
}
expect(chunks).toEqual([]);
});
});