@langchain/community
Version:
Third-party integrations for LangChain.js
1 lines • 11.4 kB
Source Map (JSON)
{"version":3,"file":"sagemaker_endpoint.cjs","names":["LLM","SageMakerRuntimeClient","InvokeEndpointCommand","InvokeEndpointWithResponseStreamCommand","GenerationChunk"],"sources":["../../src/llms/sagemaker_endpoint.ts"],"sourcesContent":["import {\n InvokeEndpointCommand,\n InvokeEndpointWithResponseStreamCommand,\n SageMakerRuntimeClient,\n SageMakerRuntimeClientConfig,\n} from \"@aws-sdk/client-sagemaker-runtime\";\nimport { CallbackManagerForLLMRun } from \"@langchain/core/callbacks/manager\";\nimport { GenerationChunk } from \"@langchain/core/outputs\";\nimport {\n type BaseLLMCallOptions,\n type BaseLLMParams,\n LLM,\n} from \"@langchain/core/language_models/llms\";\n\n/**\n * A handler class to transform input from LLM to a format that SageMaker\n * endpoint expects. Similarily, the class also handles transforming output from\n * the SageMaker endpoint to a format that LLM class expects.\n *\n * Example:\n * ```\n * class ContentHandler implements ContentHandlerBase<string, string> {\n * contentType = \"application/json\"\n * accepts = \"application/json\"\n *\n * transformInput(prompt: string, modelKwargs: Record<string, unknown>) {\n * const inputString = JSON.stringify({\n * prompt,\n * ...modelKwargs\n * })\n * return Buffer.from(inputString)\n * }\n *\n * transformOutput(output: Uint8Array) {\n * const responseJson = JSON.parse(Buffer.from(output).toString(\"utf-8\"))\n * return responseJson[0].generated_text\n * }\n *\n * }\n * ```\n */\nexport abstract class BaseSageMakerContentHandler<InputType, OutputType> {\n contentType = \"text/plain\";\n\n accepts = \"text/plain\";\n\n /**\n * Transforms the prompt and model arguments into a specific format for sending to SageMaker.\n * @param {InputType} prompt The prompt to be transformed.\n * @param {Record<string, unknown>} modelKwargs Additional arguments.\n * @returns {Promise<Uint8Array>} A promise that resolves to the formatted data for sending.\n */\n abstract transformInput(\n prompt: InputType,\n modelKwargs: Record<string, unknown>\n ): Promise<Uint8Array>;\n\n /**\n * Transforms SageMaker output into a desired format.\n * @param {Uint8Array} output The raw output from SageMaker.\n * @returns {Promise<OutputType>} A promise that resolves to the transformed data.\n */\n abstract transformOutput(output: Uint8Array): Promise<OutputType>;\n}\n\nexport type SageMakerLLMContentHandler = BaseSageMakerContentHandler<\n string,\n string\n>;\n\n/**\n * The SageMakerEndpointInput interface defines the input parameters for\n * the SageMakerEndpoint class, which includes the endpoint name, client\n * options for the SageMaker client, the content handler, and optional\n * keyword arguments for the model and the endpoint.\n */\nexport interface SageMakerEndpointInput extends BaseLLMParams {\n /**\n * The name of the endpoint from the deployed SageMaker model. Must be unique\n * within an AWS Region.\n */\n endpointName: string;\n /**\n * Options passed to the SageMaker client.\n */\n clientOptions: SageMakerRuntimeClientConfig;\n /**\n * Key word arguments to pass to the model.\n */\n modelKwargs?: Record<string, unknown>;\n /**\n * Optional attributes passed to the InvokeEndpointCommand\n */\n endpointKwargs?: Record<string, unknown>;\n /**\n * The content handler class that provides an input and output transform\n * functions to handle formats between LLM and the endpoint.\n */\n contentHandler: SageMakerLLMContentHandler;\n streaming?: boolean;\n}\n\n/**\n * The SageMakerEndpoint class is used to interact with SageMaker\n * Inference Endpoint models. It uses the AWS client for authentication,\n * which automatically loads credentials.\n * If a specific credential profile is to be used, the name of the profile\n * from the ~/.aws/credentials file must be passed. The credentials or\n * roles used should have the required policies to access the SageMaker\n * endpoint.\n */\nexport class SageMakerEndpoint extends LLM<BaseLLMCallOptions> {\n lc_serializable = true;\n\n static lc_name() {\n return \"SageMakerEndpoint\";\n }\n\n get lc_secrets(): { [key: string]: string } | undefined {\n return {\n \"clientOptions.credentials.accessKeyId\": \"AWS_ACCESS_KEY_ID\",\n \"clientOptions.credentials.secretAccessKey\": \"AWS_SECRET_ACCESS_KEY\",\n \"clientOptions.credentials.sessionToken\": \"AWS_SESSION_TOKEN\",\n };\n }\n\n endpointName: string;\n\n modelKwargs?: Record<string, unknown>;\n\n endpointKwargs?: Record<string, unknown>;\n\n client: SageMakerRuntimeClient;\n\n contentHandler: SageMakerLLMContentHandler;\n\n streaming: boolean;\n\n constructor(fields: SageMakerEndpointInput) {\n super(fields);\n\n if (!fields.clientOptions.region) {\n throw new Error(\n `Please pass a \"clientOptions\" object with a \"region\" field to the constructor`\n );\n }\n\n const endpointName = fields?.endpointName;\n if (!endpointName) {\n throw new Error(`Please pass an \"endpointName\" field to the constructor`);\n }\n\n const contentHandler = fields?.contentHandler;\n if (!contentHandler) {\n throw new Error(\n `Please pass a \"contentHandler\" field to the constructor`\n );\n }\n\n this.endpointName = fields.endpointName;\n this.contentHandler = fields.contentHandler;\n this.endpointKwargs = fields.endpointKwargs;\n this.modelKwargs = fields.modelKwargs;\n this.streaming = fields.streaming ?? false;\n this.client = new SageMakerRuntimeClient(fields.clientOptions);\n }\n\n _llmType() {\n return \"sagemaker_endpoint\";\n }\n\n /**\n * Calls the SageMaker endpoint and retrieves the result.\n * @param {string} prompt The input prompt.\n * @param {this[\"ParsedCallOptions\"]} options Parsed call options.\n * @param {CallbackManagerForLLMRun} runManager Optional run manager.\n * @returns {Promise<string>} A promise that resolves to the generated string.\n */\n /** @ignore */\n async _call(\n prompt: string,\n options: this[\"ParsedCallOptions\"],\n runManager?: CallbackManagerForLLMRun\n ): Promise<string> {\n return this.streaming\n ? await this.streamingCall(prompt, options, runManager)\n : await this.noStreamingCall(prompt, options);\n }\n\n private async streamingCall(\n prompt: string,\n options: this[\"ParsedCallOptions\"],\n runManager?: CallbackManagerForLLMRun\n ): Promise<string> {\n const chunks = [];\n for await (const chunk of this._streamResponseChunks(\n prompt,\n options,\n runManager\n )) {\n chunks.push(chunk.text);\n }\n return chunks.join(\"\");\n }\n\n private async noStreamingCall(\n prompt: string,\n options: this[\"ParsedCallOptions\"]\n ): Promise<string> {\n const body = await this.contentHandler.transformInput(\n prompt,\n this.modelKwargs ?? {}\n );\n const { contentType, accepts } = this.contentHandler;\n\n const response = await this.caller.call(() =>\n this.client.send(\n new InvokeEndpointCommand({\n EndpointName: this.endpointName,\n Body: body,\n ContentType: contentType,\n Accept: accepts,\n ...this.endpointKwargs,\n }),\n { abortSignal: options.signal }\n )\n );\n\n if (response.Body === undefined) {\n throw new Error(\"Inference result missing Body\");\n }\n return this.contentHandler.transformOutput(response.Body);\n }\n\n /**\n * Streams response chunks from the SageMaker endpoint.\n * @param {string} prompt The input prompt.\n * @param {this[\"ParsedCallOptions\"]} options Parsed call options.\n * @returns {AsyncGenerator<GenerationChunk>} An asynchronous generator yielding generation chunks.\n */\n async *_streamResponseChunks(\n prompt: string,\n options: this[\"ParsedCallOptions\"],\n runManager?: CallbackManagerForLLMRun\n ): AsyncGenerator<GenerationChunk> {\n const body = await this.contentHandler.transformInput(\n prompt,\n this.modelKwargs ?? {}\n );\n const { contentType, accepts } = this.contentHandler;\n\n const stream = await this.caller.call(() =>\n this.client.send(\n new InvokeEndpointWithResponseStreamCommand({\n EndpointName: this.endpointName,\n Body: body,\n ContentType: contentType,\n Accept: accepts,\n ...this.endpointKwargs,\n }),\n { abortSignal: options.signal }\n )\n );\n\n if (!stream.Body) {\n throw new Error(\"Inference result missing Body\");\n }\n\n for await (const chunk of stream.Body) {\n if (chunk.PayloadPart && chunk.PayloadPart.Bytes) {\n const text = await this.contentHandler.transformOutput(\n chunk.PayloadPart.Bytes\n );\n yield new GenerationChunk({\n text,\n generationInfo: {\n ...chunk,\n response: undefined,\n },\n });\n await runManager?.handleLLMNewToken(text);\n } else if (chunk.InternalStreamFailure) {\n throw new Error(chunk.InternalStreamFailure.message);\n } else if (chunk.ModelStreamError) {\n throw new Error(chunk.ModelStreamError.message);\n }\n }\n }\n}\n"],"mappings":";;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;AAyCA,IAAsB,8BAAtB,MAAyE;CACvE,cAAc;CAEd,UAAU;;;;;;;;;;;AAmEZ,IAAa,oBAAb,cAAuCA,qCAAAA,IAAwB;CAC7D,kBAAkB;CAElB,OAAO,UAAU;AACf,SAAO;;CAGT,IAAI,aAAoD;AACtD,SAAO;GACL,yCAAyC;GACzC,6CAA6C;GAC7C,0CAA0C;GAC3C;;CAGH;CAEA;CAEA;CAEA;CAEA;CAEA;CAEA,YAAY,QAAgC;AAC1C,QAAM,OAAO;AAEb,MAAI,CAAC,OAAO,cAAc,OACxB,OAAM,IAAI,MACR,gFACD;AAIH,MAAI,CADiB,QAAQ,aAE3B,OAAM,IAAI,MAAM,yDAAyD;AAI3E,MAAI,CADmB,QAAQ,eAE7B,OAAM,IAAI,MACR,0DACD;AAGH,OAAK,eAAe,OAAO;AAC3B,OAAK,iBAAiB,OAAO;AAC7B,OAAK,iBAAiB,OAAO;AAC7B,OAAK,cAAc,OAAO;AAC1B,OAAK,YAAY,OAAO,aAAa;AACrC,OAAK,SAAS,IAAIC,kCAAAA,uBAAuB,OAAO,cAAc;;CAGhE,WAAW;AACT,SAAO;;;;;;;;;;CAWT,MAAM,MACJ,QACA,SACA,YACiB;AACjB,SAAO,KAAK,YACR,MAAM,KAAK,cAAc,QAAQ,SAAS,WAAW,GACrD,MAAM,KAAK,gBAAgB,QAAQ,QAAQ;;CAGjD,MAAc,cACZ,QACA,SACA,YACiB;EACjB,MAAM,SAAS,EAAE;AACjB,aAAW,MAAM,SAAS,KAAK,sBAC7B,QACA,SACA,WACD,CACC,QAAO,KAAK,MAAM,KAAK;AAEzB,SAAO,OAAO,KAAK,GAAG;;CAGxB,MAAc,gBACZ,QACA,SACiB;EACjB,MAAM,OAAO,MAAM,KAAK,eAAe,eACrC,QACA,KAAK,eAAe,EAAE,CACvB;EACD,MAAM,EAAE,aAAa,YAAY,KAAK;EAEtC,MAAM,WAAW,MAAM,KAAK,OAAO,WACjC,KAAK,OAAO,KACV,IAAIC,kCAAAA,sBAAsB;GACxB,cAAc,KAAK;GACnB,MAAM;GACN,aAAa;GACb,QAAQ;GACR,GAAG,KAAK;GACT,CAAC,EACF,EAAE,aAAa,QAAQ,QAAQ,CAChC,CACF;AAED,MAAI,SAAS,SAAS,KAAA,EACpB,OAAM,IAAI,MAAM,gCAAgC;AAElD,SAAO,KAAK,eAAe,gBAAgB,SAAS,KAAK;;;;;;;;CAS3D,OAAO,sBACL,QACA,SACA,YACiC;EACjC,MAAM,OAAO,MAAM,KAAK,eAAe,eACrC,QACA,KAAK,eAAe,EAAE,CACvB;EACD,MAAM,EAAE,aAAa,YAAY,KAAK;EAEtC,MAAM,SAAS,MAAM,KAAK,OAAO,WAC/B,KAAK,OAAO,KACV,IAAIC,kCAAAA,wCAAwC;GAC1C,cAAc,KAAK;GACnB,MAAM;GACN,aAAa;GACb,QAAQ;GACR,GAAG,KAAK;GACT,CAAC,EACF,EAAE,aAAa,QAAQ,QAAQ,CAChC,CACF;AAED,MAAI,CAAC,OAAO,KACV,OAAM,IAAI,MAAM,gCAAgC;AAGlD,aAAW,MAAM,SAAS,OAAO,KAC/B,KAAI,MAAM,eAAe,MAAM,YAAY,OAAO;GAChD,MAAM,OAAO,MAAM,KAAK,eAAe,gBACrC,MAAM,YAAY,MACnB;AACD,SAAM,IAAIC,wBAAAA,gBAAgB;IACxB;IACA,gBAAgB;KACd,GAAG;KACH,UAAU,KAAA;KACX;IACF,CAAC;AACF,SAAM,YAAY,kBAAkB,KAAK;aAChC,MAAM,sBACf,OAAM,IAAI,MAAM,MAAM,sBAAsB,QAAQ;WAC3C,MAAM,iBACf,OAAM,IAAI,MAAM,MAAM,iBAAiB,QAAQ"}