@langchain/community
Version:
Third-party integrations for LangChain.js
142 lines (139 loc) • 4.71 kB
JavaScript
Object.defineProperty(exports, Symbol.toStringTag, { value: "Module" });
const require_runtime = require("../_virtual/_rolldown/runtime.cjs");
let _langchain_core_utils_env = require("@langchain/core/utils/env");
let _langchain_core_language_models_llms = require("@langchain/core/language_models/llms");
//#region src/llms/ai21.ts
var ai21_exports = /* @__PURE__ */ require_runtime.__exportAll({ AI21: () => AI21 });
/**
* Class representing the AI21 language model. It extends the LLM (Large
* Language Model) class, providing a standard interface for interacting
* with the AI21 language model.
*/
var AI21 = class AI21 extends _langchain_core_language_models_llms.LLM {
lc_serializable = true;
model = "j2-jumbo-instruct";
temperature = .7;
maxTokens = 1024;
minTokens = 0;
topP = 1;
presencePenalty = AI21.getDefaultAI21PenaltyData();
countPenalty = AI21.getDefaultAI21PenaltyData();
frequencyPenalty = AI21.getDefaultAI21PenaltyData();
numResults = 1;
logitBias;
ai21ApiKey;
stop;
baseUrl;
constructor(fields) {
super(fields ?? {});
this.model = fields?.model ?? this.model;
this.temperature = fields?.temperature ?? this.temperature;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.minTokens = fields?.minTokens ?? this.minTokens;
this.topP = fields?.topP ?? this.topP;
this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty;
this.countPenalty = fields?.countPenalty ?? this.countPenalty;
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty;
this.numResults = fields?.numResults ?? this.numResults;
this.logitBias = fields?.logitBias;
this.ai21ApiKey = fields?.ai21ApiKey ?? (0, _langchain_core_utils_env.getEnvironmentVariable)("AI21_API_KEY");
this.stop = fields?.stop;
this.baseUrl = fields?.baseUrl;
}
/**
* Method to validate the environment. It checks if the AI21 API key is
* set. If not, it throws an error.
*/
validateEnvironment() {
if (!this.ai21ApiKey) throw new Error(`No AI21 API key found. Please set it as "AI21_API_KEY" in your environment variables.`);
}
/**
* Static method to get the default penalty data for AI21.
* @returns AI21PenaltyData
*/
static getDefaultAI21PenaltyData() {
return {
scale: 0,
applyToWhitespaces: true,
applyToPunctuations: true,
applyToNumbers: true,
applyToStopwords: true,
applyToEmojis: true
};
}
/** Get the type of LLM. */
_llmType() {
return "ai21";
}
/** Get the default parameters for calling AI21 API. */
get defaultParams() {
return {
temperature: this.temperature,
maxTokens: this.maxTokens,
minTokens: this.minTokens,
topP: this.topP,
presencePenalty: this.presencePenalty,
countPenalty: this.countPenalty,
frequencyPenalty: this.frequencyPenalty,
numResults: this.numResults,
logitBias: this.logitBias
};
}
/** Get the identifying parameters for this LLM. */
get identifyingParams() {
return {
...this.defaultParams,
model: this.model
};
}
/** Call out to AI21's complete endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
let response = ai21._call("Tell me a joke.");
*/
async _call(prompt, options) {
let stop = options?.stop;
this.validateEnvironment();
if (this.stop && stop && this.stop.length > 0 && stop.length > 0) throw new Error("`stop` found in both the input and default params.");
stop = this.stop ?? stop ?? [];
const url = `${this.baseUrl ?? this.model === "j1-grande-instruct" ? "https://api.ai21.com/studio/v1/experimental" : "https://api.ai21.com/studio/v1"}/${this.model}/complete`;
const headers = {
Authorization: `Bearer ${this.ai21ApiKey}`,
"Content-Type": "application/json"
};
const data = {
prompt,
stopSequences: stop,
...this.defaultParams
};
const responseData = await this.caller.callWithOptions({}, async () => {
const response = await fetch(url, {
method: "POST",
headers,
body: JSON.stringify(data),
signal: options.signal
});
if (!response.ok) {
const error = /* @__PURE__ */ new Error(`AI21 call failed with status code ${response.status}`);
error.response = response;
throw error;
}
return response.json();
});
if (!responseData.completions || responseData.completions.length === 0 || !responseData.completions[0].data) throw new Error("No completions found in response");
return responseData.completions[0].data.text ?? "";
}
};
//#endregion
exports.AI21 = AI21;
Object.defineProperty(exports, "ai21_exports", {
enumerable: true,
get: function() {
return ai21_exports;
}
});
//# sourceMappingURL=ai21.cjs.map