UNPKG

ai-utils.js

Version:

Build AI applications, chatbots, and agents with JavaScript and TypeScript.

175 lines (174 loc) 6.36 kB
"use strict"; var __importDefault = (this && this.__importDefault) || function (mod) { return (mod && mod.__esModule) ? mod : { "default": mod }; }; Object.defineProperty(exports, "__esModule", { value: true }); exports.HuggingFaceTextGenerationModel = void 0; const zod_1 = __importDefault(require("zod")); const AbstractModel_js_1 = require("../../model-function/AbstractModel.cjs"); const callWithRetryAndThrottle_js_1 = require("../../util/api/callWithRetryAndThrottle.cjs"); const postToApi_js_1 = require("../../util/api/postToApi.cjs"); const HuggingFaceError_js_1 = require("./HuggingFaceError.cjs"); const PromptMappingTextGenerationModel_js_1 = require("../../prompt/PromptMappingTextGenerationModel.cjs"); /** * Create a text generation model that calls a Hugging Face Inference API Text Generation Task. * * @see https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task * * @example * const model = new HuggingFaceTextGenerationModel({ * model: "tiiuae/falcon-7b", * temperature: 0.7, * maxTokens: 500, * retry: retryWithExponentialBackoff({ maxTries: 5 }), * }); * * const { text } = await generateText( * model, * "Write a short story about a robot learning to love:\n\n" * ); */ class HuggingFaceTextGenerationModel extends AbstractModel_js_1.AbstractModel { constructor(settings) { super({ settings }); Object.defineProperty(this, "provider", { enumerable: true, configurable: true, writable: true, value: "huggingface" }); Object.defineProperty(this, "contextWindowSize", { enumerable: true, configurable: true, writable: true, value: undefined }); Object.defineProperty(this, "tokenizer", { enumerable: true, configurable: true, writable: true, value: undefined }); Object.defineProperty(this, "countPromptTokens", { enumerable: true, configurable: true, writable: true, value: undefined }); Object.defineProperty(this, "generateDeltaStreamResponse", { enumerable: true, configurable: true, writable: true, value: undefined }); Object.defineProperty(this, "extractTextDelta", { enumerable: true, configurable: true, writable: true, value: undefined }); } get modelName() { return this.settings.model; } get apiKey() { const apiKey = this.settings.apiKey ?? process.env.HUGGINGFACE_API_KEY; if (apiKey == null) { throw new Error("No Hugging Face API key provided. Pass it in the constructor or set the HUGGINGFACE_API_KEY environment variable."); } return apiKey; } async callAPI(prompt, options) { const run = options?.run; const settings = options?.settings; const callSettings = Object.assign({ apiKey: this.apiKey, options: { useCache: true, waitForModel: true, }, }, this.settings, settings, { abortSignal: run?.abortSignal, inputs: prompt, }); return (0, callWithRetryAndThrottle_js_1.callWithRetryAndThrottle)({ retry: this.settings.retry, throttle: this.settings.throttle, call: async () => callHuggingFaceTextGenerationAPI(callSettings), }); } generateTextResponse(prompt, options) { return this.callAPI(prompt, options); } extractText(response) { return response[0].generated_text; } mapPrompt(promptMapping) { return new PromptMappingTextGenerationModel_js_1.PromptMappingTextGenerationModel({ model: this, promptMapping, }); } withSettings(additionalSettings) { return new HuggingFaceTextGenerationModel(Object.assign({}, this.settings, additionalSettings)); } get maxCompletionTokens() { return this.settings.maxNewTokens; } withMaxCompletionTokens(maxCompletionTokens) { return this.withSettings({ maxNewTokens: maxCompletionTokens }); } withStopTokens() { // stop tokens are not supported by the HuggingFace API return this; } } exports.HuggingFaceTextGenerationModel = HuggingFaceTextGenerationModel; const huggingFaceTextGenerationResponseSchema = zod_1.default.array(zod_1.default.object({ generated_text: zod_1.default.string(), })); /** * Call a Hugging Face Inference API Text Generation Task to generate a text completion for the given prompt. * * @see https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task * * @example * const response = await callHuggingFaceTextGenerationAPI({ * apiKey: HUGGINGFACE_API_KEY, * model: "tiiuae/falcon-7b", * inputs: "Write a short story about a robot learning to love:\n\n", * temperature: 700, * maxNewTokens: 500, * options: { * waitForModel: true, * }, * }); * * console.log(response[0].generated_text); */ async function callHuggingFaceTextGenerationAPI({ baseUrl = "https://api-inference.huggingface.co/models", abortSignal, apiKey, model, inputs, topK, topP, temperature, repetitionPenalty, maxNewTokens, maxTime, numReturnSequences, doSample, options, }) { return (0, postToApi_js_1.postJsonToApi)({ url: `${baseUrl}/${model}`, apiKey, body: { inputs, top_k: topK, top_p: topP, temperature, repetition_penalty: repetitionPenalty, max_new_tokens: maxNewTokens, max_time: maxTime, num_return_sequences: numReturnSequences, do_sample: doSample, options: options ? { use_cache: options?.useCache, wait_for_model: options?.waitForModel, } : undefined, }, failedResponseHandler: HuggingFaceError_js_1.failedHuggingFaceCallResponseHandler, successfulResponseHandler: (0, postToApi_js_1.createJsonResponseHandler)(huggingFaceTextGenerationResponseSchema), abortSignal, }); }