UNPKG

@langchain/community

Version:
466 lines (465 loc) 15.2 kB
import { __exportAll } from "../_virtual/_rolldown/runtime.js"; import { authenticateAndSetGatewayInstance, authenticateAndSetInstance, checkValidProps, expectOneOf } from "../utils/ibm.js"; import { GenerationChunk } from "@langchain/core/outputs"; import { AsyncCaller } from "@langchain/core/utils/async_caller"; import { BaseLLM } from "@langchain/core/language_models/llms"; //#region src/llms/ibm.ts var ibm_exports = /* @__PURE__ */ __exportAll({ WatsonxLLM: () => WatsonxLLM }); /** * Integration with an LLM. */ var WatsonxLLM = class extends BaseLLM { static lc_name() { return "WatsonxLLM"; } lc_serializable = true; streaming = false; model; maxRetries = 0; version = "2024-05-31"; serviceUrl; maxTokens; maxNewTokens; spaceId; projectId; idOrName; decodingMethod; lengthPenalty; minNewTokens; randomSeed; stopSequence; temperature; timeLimit; topK; topP; repetitionPenalty; truncateInputTokens; returnOptions; includeStopSequence; maxConcurrency; watsonxCallbacks; modelGateway = false; modelGatewayKwargs = {}; service; gateway; checkValidProperties(fields, includeCommonProps = true) { const authProps = [ "serviceUrl", "watsonxAIApikey", "watsonxAIBearerToken", "watsonxAIUsername", "watsonxAIPassword", "watsonxAIUrl", "watsonxAIAuthType", "disableSSL" ]; const sharedProps = [ "maxRetries", "watsonxCallbacks", "authenticator", "serviceUrl", "version", "streaming", "callbackManager", "callbacks", "maxConcurrency", "cache", "metadata", "concurrency", "onFailedAttempt", "concurrency", "verbose", "tags" ]; const gatewayProps = [ "temperature", "topP", "model", "modelGatewayKwargs", "modelGateway", "verbose", "tags", "maxTokens" ]; const deploymentProps = ["idOrName"]; const projectOrSpaceProps = [ "spaceId", "projectId", "temperature", "topP", "timeLimit", "model", "maxNewTokens", "decodingMethod", "lengthPenalty", "minNewTokens", "randomSeed", "stopSequence", "topK", "repetitionPenalty", "truncateInputTokens", "returnOptions", "includeStopSequence" ]; const validProps = []; if (includeCommonProps) validProps.push(...authProps, ...sharedProps); if (this.modelGateway) validProps.push(...gatewayProps); else if (this.idOrName) validProps.push(...deploymentProps); else if (this.spaceId || this.projectId) validProps.push(...projectOrSpaceProps); checkValidProps(fields, validProps); } constructor(fields) { super(fields); expectOneOf(fields, [ "spaceId", "projectId", "idOrName", "modelGateway" ], true); this.idOrName = fields?.idOrName; this.projectId = fields?.projectId; this.modelGateway = fields.modelGateway || this.modelGateway; this.spaceId = fields?.spaceId; this.checkValidProperties(fields); this.model = fields.model ?? this.model; this.serviceUrl = fields.serviceUrl; this.version = fields.version; this.topP = fields.topP; this.temperature = fields.temperature; this.maxNewTokens = fields.maxNewTokens ?? fields.maxTokens; this.decodingMethod = fields.decodingMethod; this.lengthPenalty = fields.lengthPenalty; this.minNewTokens = fields.minNewTokens; this.maxTokens = fields.maxTokens; this.randomSeed = fields.randomSeed; this.stopSequence = fields.stopSequence; this.timeLimit = fields.timeLimit; this.topK = fields.topK; this.repetitionPenalty = fields.repetitionPenalty; this.truncateInputTokens = fields.truncateInputTokens; this.returnOptions = fields.returnOptions; this.includeStopSequence = fields.includeStopSequence; this.modelGatewayKwargs = fields.modelGatewayKwargs || this.modelGatewayKwargs; this.maxRetries = fields.maxRetries || this.maxRetries; this.maxConcurrency = fields.maxConcurrency; this.streaming = fields.streaming || this.streaming; this.watsonxCallbacks = fields.watsonxCallbacks || this.watsonxCallbacks; const { watsonxAIApikey, watsonxAIAuthType, watsonxAIBearerToken, watsonxAIUsername, watsonxAIPassword, watsonxAIUrl, disableSSL, version, serviceUrl } = fields; const authData = { watsonxAIApikey, watsonxAIAuthType, watsonxAIBearerToken, watsonxAIUsername, watsonxAIPassword, watsonxAIUrl, disableSSL, version, serviceUrl }; if (this.modelGateway) { const gateway = authenticateAndSetGatewayInstance(authData); if (gateway) this.gateway = gateway; else throw new Error("You have not provided any type of authentication"); } else { const service = authenticateAndSetInstance(authData); if (service) this.service = service; else throw new Error("You have not provided any type of authentication"); } } get lc_secrets() { return { authenticator: "AUTHENTICATOR", apiKey: "WATSONX_AI_APIKEY", apikey: "WATSONX_AI_APIKEY", watsonxAIAuthType: "WATSONX_AI_AUTH_TYPE", watsonxAIApikey: "WATSONX_AI_APIKEY", watsonxAIBearerToken: "WATSONX_AI_BEARER_TOKEN", watsonxAIUsername: "WATSONX_AI_USERNAME", watsonxAIPassword: "WATSONX_AI_PASSWORD", watsonxAIUrl: "WATSONX_AI_URL" }; } get lc_aliases() { return { authenticator: "authenticator", apikey: "watsonx_ai_apikey", apiKey: "watsonx_ai_apikey", watsonxAIAuthType: "watsonx_ai_auth_type", watsonxAIApikey: "watsonx_ai_apikey", watsonxAIBearerToken: "watsonx_ai_bearer_token", watsonxAIUsername: "watsonx_ai_username", watsonxAIPassword: "watsonx_ai_password", watsonxAIUrl: "watsonx_ai_url" }; } invocationParams(options) { const { parameters } = options; const { signal, maxRetries, maxConcurrency, timeout, ...rest } = options; if (parameters) this.checkValidProperties(parameters, false); if (this.idOrName && Object.keys(rest).length > 0) throw new Error("Options cannot be provided to a deployed model"); if (this.idOrName) return void 0; if (this.modelGateway) { const modelGatewayParams = { ...this?.modelGatewayKwargs, ...parameters?.modelGatewayKwargs }; return { stop: options?.stop ?? this.stopSequence, temperature: parameters?.temperature ?? this.temperature, topP: parameters?.topP ?? this.topP, maxTokens: parameters?.maxTokens ?? this.maxTokens, ...modelGatewayParams }; } return { stop_sequences: options?.stop ?? this.stopSequence, temperature: parameters?.temperature ?? this.temperature, top_p: parameters?.topP ?? this.topP, max_new_tokens: parameters?.maxNewTokens ?? this.maxNewTokens ?? parameters?.maxTokens ?? this.maxTokens, decoding_method: parameters?.decodingMethod ?? this.decodingMethod, length_penalty: parameters?.lengthPenalty ?? this.lengthPenalty, min_new_tokens: parameters?.minNewTokens ?? this.minNewTokens, random_seed: parameters?.randomSeed ?? this.randomSeed, time_limit: parameters?.timeLimit ?? this.timeLimit ?? timeout, top_k: parameters?.topK ?? this.topK, repetition_penalty: parameters?.repetitionPenalty ?? this.repetitionPenalty, truncate_input_tokens: parameters?.truncateInputTokens ?? this.truncateInputTokens, return_options: parameters?.returnOptions ?? this.returnOptions, include_stop_sequence: parameters?.includeStopSequence ?? this.includeStopSequence }; } invocationCallbacks(options) { return options.watsonxCallbacks ?? this.watsonxCallbacks; } scopeId() { if (this.projectId) return { projectId: this.projectId, modelId: this.model }; else if (this.spaceId) return { spaceId: this.spaceId, modelId: this.model }; else if (this.idOrName) return { idOrName: this.idOrName, modelId: this.model }; else if (this.modelGateway) return { modelId: this.model }; else throw new Error("Invalid mode type. Please make sure you have provided correct parameters"); } async listModels() { if (this.service) { const { service } = this; const listModelParams = { filters: "function_text_generation" }; return (await this.completionWithRetry(() => service.listFoundationModelSpecs(listModelParams))).result.resources?.map((item) => item.model_id); } else throw new Error("This method is not supported in this model gateway"); } async generateSingleMessage(input, options, stream) { const { signal, stop, maxRetries, maxConcurrency, timeout, ...requestOptions } = options; const parameters = this.invocationParams(options); const watsonxCallbacks = this.invocationCallbacks(options); if (stream) { if (this.service) if (this.idOrName) return await this.service.deploymentGenerateTextStream({ idOrName: this.idOrName, ...requestOptions, parameters: { ...parameters, prompt_variables: { input } }, returnObject: true, signal }); else return await this.service.generateTextStream({ input, parameters, ...this.scopeId(), ...requestOptions, returnObject: true, signal }, watsonxCallbacks); else if (this.gateway) return await this.gateway.completion.create({ ...parameters, model: this.model, prompt: input, stream: true, signal, returnObject: true }); } else if (this.service) { const tokenUsage = { generated_token_count: 0, input_token_count: 0 }; return (await (this.idOrName ? this.service.deploymentGenerateText({ ...requestOptions, idOrName: this.idOrName, parameters: { ...parameters, prompt_variables: { input } }, signal }, watsonxCallbacks) : this.service.generateText({ input, parameters, ...this.scopeId(), ...requestOptions, signal }, watsonxCallbacks))).result.results.map((result) => { tokenUsage.generated_token_count += result.generated_token_count ? result.generated_token_count : 0; tokenUsage.input_token_count += result.input_token_count ? result.input_token_count : 0; return { text: result.generated_text, generationInfo: { stop_reason: result.stop_reason, input_token_count: result.input_token_count, generated_token_count: result.generated_token_count } }; }); } else if (this.gateway) { const textGeneration = await this.gateway.completion.create({ ...parameters, prompt: input, model: this.model, signal }); const tokenUsage = textGeneration.result.usage; return textGeneration.result.choices.map((choice) => { return { text: choice.text ?? "", generationInfo: { stop_reason: choice.finish_reason, input_token_count: tokenUsage?.prompt_tokens, generated_token_count: tokenUsage?.completion_tokens } }; }); } throw new Error("No service or gateway set. Please check your intsance init"); } async completionWithRetry(callback, options) { const caller = new AsyncCaller({ maxConcurrency: options?.maxConcurrency || this.maxConcurrency, maxRetries: this.maxRetries }); return options ? caller.callWithOptions({ signal: options.signal }, async () => callback()) : caller.call(async () => callback()); } async _generate(prompts, options, runManager) { const tokenUsage = { generated_token_count: 0, input_token_count: 0 }; if (this.streaming) return { generations: await Promise.all(prompts.map(async (prompt, promptIdx) => { const stream = this._streamResponseChunks(prompt, options); const geneartionsArray = []; for await (const chunk of stream) { const completion = chunk?.generationInfo?.completion ?? 0; geneartionsArray[completion] ??= { text: "", stop_reason: "", generated_token_count: 0, input_token_count: 0 }; geneartionsArray[completion].generated_token_count = chunk?.generationInfo?.usage_metadata.generated_token_count ?? 0; geneartionsArray[completion].input_token_count += chunk?.generationInfo?.usage_metadata.input_token_count ?? 0; geneartionsArray[completion].stop_reason = chunk?.generationInfo?.stop_reason; geneartionsArray[completion].text += chunk.text; if (chunk.text) runManager?.handleLLMNewToken(chunk.text, { prompt: promptIdx, completion: 0 }); } return geneartionsArray.map((item) => { const { text, ...rest } = item; tokenUsage.generated_token_count = rest.generated_token_count; tokenUsage.input_token_count += rest.input_token_count; return { text, generationInfo: rest }; }); })), llmOutput: { tokenUsage } }; else return { generations: await Promise.all(prompts.map(async (prompt) => { const callback = () => this.generateSingleMessage(prompt, options, false); const response = await this.completionWithRetry(callback, options); const [generated_token_count, input_token_count] = response.reduce((acc, curr) => { let generated = 0; let inputed = 0; if (curr?.generationInfo?.generated_token_count) generated = curr.generationInfo.generated_token_count + acc[0]; if (curr?.generationInfo?.input_token_count) inputed = curr.generationInfo.input_token_count + acc[1]; return [generated, inputed]; }, [0, 0]); tokenUsage.generated_token_count += generated_token_count; tokenUsage.input_token_count += input_token_count; return response; })), llmOutput: { tokenUsage } }; } async getNumTokens(content, options) { if (this.service) { const { service } = this; const params = { ...this.scopeId(), input: content, parameters: options }; const callback = () => service.tokenizeText(params); return (await this.completionWithRetry(callback)).result.result.token_count; } else throw new Error("This method is not supported in model gateway"); } async *_streamResponseChunks(prompt, options, runManager) { const callback = () => this.generateSingleMessage(prompt, options, true); const streamInferDeployedPrompt = await this.completionWithRetry(callback); const responseChunk = { id: 0, event: "", data: { results: [] } }; for await (const chunk of streamInferDeployedPrompt) { const results = "model_id" in chunk.data ? chunk.data.results : chunk.data.choices; const usage = "usage" in chunk.data ? chunk.data.usage : {}; for (const [index, item] of results.entries()) { yield new GenerationChunk("generated_text" in item ? { text: item.generated_text, generationInfo: { stop_reason: item.stop_reason, completion: index, usage_metadata: { generated_token_count: item.generated_token_count, input_token_count: item.input_token_count, stop_reason: item.stop_reason } } } : { text: item.text ?? "", generationInfo: { stop_reason: item.finish_reason, completion: index, usage_metadata: { generated_token_count: usage?.completion_tokens, input_token_count: usage?.prompt_tokens, stop_reason: item.finish_reason } } }); if (!this.streaming) runManager?.handleLLMNewToken("generated_text" in item ? item.generated_text : item.text ?? ""); } Object.assign(responseChunk, { id: 0, event: "", data: {} }); } } _llmType() { return "watsonx"; } }; //#endregion export { WatsonxLLM, ibm_exports }; //# sourceMappingURL=ibm.js.map