UNPKG

@langchain/community

Version:
118 lines (117 loc) 4.59 kB
/* eslint-disable no-process-env */ /* eslint-disable @typescript-eslint/no-non-null-assertion */ import { expect, test } from "@jest/globals"; import { SageMakerEndpoint, } from "../sagemaker_endpoint.js"; // yarn test:single /{path_to}/langchain/src/llms/tests/sagemaker.int.test.ts describe.skip("Test SageMaker LLM", () => { test("without streaming", async () => { class LLama213BHandler { constructor() { Object.defineProperty(this, "contentType", { enumerable: true, configurable: true, writable: true, value: "application/json" }); Object.defineProperty(this, "accepts", { enumerable: true, configurable: true, writable: true, value: "application/json" }); } async transformInput(prompt, modelKwargs) { const payload = { inputs: [[{ role: "user", content: prompt }]], parameters: modelKwargs, }; const input_str = JSON.stringify(payload); return new TextEncoder().encode(input_str); } async transformOutput(output) { const response_json = JSON.parse(new TextDecoder("utf-8").decode(output)); const content = response_json[0]?.generation.content ?? ""; return content; } } const contentHandler = new LLama213BHandler(); const model = new SageMakerEndpoint({ endpointName: "aws-productbot-ai-dev-llama-2-13b-chat", streaming: false, modelKwargs: { temperature: 0.5, max_new_tokens: 700, top_p: 0.9, }, endpointKwargs: { CustomAttributes: "accept_eula=true", }, contentHandler, clientOptions: { region: "us-east-1", credentials: { accessKeyId: process.env.AWS_ACCESS_KEY_ID, secretAccessKey: process.env.AWS_SECRET_ACCESS_KEY, }, }, }); const response = await model.invoke("hello, my name is John Doe, tell me a fun story about llamas."); expect(response.length).toBeGreaterThan(0); }); test("with streaming", async () => { class LLama213BHandler { constructor() { Object.defineProperty(this, "contentType", { enumerable: true, configurable: true, writable: true, value: "application/json" }); Object.defineProperty(this, "accepts", { enumerable: true, configurable: true, writable: true, value: "application/json" }); } async transformInput(prompt, modelKwargs) { const payload = { inputs: [[{ role: "user", content: prompt }]], parameters: modelKwargs, }; const input_str = JSON.stringify(payload); return new TextEncoder().encode(input_str); } async transformOutput(output) { return new TextDecoder("utf-8").decode(output); } } const contentHandler = new LLama213BHandler(); const model = new SageMakerEndpoint({ endpointName: "aws-productbot-ai-dev-llama-2-13b-chat", streaming: true, modelKwargs: { temperature: 0.5, max_new_tokens: 700, top_p: 0.9, }, endpointKwargs: { CustomAttributes: "accept_eula=true", }, contentHandler, clientOptions: { region: "us-east-1", credentials: { accessKeyId: process.env.AWS_ACCESS_KEY_ID, secretAccessKey: process.env.AWS_SECRET_ACCESS_KEY, }, }, }); const response = await model.invoke("hello, my name is John Doe, tell me a fun story about llamas in 3 paragraphs"); const chunks = []; for await (const chunk of response) { chunks.push(chunk); } expect(response.length).toBeGreaterThan(0); }); });