UNPKG

@langchain/community

Version:
638 lines (637 loc) 25.9 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.ChatBedrock = exports.BedrockChat = exports.convertMessagesToPrompt = exports.convertMessagesToPromptAnthropic = void 0; const signature_v4_1 = require("@smithy/signature-v4"); const protocol_http_1 = require("@smithy/protocol-http"); const eventstream_codec_1 = require("@smithy/eventstream-codec"); const util_utf8_1 = require("@smithy/util-utf8"); const sha256_js_1 = require("@aws-crypto/sha256-js"); const chat_models_1 = require("@langchain/core/language_models/chat_models"); const base_1 = require("@langchain/core/language_models/base"); const env_1 = require("@langchain/core/utils/env"); const messages_1 = require("@langchain/core/messages"); const outputs_1 = require("@langchain/core/outputs"); const function_calling_1 = require("@langchain/core/utils/function_calling"); const zod_to_json_schema_1 = require("zod-to-json-schema"); const index_js_1 = require("../../utils/bedrock/index.cjs"); const anthropic_js_1 = require("../../utils/bedrock/anthropic.cjs"); const PRELUDE_TOTAL_LENGTH_BYTES = 4; function convertOneMessageToText(message, humanPrompt, aiPrompt) { if (message._getType() === "human") { return `${humanPrompt} ${message.content}`; } else if (message._getType() === "ai") { return `${aiPrompt} ${message.content}`; } else if (message._getType() === "system") { return `${humanPrompt} <admin>${message.content}</admin>`; } else if (message._getType() === "function") { return `${humanPrompt} ${message.content}`; } else if (messages_1.ChatMessage.isInstance(message)) { return `\n\n${message.role[0].toUpperCase() + message.role.slice(1)}: {message.content}`; } throw new Error(`Unknown role: ${message._getType()}`); } function convertMessagesToPromptAnthropic(messages, humanPrompt = "\n\nHuman:", aiPrompt = "\n\nAssistant:") { const messagesCopy = [...messages]; if (messagesCopy.length === 0 || messagesCopy[messagesCopy.length - 1]._getType() !== "ai") { messagesCopy.push(new messages_1.AIMessage({ content: "" })); } return messagesCopy .map((message) => convertOneMessageToText(message, humanPrompt, aiPrompt)) .join(""); } exports.convertMessagesToPromptAnthropic = convertMessagesToPromptAnthropic; /** * Function that converts an array of messages into a single string prompt * that can be used as input for a chat model. It delegates the conversion * logic to the appropriate provider-specific function. * @param messages Array of messages to be converted. * @param options Options to be used during the conversion. * @returns A string prompt that can be used as input for a chat model. */ function convertMessagesToPrompt(messages, provider) { if (provider === "anthropic") { return convertMessagesToPromptAnthropic(messages); } throw new Error(`Provider ${provider} does not support chat.`); } exports.convertMessagesToPrompt = convertMessagesToPrompt; function formatTools(tools) { if (!tools || !tools.length) { return []; } if (tools.every((tc) => (0, function_calling_1.isStructuredTool)(tc))) { return tools.map((tc) => ({ name: tc.name, description: tc.description, input_schema: (0, zod_to_json_schema_1.zodToJsonSchema)(tc.schema), })); } if (tools.every((tc) => (0, base_1.isOpenAITool)(tc))) { return tools.map((tc) => ({ name: tc.function.name, description: tc.function.description, input_schema: tc.function.parameters, })); } if (tools.every((tc) => (0, anthropic_js_1.isAnthropicTool)(tc))) { return tools; } if (tools.some((tc) => (0, function_calling_1.isStructuredTool)(tc)) || tools.some((tc) => (0, base_1.isOpenAITool)(tc)) || tools.some((tc) => (0, anthropic_js_1.isAnthropicTool)(tc))) { throw new Error("All tools passed to BedrockChat must be of the same type."); } throw new Error("Invalid tool format received."); } /** * A type of Large Language Model (LLM) that interacts with the Bedrock * service. It extends the base `LLM` class and implements the * `BaseBedrockInput` interface. The class is designed to authenticate and * interact with the Bedrock service, which is a part of Amazon Web * Services (AWS). It uses AWS credentials for authentication and can be * configured with various parameters such as the model to use, the AWS * region, and the maximum number of tokens to generate. * * The `BedrockChat` class supports both synchronous and asynchronous interactions with the model, * allowing for streaming responses and handling new token callbacks. It can be configured with * optional parameters like temperature, stop sequences, and guardrail settings for enhanced control * over the generated responses. * * @example * ```typescript * import { BedrockChat } from 'path-to-your-bedrock-chat-module'; * import { HumanMessage } from '@langchain/core/messages'; * * async function run() { * // Instantiate the BedrockChat model with the desired configuration * const model = new BedrockChat({ * model: "anthropic.claude-v2", * region: "us-east-1", * credentials: { * accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, * secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, * }, * maxTokens: 150, * temperature: 0.7, * stopSequences: ["\n", " Human:", " Assistant:"], * streaming: false, * trace: "ENABLED", * guardrailIdentifier: "your-guardrail-id", * guardrailVersion: "1.0", * guardrailConfig: { * tagSuffix: "example", * streamProcessingMode: "SYNCHRONOUS", * }, * }); * * // Prepare the message to be sent to the model * const message = new HumanMessage("Tell me a joke"); * * // Invoke the model with the message * const res = await model.invoke([message]); * * // Output the response from the model * console.log(res); * } * * run().catch(console.error); * ``` * * For streaming responses, use the following example: * @example * ```typescript * import { BedrockChat } from 'path-to-your-bedrock-chat-module'; * import { HumanMessage } from '@langchain/core/messages'; * * async function runStreaming() { * // Instantiate the BedrockChat model with the desired configuration * const model = new BedrockChat({ * model: "anthropic.claude-3-sonnet-20240229-v1:0", * region: "us-east-1", * credentials: { * accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, * secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, * }, * maxTokens: 150, * temperature: 0.7, * stopSequences: ["\n", " Human:", " Assistant:"], * streaming: true, * trace: "ENABLED", * guardrailIdentifier: "your-guardrail-id", * guardrailVersion: "1.0", * guardrailConfig: { * tagSuffix: "example", * streamProcessingMode: "SYNCHRONOUS", * }, * }); * * // Prepare the message to be sent to the model * const message = new HumanMessage("Tell me a joke"); * * // Stream the response from the model * const stream = await model.stream([message]); * for await (const chunk of stream) { * // Output each chunk of the response * console.log(chunk); * } * } * * runStreaming().catch(console.error); * ``` */ class BedrockChat extends chat_models_1.BaseChatModel { get lc_aliases() { return { model: "model_id", region: "region_name", }; } get lc_secrets() { return { "credentials.accessKeyId": "BEDROCK_AWS_ACCESS_KEY_ID", "credentials.secretAccessKey": "BEDROCK_AWS_SECRET_ACCESS_KEY", }; } get lc_attributes() { return { region: this.region }; } _identifyingParams() { return { model: this.model, }; } _llmType() { return "bedrock"; } static lc_name() { return "BedrockChat"; } constructor(fields) { super(fields ?? {}); Object.defineProperty(this, "model", { enumerable: true, configurable: true, writable: true, value: "amazon.titan-tg1-large" }); Object.defineProperty(this, "region", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "credentials", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "temperature", { enumerable: true, configurable: true, writable: true, value: undefined }); Object.defineProperty(this, "maxTokens", { enumerable: true, configurable: true, writable: true, value: undefined }); Object.defineProperty(this, "fetchFn", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "endpointHost", { enumerable: true, configurable: true, writable: true, value: void 0 }); /** @deprecated Use as a call option using .bind() instead. */ Object.defineProperty(this, "stopSequences", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "modelKwargs", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "codec", { enumerable: true, configurable: true, writable: true, value: new eventstream_codec_1.EventStreamCodec(util_utf8_1.toUtf8, util_utf8_1.fromUtf8) }); Object.defineProperty(this, "streaming", { enumerable: true, configurable: true, writable: true, value: false }); Object.defineProperty(this, "usesMessagesApi", { enumerable: true, configurable: true, writable: true, value: false }); Object.defineProperty(this, "lc_serializable", { enumerable: true, configurable: true, writable: true, value: true }); Object.defineProperty(this, "trace", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "guardrailIdentifier", { enumerable: true, configurable: true, writable: true, value: "" }); Object.defineProperty(this, "guardrailVersion", { enumerable: true, configurable: true, writable: true, value: "" }); Object.defineProperty(this, "guardrailConfig", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "_anthropicTools", { enumerable: true, configurable: true, writable: true, value: void 0 }); this.model = fields?.model ?? this.model; const allowedModels = [ "ai21", "anthropic", "amazon", "cohere", "meta", "mistral", ]; if (!allowedModels.includes(this.model.split(".")[0])) { throw new Error(`Unknown model: '${this.model}', only these are supported: ${allowedModels}`); } const region = fields?.region ?? (0, env_1.getEnvironmentVariable)("AWS_DEFAULT_REGION"); if (!region) { throw new Error("Please set the AWS_DEFAULT_REGION environment variable or pass it to the constructor as the region field."); } this.region = region; const credentials = fields?.credentials; if (!credentials) { throw new Error("Please set the AWS credentials in the 'credentials' field."); } this.credentials = credentials; this.temperature = fields?.temperature ?? this.temperature; this.maxTokens = fields?.maxTokens ?? this.maxTokens; this.fetchFn = fields?.fetchFn ?? fetch.bind(globalThis); this.endpointHost = fields?.endpointHost ?? fields?.endpointUrl; this.stopSequences = fields?.stopSequences; this.modelKwargs = fields?.modelKwargs; this.streaming = fields?.streaming ?? this.streaming; this.usesMessagesApi = canUseMessagesApi(this.model); this.trace = fields?.trace ?? this.trace; this.guardrailVersion = fields?.guardrailVersion ?? this.guardrailVersion; this.guardrailIdentifier = fields?.guardrailIdentifier ?? this.guardrailIdentifier; this.guardrailConfig = fields?.guardrailConfig; } invocationParams(options) { if (options?.tool_choice) { throw new Error("'tool_choice' call option is not supported by BedrockChat."); } const callOptionTools = formatTools(options?.tools ?? []); return { tools: [...(this._anthropicTools ?? []), ...callOptionTools], temperature: this.temperature, max_tokens: this.maxTokens, stop: options?.stop ?? this.stopSequences, modelKwargs: this.modelKwargs, guardrailConfig: this.guardrailConfig, }; } getLsParams(options) { const params = this.invocationParams(options); return { ls_provider: "bedrock", ls_model_name: this.model, ls_model_type: "chat", ls_temperature: params.temperature ?? undefined, ls_max_tokens: params.max_tokens ?? undefined, ls_stop: options.stop, }; } async _generate(messages, options, runManager) { if (this.streaming) { const stream = this._streamResponseChunks(messages, options, runManager); let finalResult; for await (const chunk of stream) { if (finalResult === undefined) { finalResult = chunk; } else { finalResult = finalResult.concat(chunk); } } if (finalResult === undefined) { throw new Error("Could not parse final output from Bedrock streaming call."); } return { generations: [finalResult], llmOutput: finalResult.generationInfo, }; } return this._generateNonStreaming(messages, options, runManager); } async _generateNonStreaming(messages, options, _runManager) { const service = "bedrock-runtime"; const endpointHost = this.endpointHost ?? `${service}.${this.region}.amazonaws.com`; const provider = this.model.split(".")[0]; const response = await this._signedFetch(messages, options, { bedrockMethod: "invoke", endpointHost, provider, }); const json = await response.json(); if (!response.ok) { throw new Error(`Error ${response.status}: ${json.message ?? JSON.stringify(json)}`); } if (this.usesMessagesApi) { const outputGeneration = index_js_1.BedrockLLMInputOutputAdapter.prepareMessagesOutput(provider, json); if (outputGeneration === undefined) { throw new Error("Failed to parse output generation."); } return { generations: [outputGeneration], llmOutput: outputGeneration.generationInfo, }; } else { const text = index_js_1.BedrockLLMInputOutputAdapter.prepareOutput(provider, json); return { generations: [{ text, message: new messages_1.AIMessage(text) }] }; } } async _signedFetch(messages, options, fields) { const { bedrockMethod, endpointHost, provider } = fields; const { max_tokens, temperature, stop, modelKwargs, guardrailConfig, tools, } = this.invocationParams(options); const inputBody = this.usesMessagesApi ? index_js_1.BedrockLLMInputOutputAdapter.prepareMessagesInput(provider, messages, max_tokens, temperature, stop, modelKwargs, guardrailConfig, tools) : index_js_1.BedrockLLMInputOutputAdapter.prepareInput(provider, convertMessagesToPromptAnthropic(messages), max_tokens, temperature, stop, modelKwargs, fields.bedrockMethod, guardrailConfig); const url = new URL(`https://${endpointHost}/model/${this.model}/${bedrockMethod}`); const request = new protocol_http_1.HttpRequest({ hostname: url.hostname, path: url.pathname, protocol: url.protocol, method: "POST", body: JSON.stringify(inputBody), query: Object.fromEntries(url.searchParams.entries()), headers: { // host is required by AWS Signature V4: https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html host: url.host, accept: "application/json", "content-type": "application/json", ...(this.trace && this.guardrailIdentifier && this.guardrailVersion && { "X-Amzn-Bedrock-Trace": this.trace, "X-Amzn-Bedrock-GuardrailIdentifier": this.guardrailIdentifier, "X-Amzn-Bedrock-GuardrailVersion": this.guardrailVersion, }), }, }); const signer = new signature_v4_1.SignatureV4({ credentials: this.credentials, service: "bedrock", region: this.region, sha256: sha256_js_1.Sha256, }); const signedRequest = await signer.sign(request); // Send request to AWS using the low-level fetch API const response = await this.caller.callWithOptions({ signal: options.signal }, async () => this.fetchFn(url, { headers: signedRequest.headers, body: signedRequest.body, method: signedRequest.method, })); return response; } async *_streamResponseChunks(messages, options, runManager) { const provider = this.model.split(".")[0]; const service = "bedrock-runtime"; const endpointHost = this.endpointHost ?? `${service}.${this.region}.amazonaws.com`; const bedrockMethod = provider === "anthropic" || provider === "cohere" || provider === "meta" || provider === "mistral" ? "invoke-with-response-stream" : "invoke"; const response = await this._signedFetch(messages, options, { bedrockMethod, endpointHost, provider, }); if (response.status < 200 || response.status >= 300) { throw Error(`Failed to access underlying url '${endpointHost}': got ${response.status} ${response.statusText}: ${await response.text()}`); } if (provider === "anthropic" || provider === "cohere" || provider === "meta" || provider === "mistral") { const toolsInParams = !(0, anthropic_js_1._toolsInParams)(options); const reader = response.body?.getReader(); const decoder = new TextDecoder(); for await (const chunk of this._readChunks(reader)) { const event = this.codec.decode(chunk); if ((event.headers[":event-type"] !== undefined && event.headers[":event-type"].value !== "chunk") || event.headers[":content-type"].value !== "application/json") { throw Error(`Failed to get event chunk: got ${chunk}`); } const body = JSON.parse(decoder.decode(event.body)); if (body.message) { throw new Error(body.message); } if (body.bytes !== undefined) { const chunkResult = JSON.parse(decoder.decode(Uint8Array.from(atob(body.bytes), (m) => m.codePointAt(0) ?? 0))); if (this.usesMessagesApi) { const chunk = index_js_1.BedrockLLMInputOutputAdapter.prepareMessagesOutput(provider, chunkResult, { coerceContentToString: toolsInParams, }); if (chunk === undefined) { continue; } if (provider === "anthropic" && chunk.generationInfo?.usage !== undefined) { // Avoid bad aggregation in chunks, rely on final Bedrock data delete chunk.generationInfo.usage; } const finalMetrics = chunk.generationInfo?.["amazon-bedrock-invocationMetrics"]; if (finalMetrics != null && typeof finalMetrics === "object" && (0, messages_1.isAIMessage)(chunk.message)) { chunk.message.usage_metadata = { input_tokens: finalMetrics.inputTokenCount, output_tokens: finalMetrics.outputTokenCount, total_tokens: finalMetrics.inputTokenCount + finalMetrics.outputTokenCount, }; } if (isChatGenerationChunk(chunk)) { yield chunk; } // eslint-disable-next-line no-void void runManager?.handleLLMNewToken(chunk.text); } else { const text = index_js_1.BedrockLLMInputOutputAdapter.prepareOutput(provider, chunkResult); yield new outputs_1.ChatGenerationChunk({ text, message: new messages_1.AIMessageChunk({ content: text }), }); // eslint-disable-next-line no-void void runManager?.handleLLMNewToken(text); } } } } else { const json = await response.json(); const text = index_js_1.BedrockLLMInputOutputAdapter.prepareOutput(provider, json); yield new outputs_1.ChatGenerationChunk({ text, message: new messages_1.AIMessageChunk({ content: text }), }); // eslint-disable-next-line no-void void runManager?.handleLLMNewToken(text); } } // eslint-disable-next-line @typescript-eslint/no-explicit-any _readChunks(reader) { function _concatChunks(a, b) { const newBuffer = new Uint8Array(a.length + b.length); newBuffer.set(a); newBuffer.set(b, a.length); return newBuffer; } function getMessageLength(buffer) { if (buffer.byteLength < PRELUDE_TOTAL_LENGTH_BYTES) return 0; const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength); return view.getUint32(0, false); } return { async *[Symbol.asyncIterator]() { let readResult = await reader.read(); let buffer = new Uint8Array(0); while (!readResult.done) { const chunk = readResult.value; buffer = _concatChunks(buffer, chunk); let messageLength = getMessageLength(buffer); while (buffer.byteLength >= PRELUDE_TOTAL_LENGTH_BYTES && buffer.byteLength >= messageLength) { yield buffer.slice(0, messageLength); buffer = buffer.slice(messageLength); messageLength = getMessageLength(buffer); } readResult = await reader.read(); } }, }; } _combineLLMOutput() { return {}; } bindTools(tools, _kwargs) { const provider = this.model.split(".")[0]; if (provider !== "anthropic") { throw new Error("Currently, tool calling through Bedrock is only supported for Anthropic models."); } this._anthropicTools = formatTools(tools); return this; } } exports.BedrockChat = BedrockChat; function isChatGenerationChunk(x) { return (x !== undefined && typeof x.concat === "function"); } function canUseMessagesApi(model) { const modelProviderName = model.split(".")[0]; if (modelProviderName === "anthropic" && !model.includes("claude-v2") && !model.includes("claude-instant-v1")) { return true; } if (modelProviderName === "cohere") { if (model.includes("command-r-v1")) { return true; } if (model.includes("command-r-plus-v1")) { return true; } } return false; } /** * @deprecated Use `BedrockChat` instead. */ exports.ChatBedrock = BedrockChat;