UNPKG

@langchain/community

Version:
257 lines (254 loc) 9.51 kB
Object.defineProperty(exports, Symbol.toStringTag, { value: "Module" }); const require_runtime = require("../../_virtual/_rolldown/runtime.cjs"); const require_index = require("../../utils/bedrock/index.cjs"); let _langchain_core_outputs = require("@langchain/core/outputs"); let _langchain_core_utils_env = require("@langchain/core/utils/env"); let _smithy_signature_v4 = require("@smithy/signature-v4"); let _smithy_protocol_http = require("@smithy/protocol-http"); let _smithy_eventstream_codec = require("@smithy/eventstream-codec"); let _smithy_util_utf8 = require("@smithy/util-utf8"); let _aws_crypto_sha256_js = require("@aws-crypto/sha256-js"); let _langchain_core_language_models_llms = require("@langchain/core/language_models/llms"); //#region src/llms/bedrock/web.ts var web_exports = /* @__PURE__ */ require_runtime.__exportAll({ Bedrock: () => Bedrock }); /** * @see https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html#Concepts.RegionsAndAvailabilityZones.Regions */ const AWS_REGIONS = [ "us", "sa", "me", "mx", "il", "eu", "cn", "ca", "ap", "af", "us-gov", "apac", "au", "jp", "global" ]; const ALLOWED_MODEL_PROVIDERS = [ "ai21", "anthropic", "amazon", "cohere", "meta", "mistral", "deepseek" ]; const PRELUDE_TOTAL_LENGTH_BYTES = 4; /** * A type of Large Language Model (LLM) that interacts with the Bedrock * service. It extends the base `LLM` class and implements the * `BaseBedrockInput` interface. The class is designed to authenticate and * interact with the Bedrock service, which is a part of Amazon Web * Services (AWS). It uses AWS credentials for authentication and can be * configured with various parameters such as the model to use, the AWS * region, and the maximum number of tokens to generate. */ var Bedrock = class extends _langchain_core_language_models_llms.LLM { model = "amazon.titan-tg1-large"; modelProvider; region; credentials; temperature = void 0; maxTokens = void 0; fetchFn; endpointHost; modelKwargs; codec = new _smithy_eventstream_codec.EventStreamCodec(_smithy_util_utf8.toUtf8, _smithy_util_utf8.fromUtf8); streaming = false; lc_serializable = true; get lc_aliases() { return { model: "model_id", region: "region_name" }; } get lc_secrets() { return { "credentials.accessKeyId": "BEDROCK_AWS_ACCESS_KEY_ID", "credentials.secretAccessKey": "BEDROCK_AWS_SECRET_ACCESS_KEY" }; } get lc_attributes() { return { region: this.region }; } _llmType() { return "bedrock"; } static lc_name() { return "Bedrock"; } constructor(fields) { super(fields ?? {}); this.model = fields?.model ?? this.model; this.modelProvider = getModelProvider(this.model); if (!ALLOWED_MODEL_PROVIDERS.includes(this.modelProvider)) throw new Error(`Unknown model provider: '${this.modelProvider}', only these are supported: ${ALLOWED_MODEL_PROVIDERS}`); const region = fields?.region ?? (0, _langchain_core_utils_env.getEnvironmentVariable)("AWS_DEFAULT_REGION"); if (!region) throw new Error("Please set the AWS_DEFAULT_REGION environment variable or pass it to the constructor as the region field."); this.region = region; const credentials = fields?.credentials; if (!credentials) throw new Error("Please set the AWS credentials in the 'credentials' field."); this.credentials = credentials; this.temperature = fields?.temperature ?? this.temperature; this.maxTokens = fields?.maxTokens ?? this.maxTokens; this.fetchFn = fields?.fetchFn ?? fetch.bind(globalThis); this.endpointHost = fields?.endpointHost ?? fields?.endpointUrl; this.modelKwargs = fields?.modelKwargs; this.streaming = fields?.streaming ?? this.streaming; } /** Call out to Bedrock service model. Arguments: prompt: The prompt to pass into the model. Returns: The string generated by the model. Example: response = model.invoke("Tell me a joke.") */ async _call(prompt, options, runManager) { const endpointHost = this.endpointHost ?? `bedrock-runtime.${this.region}.amazonaws.com`; const provider = this.modelProvider; if (this.streaming) { const stream = this._streamResponseChunks(prompt, options, runManager); let finalResult; for await (const chunk of stream) if (finalResult === void 0) finalResult = chunk; else finalResult = finalResult.concat(chunk); return finalResult?.text ?? ""; } const response = await this._signedFetch(prompt, options, { bedrockMethod: "invoke", endpointHost, provider }); const json = await response.json(); if (!response.ok) throw new Error(`Error ${response.status}: ${json.message ?? JSON.stringify(json)}`); return require_index.BedrockLLMInputOutputAdapter.prepareOutput(provider, json); } async _signedFetch(prompt, options, fields) { const { bedrockMethod, endpointHost, provider } = fields; const inputBody = require_index.BedrockLLMInputOutputAdapter.prepareInput(provider, prompt, this.maxTokens, this.temperature, options.stop, this.modelKwargs, fields.bedrockMethod); const url = new URL(`https://${endpointHost}/model/${this.model}/${bedrockMethod}`); const request = new _smithy_protocol_http.HttpRequest({ hostname: url.hostname, path: url.pathname, protocol: url.protocol, method: "POST", body: JSON.stringify(inputBody), query: Object.fromEntries(url.searchParams.entries()), headers: { host: url.host, accept: "application/json", "content-type": "application/json" } }); const signedRequest = await new _smithy_signature_v4.SignatureV4({ credentials: this.credentials, service: "bedrock", region: this.region, sha256: _aws_crypto_sha256_js.Sha256 }).sign(request); return await this.caller.callWithOptions({ signal: options.signal }, async () => this.fetchFn(url, { headers: signedRequest.headers, body: signedRequest.body, method: signedRequest.method })); } invocationParams(options) { return { model: this.model, region: this.region, temperature: this.temperature, maxTokens: this.maxTokens, stop: options?.stop, modelKwargs: this.modelKwargs }; } async *_streamResponseChunks(prompt, options, runManager) { const provider = this.modelProvider; const bedrockMethod = provider === "anthropic" || provider === "cohere" || provider === "meta" || provider === "mistral" ? "invoke-with-response-stream" : "invoke"; const endpointHost = this.endpointHost ?? `bedrock-runtime.${this.region}.amazonaws.com`; const response = await this._signedFetch(prompt, options, { bedrockMethod, endpointHost, provider }); if (response.status < 200 || response.status >= 300) throw Error(`Failed to access underlying url '${endpointHost}': got ${response.status} ${response.statusText}: ${await response.text()}`); if (provider === "anthropic" || provider === "cohere" || provider === "meta" || provider === "mistral") { const reader = response.body?.getReader(); const decoder = new TextDecoder(); for await (const chunk of this._readChunks(reader)) { const event = this.codec.decode(chunk); if (event.headers[":event-type"] !== void 0 && event.headers[":event-type"].value !== "chunk" || event.headers[":content-type"].value !== "application/json") throw Error(`Failed to get event chunk: got ${chunk}`); const body = JSON.parse(decoder.decode(event.body)); if (body.message) throw new Error(body.message); if (body.bytes !== void 0) { const chunkResult = JSON.parse(decoder.decode(Uint8Array.from(atob(body.bytes), (m) => m.codePointAt(0) ?? 0))); const text = require_index.BedrockLLMInputOutputAdapter.prepareOutput(provider, chunkResult); yield new _langchain_core_outputs.GenerationChunk({ text, generationInfo: {} }); runManager?.handleLLMNewToken(text); } } } else { const json = await response.json(); const text = require_index.BedrockLLMInputOutputAdapter.prepareOutput(provider, json); yield new _langchain_core_outputs.GenerationChunk({ text, generationInfo: {} }); runManager?.handleLLMNewToken(text); } } _readChunks(reader) { function _concatChunks(a, b) { const newBuffer = new Uint8Array(a.length + b.length); newBuffer.set(a); newBuffer.set(b, a.length); return newBuffer; } function getMessageLength(buffer) { if (buffer.byteLength < PRELUDE_TOTAL_LENGTH_BYTES) return 0; return new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength).getUint32(0, false); } return { async *[Symbol.asyncIterator]() { let readResult = await reader.read(); let buffer = new Uint8Array(0); while (!readResult.done) { const chunk = readResult.value; buffer = _concatChunks(buffer, chunk); let messageLength = getMessageLength(buffer); while (buffer.byteLength >= PRELUDE_TOTAL_LENGTH_BYTES && buffer.byteLength >= messageLength) { yield buffer.slice(0, messageLength); buffer = buffer.slice(messageLength); messageLength = getMessageLength(buffer); } readResult = await reader.read(); } } }; } }; function isInferenceModel(modelId) { const parts = modelId.split("."); return AWS_REGIONS.some((region) => parts[0] === region); } function getModelProvider(modelId) { const parts = modelId.split("."); if (isInferenceModel(modelId)) return parts[1]; else return parts[0]; } //#endregion exports.Bedrock = Bedrock; Object.defineProperty(exports, "web_exports", { enumerable: true, get: function() { return web_exports; } }); //# sourceMappingURL=web.cjs.map