UNPKG

@langchain/community

Version:
142 lines (139 loc) 4.71 kB
Object.defineProperty(exports, Symbol.toStringTag, { value: "Module" }); const require_runtime = require("../_virtual/_rolldown/runtime.cjs"); let _langchain_core_utils_env = require("@langchain/core/utils/env"); let _langchain_core_language_models_llms = require("@langchain/core/language_models/llms"); //#region src/llms/ai21.ts var ai21_exports = /* @__PURE__ */ require_runtime.__exportAll({ AI21: () => AI21 }); /** * Class representing the AI21 language model. It extends the LLM (Large * Language Model) class, providing a standard interface for interacting * with the AI21 language model. */ var AI21 = class AI21 extends _langchain_core_language_models_llms.LLM { lc_serializable = true; model = "j2-jumbo-instruct"; temperature = .7; maxTokens = 1024; minTokens = 0; topP = 1; presencePenalty = AI21.getDefaultAI21PenaltyData(); countPenalty = AI21.getDefaultAI21PenaltyData(); frequencyPenalty = AI21.getDefaultAI21PenaltyData(); numResults = 1; logitBias; ai21ApiKey; stop; baseUrl; constructor(fields) { super(fields ?? {}); this.model = fields?.model ?? this.model; this.temperature = fields?.temperature ?? this.temperature; this.maxTokens = fields?.maxTokens ?? this.maxTokens; this.minTokens = fields?.minTokens ?? this.minTokens; this.topP = fields?.topP ?? this.topP; this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty; this.countPenalty = fields?.countPenalty ?? this.countPenalty; this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty; this.numResults = fields?.numResults ?? this.numResults; this.logitBias = fields?.logitBias; this.ai21ApiKey = fields?.ai21ApiKey ?? (0, _langchain_core_utils_env.getEnvironmentVariable)("AI21_API_KEY"); this.stop = fields?.stop; this.baseUrl = fields?.baseUrl; } /** * Method to validate the environment. It checks if the AI21 API key is * set. If not, it throws an error. */ validateEnvironment() { if (!this.ai21ApiKey) throw new Error(`No AI21 API key found. Please set it as "AI21_API_KEY" in your environment variables.`); } /** * Static method to get the default penalty data for AI21. * @returns AI21PenaltyData */ static getDefaultAI21PenaltyData() { return { scale: 0, applyToWhitespaces: true, applyToPunctuations: true, applyToNumbers: true, applyToStopwords: true, applyToEmojis: true }; } /** Get the type of LLM. */ _llmType() { return "ai21"; } /** Get the default parameters for calling AI21 API. */ get defaultParams() { return { temperature: this.temperature, maxTokens: this.maxTokens, minTokens: this.minTokens, topP: this.topP, presencePenalty: this.presencePenalty, countPenalty: this.countPenalty, frequencyPenalty: this.frequencyPenalty, numResults: this.numResults, logitBias: this.logitBias }; } /** Get the identifying parameters for this LLM. */ get identifyingParams() { return { ...this.defaultParams, model: this.model }; } /** Call out to AI21's complete endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: let response = ai21._call("Tell me a joke."); */ async _call(prompt, options) { let stop = options?.stop; this.validateEnvironment(); if (this.stop && stop && this.stop.length > 0 && stop.length > 0) throw new Error("`stop` found in both the input and default params."); stop = this.stop ?? stop ?? []; const url = `${this.baseUrl ?? this.model === "j1-grande-instruct" ? "https://api.ai21.com/studio/v1/experimental" : "https://api.ai21.com/studio/v1"}/${this.model}/complete`; const headers = { Authorization: `Bearer ${this.ai21ApiKey}`, "Content-Type": "application/json" }; const data = { prompt, stopSequences: stop, ...this.defaultParams }; const responseData = await this.caller.callWithOptions({}, async () => { const response = await fetch(url, { method: "POST", headers, body: JSON.stringify(data), signal: options.signal }); if (!response.ok) { const error = /* @__PURE__ */ new Error(`AI21 call failed with status code ${response.status}`); error.response = response; throw error; } return response.json(); }); if (!responseData.completions || responseData.completions.length === 0 || !responseData.completions[0].data) throw new Error("No completions found in response"); return responseData.completions[0].data.text ?? ""; } }; //#endregion exports.AI21 = AI21; Object.defineProperty(exports, "ai21_exports", { enumerable: true, get: function() { return ai21_exports; } }); //# sourceMappingURL=ai21.cjs.map