UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

219 lines (195 loc) • 6.83 kB
import { BedrockRuntimeClient, InvokeModelCommand, InvokeModelWithResponseStreamCommand, } from '@aws-sdk/client-bedrock-runtime'; import { experimental_buildLlama2Prompt } from 'ai/prompts'; import { LobeRuntimeAI } from '../BaseAI'; import { AgentRuntimeErrorType } from '../error'; import { ChatMethodOptions, ChatStreamPayload, Embeddings, EmbeddingsOptions, EmbeddingsPayload, ModelProvider, } from '../types'; import { buildAnthropicMessages, buildAnthropicTools } from '../utils/anthropicHelpers'; import { AgentRuntimeError } from '../utils/createError'; import { debugStream } from '../utils/debugStream'; import { StreamingResponse } from '../utils/response'; import { AWSBedrockClaudeStream, AWSBedrockLlamaStream, createBedrockStream, } from '../utils/streams'; export interface LobeBedrockAIParams { accessKeyId?: string; accessKeySecret?: string; region?: string; sessionToken?: string; } export class LobeBedrockAI implements LobeRuntimeAI { private client: BedrockRuntimeClient; region: string; constructor({ region, accessKeyId, accessKeySecret, sessionToken }: LobeBedrockAIParams = {}) { if (!(accessKeyId && accessKeySecret)) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidBedrockCredentials); this.region = region ?? 'us-east-1'; this.client = new BedrockRuntimeClient({ credentials: { accessKeyId: accessKeyId, secretAccessKey: accessKeySecret, sessionToken: sessionToken, }, region: this.region, }); } async chat(payload: ChatStreamPayload, options?: ChatMethodOptions) { if (payload.model.startsWith('meta')) return this.invokeLlamaModel(payload, options); return this.invokeClaudeModel(payload, options); } /** * Supports the Amazon Titan Text models series. * Cohere Embed models are not supported * because the current text size per request * exceeds the maximum 2048 characters limit * for a single request for this series of models. * [bedrock embed guide] https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html */ async embeddings(payload: EmbeddingsPayload, options?: EmbeddingsOptions): Promise<Embeddings[]> { const input = Array.isArray(payload.input) ? payload.input : [payload.input]; const promises = input.map((inputText: string) => this.invokeEmbeddingModel( { dimensions: payload.dimensions, input: inputText, model: payload.model, }, options, ), ); return Promise.all(promises); } private invokeEmbeddingModel = async ( payload: EmbeddingsPayload, options?: EmbeddingsOptions, ): Promise<Embeddings> => { const command = new InvokeModelCommand({ accept: 'application/json', body: JSON.stringify({ dimensions: payload.dimensions, inputText: payload.input, normalize: true, }), contentType: 'application/json', modelId: payload.model, }); try { const res = await this.client.send(command, { abortSignal: options?.signal }); const responseBody = JSON.parse(new TextDecoder().decode(res.body)); return responseBody.embedding; } catch (e) { const err = e as Error & { $metadata: any }; throw AgentRuntimeError.chat({ error: { body: err.$metadata, message: err.message, type: err.name, }, errorType: AgentRuntimeErrorType.ProviderBizError, provider: ModelProvider.Bedrock, region: this.region, }); } }; private invokeClaudeModel = async ( payload: ChatStreamPayload, options?: ChatMethodOptions, ): Promise<Response> => { const { max_tokens, messages, model, temperature, top_p, tools } = payload; const system_message = messages.find((m) => m.role === 'system'); const user_messages = messages.filter((m) => m.role !== 'system'); const command = new InvokeModelWithResponseStreamCommand({ accept: 'application/json', body: JSON.stringify({ anthropic_version: 'bedrock-2023-05-31', max_tokens: max_tokens || 4096, messages: await buildAnthropicMessages(user_messages), system: system_message?.content as string, temperature: temperature / 2, tools: buildAnthropicTools(tools), top_p: top_p, }), contentType: 'application/json', modelId: model, }); try { // Ask Claude for a streaming chat completion given the prompt const res = await this.client.send(command, { abortSignal: options?.signal }); const claudeStream = createBedrockStream(res); const [prod, debug] = claudeStream.tee(); if (process.env.DEBUG_BEDROCK_CHAT_COMPLETION === '1') { debugStream(debug).catch(console.error); } // Respond with the stream return StreamingResponse(AWSBedrockClaudeStream(prod, options?.callback), { headers: options?.headers, }); } catch (e) { const err = e as Error & { $metadata: any }; throw AgentRuntimeError.chat({ error: { body: err.$metadata, message: err.message, type: err.name, }, errorType: AgentRuntimeErrorType.ProviderBizError, provider: ModelProvider.Bedrock, region: this.region, }); } }; private invokeLlamaModel = async ( payload: ChatStreamPayload, options?: ChatMethodOptions, ): Promise<Response> => { const { max_tokens, messages, model } = payload; const command = new InvokeModelWithResponseStreamCommand({ accept: 'application/json', body: JSON.stringify({ max_gen_len: max_tokens || 400, prompt: experimental_buildLlama2Prompt(messages as any), }), contentType: 'application/json', modelId: model, }); try { // Ask Claude for a streaming chat completion given the prompt const res = await this.client.send(command); const stream = createBedrockStream(res); const [prod, debug] = stream.tee(); if (process.env.DEBUG_BEDROCK_CHAT_COMPLETION === '1') { debugStream(debug).catch(console.error); } // Respond with the stream return StreamingResponse(AWSBedrockLlamaStream(prod, options?.callback), { headers: options?.headers, }); } catch (e) { const err = e as Error & { $metadata: any }; throw AgentRuntimeError.chat({ error: { body: err.$metadata, message: err.message, region: this.region, type: err.name, }, errorType: AgentRuntimeErrorType.ProviderBizError, provider: ModelProvider.Bedrock, region: this.region, }); } }; } export default LobeBedrockAI;