UNPKG

@langchain/community

Version:
115 lines (114 loc) 4.24 kB
Object.defineProperty(exports, Symbol.toStringTag, { value: "Module" }); const require_runtime = require("../_virtual/_rolldown/runtime.cjs"); let _langchain_core_outputs = require("@langchain/core/outputs"); let _langchain_core_utils_env = require("@langchain/core/utils/env"); let _langchain_core_language_models_llms = require("@langchain/core/language_models/llms"); //#region src/llms/replicate.ts var replicate_exports = /* @__PURE__ */ require_runtime.__exportAll({ Replicate: () => Replicate }); /** * Class responsible for managing the interaction with the Replicate API. * It handles the API key and model details, makes the actual API calls, * and converts the API response into a format usable by the rest of the * LangChain framework. * @example * ```typescript * const model = new Replicate({ * model: "replicate/flan-t5-xl:3ae0799123a1fe11f8c89fd99632f843fc5f7a761630160521c4253149754523", * }); * * const res = await model.invoke( * "Question: What would be a good company name for a company that makes colorful socks?\nAnswer:" * ); * console.log({ res }); * ``` */ var Replicate = class Replicate extends _langchain_core_language_models_llms.LLM { static lc_name() { return "Replicate"; } get lc_secrets() { return { apiKey: "REPLICATE_API_TOKEN" }; } lc_serializable = true; model; input; apiKey; promptKey; constructor(fields) { super(fields); const apiKey = fields?.apiKey ?? (0, _langchain_core_utils_env.getEnvironmentVariable)("REPLICATE_API_KEY") ?? (0, _langchain_core_utils_env.getEnvironmentVariable)("REPLICATE_API_TOKEN"); if (!apiKey) throw new Error("Please set the REPLICATE_API_TOKEN environment variable"); this.apiKey = apiKey; this.model = fields.model; this.input = fields.input ?? {}; this.promptKey = fields.promptKey; } _llmType() { return "replicate"; } /** @ignore */ async _call(prompt, options) { const replicate = await this._prepareReplicate(); const input = await this._getReplicateInput(replicate, prompt); const output = await this.caller.callWithOptions({ signal: options.signal }, () => replicate.run(this.model, { input })); if (typeof output === "string") return output; else if (Array.isArray(output)) return output.join(""); else return String(output); } async *_streamResponseChunks(prompt, options, runManager) { const replicate = await this._prepareReplicate(); const input = await this._getReplicateInput(replicate, prompt); const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () => replicate.stream(this.model, { input })); for await (const chunk of stream) { if (chunk.event === "output") { yield new _langchain_core_outputs.GenerationChunk({ text: chunk.data, generationInfo: chunk }); await runManager?.handleLLMNewToken(chunk.data ?? ""); } if (chunk.event === "done") yield new _langchain_core_outputs.GenerationChunk({ text: "", generationInfo: { finished: true } }); } } /** @ignore */ static async imports() { try { const { default: Replicate } = await import("replicate"); return { Replicate }; } catch { throw new Error("Please install replicate as a dependency with, e.g. `pnpm install replicate`"); } } async _prepareReplicate() { return new (await (Replicate.imports())).Replicate({ userAgent: "langchain", auth: this.apiKey }); } async _getReplicateInput(replicate, prompt) { if (this.promptKey === void 0) { const [modelString, versionString] = this.model.split(":"); const inputProperties = (await replicate.models.versions.get(modelString.split("/")[0], modelString.split("/")[1], versionString)).openapi_schema?.components?.schemas?.Input?.properties; if (inputProperties === void 0) this.promptKey = "prompt"; else this.promptKey = Object.entries(inputProperties).sort(([_keyA, valueA], [_keyB, valueB]) => { return (valueA["x-order"] || 0) - (valueB["x-order"] || 0); })[0][0] ?? "prompt"; } return { [this.promptKey]: prompt, ...this.input }; } }; //#endregion exports.Replicate = Replicate; Object.defineProperty(exports, "replicate_exports", { enumerable: true, get: function() { return replicate_exports; } }); //# sourceMappingURL=replicate.cjs.map