@vulcan-sql/extension-huggingface
Version:
Hugging Face feature for VulcanSQL
56 lines • 3.01 kB
JavaScript
;
var _a;
Object.defineProperty(exports, "__esModule", { value: true });
exports.Runner = exports.Builder = exports.TextGenerationFilter = void 0;
const tslib_1 = require("tslib");
const core_1 = require("@vulcan-sql/core");
const lodash_1 = require("lodash");
const model_1 = require("../model");
const utils_1 = require("../utils");
/**
* Get text generation url. Used gpt2 model be default value.
* See: https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
* */
const getUrl = (model = 'gpt2') => `${model_1.apiInferenceEndpoint}/${model}`;
const TextGenerationFilter = ({ args, value, options, }) => tslib_1.__awaiter(void 0, void 0, void 0, function* () {
const token = options === null || options === void 0 ? void 0 : options.accessToken;
if (!token)
throw new core_1.InternalError('please given access token');
if (!(0, lodash_1.isArray)(value))
throw new core_1.InternalError('Input value must be an array of object');
if (!(typeof args === 'object') || !(0, lodash_1.has)(args, 'query'))
throw new core_1.InternalError('Must provide "query" keyword argument');
if (!args['query'])
throw new core_1.InternalError('The "query" argument must have value');
// Convert the data result to JSON string as question context
const context = JSON.stringify(value);
// omit hidden value '__keywords' from args, it generated from nunjucks and not related to HuggingFace.
const _b = (0, lodash_1.omit)(args, '__keywords'), { query, model, endpoint } = _b, otherArgs = tslib_1.__rest(_b, ["query", "model", "endpoint"]);
const inferenceOptions = (0, lodash_1.pick)(otherArgs, ['use_cache', 'wait_for_model']);
const parameters = (0, lodash_1.omit)(otherArgs, ['use_cache', 'wait_for_model', 'endpoint']);
const payload = {
inputs: `Context: ${context}. Question: ${query}}`,
parameters: {
return_full_text: false,
max_new_tokens: 250,
temperature: 0.1,
}
};
if (!(0, lodash_1.isEmpty)(parameters))
payload.parameters = parameters;
if (!(0, lodash_1.isEmpty)(inferenceOptions))
payload.options = inferenceOptions;
try {
// if not given endpoint, use default HuggingFace inference endpoint
const url = endpoint ? endpoint : getUrl(model);
const results = yield (0, utils_1.postRequest)(url, payload, token);
// get the "generated_text" field, and trim leading and trailing white space.
return String(results[0]['generated_text']).trim();
}
catch (error) {
throw new core_1.InternalError(`Error when sending data to Hugging Face for executing TextGeneration tasks, details: ${error.message}`);
}
});
exports.TextGenerationFilter = TextGenerationFilter;
_a = (0, core_1.createFilterExtension)('huggingface_text_generation', exports.TextGenerationFilter), exports.Builder = _a[0], exports.Runner = _a[1];
//# sourceMappingURL=textGeneration.js.map