@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
219 lines (195 loc) • 6.83 kB
text/typescript
import {
BedrockRuntimeClient,
InvokeModelCommand,
InvokeModelWithResponseStreamCommand,
} from '@aws-sdk/client-bedrock-runtime';
import { experimental_buildLlama2Prompt } from 'ai/prompts';
import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import {
ChatMethodOptions,
ChatStreamPayload,
Embeddings,
EmbeddingsOptions,
EmbeddingsPayload,
ModelProvider,
} from '../types';
import { buildAnthropicMessages, buildAnthropicTools } from '../utils/anthropicHelpers';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { StreamingResponse } from '../utils/response';
import {
AWSBedrockClaudeStream,
AWSBedrockLlamaStream,
createBedrockStream,
} from '../utils/streams';
export interface LobeBedrockAIParams {
accessKeyId?: string;
accessKeySecret?: string;
region?: string;
sessionToken?: string;
}
export class LobeBedrockAI implements LobeRuntimeAI {
private client: BedrockRuntimeClient;
region: string;
constructor({ region, accessKeyId, accessKeySecret, sessionToken }: LobeBedrockAIParams = {}) {
if (!(accessKeyId && accessKeySecret))
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidBedrockCredentials);
this.region = region ?? 'us-east-1';
this.client = new BedrockRuntimeClient({
credentials: {
accessKeyId: accessKeyId,
secretAccessKey: accessKeySecret,
sessionToken: sessionToken,
},
region: this.region,
});
}
async chat(payload: ChatStreamPayload, options?: ChatMethodOptions) {
if (payload.model.startsWith('meta')) return this.invokeLlamaModel(payload, options);
return this.invokeClaudeModel(payload, options);
}
/**
* Supports the Amazon Titan Text models series.
* Cohere Embed models are not supported
* because the current text size per request
* exceeds the maximum 2048 characters limit
* for a single request for this series of models.
* [bedrock embed guide] https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html
*/
async embeddings(payload: EmbeddingsPayload, options?: EmbeddingsOptions): Promise<Embeddings[]> {
const input = Array.isArray(payload.input) ? payload.input : [payload.input];
const promises = input.map((inputText: string) =>
this.invokeEmbeddingModel(
{
dimensions: payload.dimensions,
input: inputText,
model: payload.model,
},
options,
),
);
return Promise.all(promises);
}
private invokeEmbeddingModel = async (
payload: EmbeddingsPayload,
options?: EmbeddingsOptions,
): Promise<Embeddings> => {
const command = new InvokeModelCommand({
accept: 'application/json',
body: JSON.stringify({
dimensions: payload.dimensions,
inputText: payload.input,
normalize: true,
}),
contentType: 'application/json',
modelId: payload.model,
});
try {
const res = await this.client.send(command, { abortSignal: options?.signal });
const responseBody = JSON.parse(new TextDecoder().decode(res.body));
return responseBody.embedding;
} catch (e) {
const err = e as Error & { $metadata: any };
throw AgentRuntimeError.chat({
error: {
body: err.$metadata,
message: err.message,
type: err.name,
},
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Bedrock,
region: this.region,
});
}
};
private invokeClaudeModel = async (
payload: ChatStreamPayload,
options?: ChatMethodOptions,
): Promise<Response> => {
const { max_tokens, messages, model, temperature, top_p, tools } = payload;
const system_message = messages.find((m) => m.role === 'system');
const user_messages = messages.filter((m) => m.role !== 'system');
const command = new InvokeModelWithResponseStreamCommand({
accept: 'application/json',
body: JSON.stringify({
anthropic_version: 'bedrock-2023-05-31',
max_tokens: max_tokens || 4096,
messages: await buildAnthropicMessages(user_messages),
system: system_message?.content as string,
temperature: temperature / 2,
tools: buildAnthropicTools(tools),
top_p: top_p,
}),
contentType: 'application/json',
modelId: model,
});
try {
// Ask Claude for a streaming chat completion given the prompt
const res = await this.client.send(command, { abortSignal: options?.signal });
const claudeStream = createBedrockStream(res);
const [prod, debug] = claudeStream.tee();
if (process.env.DEBUG_BEDROCK_CHAT_COMPLETION === '1') {
debugStream(debug).catch(console.error);
}
// Respond with the stream
return StreamingResponse(AWSBedrockClaudeStream(prod, options?.callback), {
headers: options?.headers,
});
} catch (e) {
const err = e as Error & { $metadata: any };
throw AgentRuntimeError.chat({
error: {
body: err.$metadata,
message: err.message,
type: err.name,
},
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Bedrock,
region: this.region,
});
}
};
private invokeLlamaModel = async (
payload: ChatStreamPayload,
options?: ChatMethodOptions,
): Promise<Response> => {
const { max_tokens, messages, model } = payload;
const command = new InvokeModelWithResponseStreamCommand({
accept: 'application/json',
body: JSON.stringify({
max_gen_len: max_tokens || 400,
prompt: experimental_buildLlama2Prompt(messages as any),
}),
contentType: 'application/json',
modelId: model,
});
try {
// Ask Claude for a streaming chat completion given the prompt
const res = await this.client.send(command);
const stream = createBedrockStream(res);
const [prod, debug] = stream.tee();
if (process.env.DEBUG_BEDROCK_CHAT_COMPLETION === '1') {
debugStream(debug).catch(console.error);
}
// Respond with the stream
return StreamingResponse(AWSBedrockLlamaStream(prod, options?.callback), {
headers: options?.headers,
});
} catch (e) {
const err = e as Error & { $metadata: any };
throw AgentRuntimeError.chat({
error: {
body: err.$metadata,
message: err.message,
region: this.region,
type: err.name,
},
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Bedrock,
region: this.region,
});
}
};
}
export default LobeBedrockAI;