UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

189 lines (158 loc) • 5.5 kB
import { InvokeModelWithResponseStreamResponse } from '@aws-sdk/client-bedrock-runtime'; import { Readable } from 'stream'; import { describe, expect, it, vi } from 'vitest'; import * as uuidModule from '@/utils/uuid'; import { AWSBedrockLlamaStream } from './llama'; describe('AWSBedrockLlamaStream', () => { it('should transform Bedrock Llama stream to protocol stream', async () => { vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('1'); const mockBedrockStream = new ReadableStream({ start(controller) { controller.enqueue({ generation: 'Hello', generation_token_count: 1 }); controller.enqueue({ generation: ' world!', generation_token_count: 2 }); controller.enqueue({ stop_reason: 'stop' }); controller.close(); }, }); const onStartMock = vi.fn(); const onTextMock = vi.fn(); const onCompletionMock = vi.fn(); const protocolStream = AWSBedrockLlamaStream(mockBedrockStream, { onStart: onStartMock, onText: onTextMock, onCompletion: onCompletionMock, }); const decoder = new TextDecoder(); const chunks = []; // @ts-ignore for await (const chunk of protocolStream) { chunks.push(decoder.decode(chunk, { stream: true })); } expect(chunks).toEqual([ 'id: chat_1\n', 'event: text\n', `data: "Hello"\n\n`, 'id: chat_1\n', 'event: text\n', `data: " world!"\n\n`, 'id: chat_1\n', 'event: stop\n', `data: "finished"\n\n`, ]); expect(onStartMock).toHaveBeenCalledTimes(1); expect(onTextMock).toHaveBeenNthCalledWith(1, 'Hello'); expect(onTextMock).toHaveBeenNthCalledWith(2, ' world!'); expect(onCompletionMock).toHaveBeenCalledTimes(1); }); it('should transform Bedrock Llama AsyncIterator to protocol stream', async () => { vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('1'); const mockBedrockStream: InvokeModelWithResponseStreamResponse = { body: { // @ts-ignore async *[Symbol.asyncIterator]() { yield { generation: 'Hello', generation_token_count: 1 }; yield { generation: ' world!', generation_token_count: 2 }; yield { stop_reason: 'stop' }; }, }, }; const onStartMock = vi.fn(); const onTextMock = vi.fn(); const onCompletionMock = vi.fn(); const protocolStream = AWSBedrockLlamaStream(mockBedrockStream, { onStart: onStartMock, onText: onTextMock, onCompletion: onCompletionMock, }); const decoder = new TextDecoder(); const chunks = []; // @ts-ignore for await (const chunk of protocolStream) { chunks.push(decoder.decode(chunk, { stream: true })); } expect(chunks).toEqual([ 'id: chat_1\n', 'event: text\n', `data: "Hello"\n\n`, 'id: chat_1\n', 'event: text\n', `data: " world!"\n\n`, 'id: chat_1\n', 'event: stop\n', `data: "finished"\n\n`, ]); expect(onStartMock).toHaveBeenCalledTimes(1); expect(onTextMock).toHaveBeenNthCalledWith(1, 'Hello'); expect(onTextMock).toHaveBeenNthCalledWith(2, ' world!'); expect(onCompletionMock).toHaveBeenCalledTimes(1); }); it('should handle Bedrock response with chunk property', async () => { vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('2'); const mockBedrockStream: InvokeModelWithResponseStreamResponse = { contentType: 'any', body: { // @ts-ignore async *[Symbol.asyncIterator]() { yield { chunk: { bytes: new TextEncoder().encode('{"generation":"Hello","generation_token_count":1}'), }, }; yield { chunk: { bytes: new TextEncoder().encode( '{"generation":" world!","generation_token_count":2}', ), }, }; yield { chunk: { bytes: new TextEncoder().encode('{"stop_reason":"stop"}') } }; }, }, }; const onStartMock = vi.fn(); const onTextMock = vi.fn(); const onTokenMock = vi.fn(); const onCompletionMock = vi.fn(); const protocolStream = AWSBedrockLlamaStream(mockBedrockStream, { onStart: onStartMock, onText: onTextMock, onCompletion: onCompletionMock, }); const decoder = new TextDecoder(); const chunks = []; // @ts-ignore for await (const chunk of protocolStream) { chunks.push(decoder.decode(chunk, { stream: true })); } expect(chunks).toEqual([ 'id: chat_2\n', 'event: text\n', `data: "Hello"\n\n`, 'id: chat_2\n', 'event: text\n', `data: " world!"\n\n`, 'id: chat_2\n', 'event: stop\n', `data: "finished"\n\n`, ]); expect(onStartMock).toHaveBeenCalledTimes(1); expect(onTextMock).toHaveBeenNthCalledWith(1, 'Hello'); expect(onTextMock).toHaveBeenNthCalledWith(2, ' world!'); expect(onCompletionMock).toHaveBeenCalledTimes(1); }); it('should handle empty stream', async () => { const mockBedrockStream = new ReadableStream({ start(controller) { controller.close(); }, }); const protocolStream = AWSBedrockLlamaStream(mockBedrockStream); const decoder = new TextDecoder(); const chunks = []; // @ts-ignore for await (const chunk of protocolStream) { chunks.push(decoder.decode(chunk, { stream: true })); } expect(chunks).toEqual([]); }); });