ai-utils.js
Version:
Build AI applications, chatbots, and agents with JavaScript and TypeScript.
175 lines (174 loc) • 6.36 kB
JavaScript
;
var __importDefault = (this && this.__importDefault) || function (mod) {
return (mod && mod.__esModule) ? mod : { "default": mod };
};
Object.defineProperty(exports, "__esModule", { value: true });
exports.HuggingFaceTextGenerationModel = void 0;
const zod_1 = __importDefault(require("zod"));
const AbstractModel_js_1 = require("../../model-function/AbstractModel.cjs");
const callWithRetryAndThrottle_js_1 = require("../../util/api/callWithRetryAndThrottle.cjs");
const postToApi_js_1 = require("../../util/api/postToApi.cjs");
const HuggingFaceError_js_1 = require("./HuggingFaceError.cjs");
const PromptMappingTextGenerationModel_js_1 = require("../../prompt/PromptMappingTextGenerationModel.cjs");
/**
* Create a text generation model that calls a Hugging Face Inference API Text Generation Task.
*
* @see https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
*
* @example
* const model = new HuggingFaceTextGenerationModel({
* model: "tiiuae/falcon-7b",
* temperature: 0.7,
* maxTokens: 500,
* retry: retryWithExponentialBackoff({ maxTries: 5 }),
* });
*
* const { text } = await generateText(
* model,
* "Write a short story about a robot learning to love:\n\n"
* );
*/
class HuggingFaceTextGenerationModel extends AbstractModel_js_1.AbstractModel {
constructor(settings) {
super({ settings });
Object.defineProperty(this, "provider", {
enumerable: true,
configurable: true,
writable: true,
value: "huggingface"
});
Object.defineProperty(this, "contextWindowSize", {
enumerable: true,
configurable: true,
writable: true,
value: undefined
});
Object.defineProperty(this, "tokenizer", {
enumerable: true,
configurable: true,
writable: true,
value: undefined
});
Object.defineProperty(this, "countPromptTokens", {
enumerable: true,
configurable: true,
writable: true,
value: undefined
});
Object.defineProperty(this, "generateDeltaStreamResponse", {
enumerable: true,
configurable: true,
writable: true,
value: undefined
});
Object.defineProperty(this, "extractTextDelta", {
enumerable: true,
configurable: true,
writable: true,
value: undefined
});
}
get modelName() {
return this.settings.model;
}
get apiKey() {
const apiKey = this.settings.apiKey ?? process.env.HUGGINGFACE_API_KEY;
if (apiKey == null) {
throw new Error("No Hugging Face API key provided. Pass it in the constructor or set the HUGGINGFACE_API_KEY environment variable.");
}
return apiKey;
}
async callAPI(prompt, options) {
const run = options?.run;
const settings = options?.settings;
const callSettings = Object.assign({
apiKey: this.apiKey,
options: {
useCache: true,
waitForModel: true,
},
}, this.settings, settings, {
abortSignal: run?.abortSignal,
inputs: prompt,
});
return (0, callWithRetryAndThrottle_js_1.callWithRetryAndThrottle)({
retry: this.settings.retry,
throttle: this.settings.throttle,
call: async () => callHuggingFaceTextGenerationAPI(callSettings),
});
}
generateTextResponse(prompt, options) {
return this.callAPI(prompt, options);
}
extractText(response) {
return response[0].generated_text;
}
mapPrompt(promptMapping) {
return new PromptMappingTextGenerationModel_js_1.PromptMappingTextGenerationModel({
model: this,
promptMapping,
});
}
withSettings(additionalSettings) {
return new HuggingFaceTextGenerationModel(Object.assign({}, this.settings, additionalSettings));
}
get maxCompletionTokens() {
return this.settings.maxNewTokens;
}
withMaxCompletionTokens(maxCompletionTokens) {
return this.withSettings({ maxNewTokens: maxCompletionTokens });
}
withStopTokens() {
// stop tokens are not supported by the HuggingFace API
return this;
}
}
exports.HuggingFaceTextGenerationModel = HuggingFaceTextGenerationModel;
const huggingFaceTextGenerationResponseSchema = zod_1.default.array(zod_1.default.object({
generated_text: zod_1.default.string(),
}));
/**
* Call a Hugging Face Inference API Text Generation Task to generate a text completion for the given prompt.
*
* @see https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
*
* @example
* const response = await callHuggingFaceTextGenerationAPI({
* apiKey: HUGGINGFACE_API_KEY,
* model: "tiiuae/falcon-7b",
* inputs: "Write a short story about a robot learning to love:\n\n",
* temperature: 700,
* maxNewTokens: 500,
* options: {
* waitForModel: true,
* },
* });
*
* console.log(response[0].generated_text);
*/
async function callHuggingFaceTextGenerationAPI({ baseUrl = "https://api-inference.huggingface.co/models", abortSignal, apiKey, model, inputs, topK, topP, temperature, repetitionPenalty, maxNewTokens, maxTime, numReturnSequences, doSample, options, }) {
return (0, postToApi_js_1.postJsonToApi)({
url: `${baseUrl}/${model}`,
apiKey,
body: {
inputs,
top_k: topK,
top_p: topP,
temperature,
repetition_penalty: repetitionPenalty,
max_new_tokens: maxNewTokens,
max_time: maxTime,
num_return_sequences: numReturnSequences,
do_sample: doSample,
options: options
? {
use_cache: options?.useCache,
wait_for_model: options?.waitForModel,
}
: undefined,
},
failedResponseHandler: HuggingFaceError_js_1.failedHuggingFaceCallResponseHandler,
successfulResponseHandler: (0, postToApi_js_1.createJsonResponseHandler)(huggingFaceTextGenerationResponseSchema),
abortSignal,
});
}