UNPKG

@vulcan-sql/extension-huggingface

Version:

Hugging Face feature for VulcanSQL

56 lines 3.01 kB
"use strict"; var _a; Object.defineProperty(exports, "__esModule", { value: true }); exports.Runner = exports.Builder = exports.TextGenerationFilter = void 0; const tslib_1 = require("tslib"); const core_1 = require("@vulcan-sql/core"); const lodash_1 = require("lodash"); const model_1 = require("../model"); const utils_1 = require("../utils"); /** * Get text generation url. Used gpt2 model be default value. * See: https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task * */ const getUrl = (model = 'gpt2') => `${model_1.apiInferenceEndpoint}/${model}`; const TextGenerationFilter = ({ args, value, options, }) => tslib_1.__awaiter(void 0, void 0, void 0, function* () { const token = options === null || options === void 0 ? void 0 : options.accessToken; if (!token) throw new core_1.InternalError('please given access token'); if (!(0, lodash_1.isArray)(value)) throw new core_1.InternalError('Input value must be an array of object'); if (!(typeof args === 'object') || !(0, lodash_1.has)(args, 'query')) throw new core_1.InternalError('Must provide "query" keyword argument'); if (!args['query']) throw new core_1.InternalError('The "query" argument must have value'); // Convert the data result to JSON string as question context const context = JSON.stringify(value); // omit hidden value '__keywords' from args, it generated from nunjucks and not related to HuggingFace. const _b = (0, lodash_1.omit)(args, '__keywords'), { query, model, endpoint } = _b, otherArgs = tslib_1.__rest(_b, ["query", "model", "endpoint"]); const inferenceOptions = (0, lodash_1.pick)(otherArgs, ['use_cache', 'wait_for_model']); const parameters = (0, lodash_1.omit)(otherArgs, ['use_cache', 'wait_for_model', 'endpoint']); const payload = { inputs: `Context: ${context}. Question: ${query}}`, parameters: { return_full_text: false, max_new_tokens: 250, temperature: 0.1, } }; if (!(0, lodash_1.isEmpty)(parameters)) payload.parameters = parameters; if (!(0, lodash_1.isEmpty)(inferenceOptions)) payload.options = inferenceOptions; try { // if not given endpoint, use default HuggingFace inference endpoint const url = endpoint ? endpoint : getUrl(model); const results = yield (0, utils_1.postRequest)(url, payload, token); // get the "generated_text" field, and trim leading and trailing white space. return String(results[0]['generated_text']).trim(); } catch (error) { throw new core_1.InternalError(`Error when sending data to Hugging Face for executing TextGeneration tasks, details: ${error.message}`); } }); exports.TextGenerationFilter = TextGenerationFilter; _a = (0, core_1.createFilterExtension)('huggingface_text_generation', exports.TextGenerationFilter), exports.Builder = _a[0], exports.Runner = _a[1]; //# sourceMappingURL=textGeneration.js.map