inference-server
Version:
Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.
1,027 lines • 42.1 kB
JavaScript
import path from 'node:path';
import fs from 'node:fs';
import { env, AutoModel, AutoProcessor, AutoTokenizer, RawImage, TextStreamer, mean_pooling, Tensor, StoppingCriteria, softmax, TextClassificationPipeline, } from '@huggingface/transformers';
import { LogLevels } from '../../lib/logger.js';
import { resampleAudioBuffer } from '../../lib/loadAudio.js';
import { resolveModelFileLocation } from '../../lib/resolveModelFileLocation.js';
import { moveDirectoryContents } from '../../lib/moveDirectoryContents.js';
import { fetchBuffer, normalizeTransformersJsClass, parseHuggingfaceModelIdAndBranch, remoteFileExists, } from './util.js';
import { validateModelFiles } from './validateModelFiles.js';
import { acquireModelFileLocks } from './acquireModelFileLocks.js';
import { loadModelComponents, loadSpeechModelComponents } from './loadModelComponents.js';
import { flattenMessageTextContent } from '../../lib/flattenMessageTextContent.js';
export const autoGpu = true;
let didConfigureEnvironment = false;
function configureEnvironment(modelsPath) {
// console.debug({
// cacheDir: env.cacheDir,
// localModelPaths: env.localModelPath,
// })
// env.useFSCache = false
// env.useCustomCache = true
// env.customCache = new TransformersFileCache(modelsPath)
env.localModelPath = '';
didConfigureEnvironment = true;
}
async function disposeModelComponents(modelComponents) {
if (modelComponents.model && 'dispose' in modelComponents.model) {
await modelComponents.model.dispose();
}
}
export async function prepareModel({ config, log }, onProgress, signal) {
if (!didConfigureEnvironment) {
configureEnvironment(config.modelsCachePath);
}
fs.mkdirSync(config.location, { recursive: true });
const releaseFileLocks = await acquireModelFileLocks(config, signal);
if (signal?.aborted) {
releaseFileLocks();
return;
}
log(LogLevels.info, `Preparing transformers.js model at ${config.location}`, {
model: config.id,
});
const downloadModelFiles = async (modelOpts, { modelId, branch }, requiredComponents = ['model', 'tokenizer', 'processor', 'vocoder']) => {
const modelClass = normalizeTransformersJsClass(modelOpts.modelClass, AutoModel);
const downloadPromises = {};
const progressCallback = (progress) => {
if (onProgress && progress.status === 'progress') {
onProgress({
file: env.cacheDir + progress.name + '/' + progress.file,
loadedBytes: progress.loaded || 0,
totalBytes: progress.total || 0,
});
}
};
if (requiredComponents.includes('model')) {
const modelDownloadPromise = modelClass.from_pretrained(modelId, {
revision: branch,
dtype: modelOpts.dtype || 'fp32',
progress_callback: progressCallback,
// use_external_data_format: true, // https://github.com/xenova/transformers.js/blob/38a3bf6dab2265d9f0c2f613064535863194e6b9/src/models.js#L205-L207
});
downloadPromises.model = modelDownloadPromise;
}
if (requiredComponents.includes('tokenizer')) {
const hasTokenizer = await remoteFileExists(`${config.url}/blob/${branch}/tokenizer.json`);
if (hasTokenizer) {
const tokenizerClass = normalizeTransformersJsClass(modelOpts.tokenizerClass, AutoTokenizer);
const tokenizerDownload = tokenizerClass.from_pretrained(modelId, {
revision: branch,
progress_callback: progressCallback,
});
downloadPromises.tokenizer = tokenizerDownload;
}
}
if (requiredComponents.includes('processor')) {
const processorClass = normalizeTransformersJsClass(modelOpts.processorClass, AutoProcessor);
if (modelOpts.processor?.url) {
const { modelId, branch } = parseHuggingfaceModelIdAndBranch(modelOpts.processor.url);
const processorDownload = processorClass.from_pretrained(modelId, {
revision: branch,
progress_callback: progressCallback,
});
downloadPromises.processor = processorDownload;
}
else {
const [hasProcessor, hasPreprocessor] = await Promise.all([
remoteFileExists(`${config.url}/blob/${branch}/processor_config.json`),
remoteFileExists(`${config.url}/blob/${branch}/preprocessor_config.json`),
]);
if (hasProcessor || hasPreprocessor) {
const processorDownload = processorClass.from_pretrained(modelId, {
revision: branch,
progress_callback: progressCallback,
});
downloadPromises.processor = processorDownload;
}
}
}
if (requiredComponents.includes('vocoder') && 'vocoder' in modelOpts) {
if (modelOpts.vocoder?.url) {
const { modelId, branch } = parseHuggingfaceModelIdAndBranch(modelOpts.vocoder.url);
const vocoderDownload = AutoModel.from_pretrained(modelId, {
revision: branch,
progress_callback: progressCallback,
});
downloadPromises.vocoder = vocoderDownload;
}
}
await Promise.all(Object.values(downloadPromises));
const modelComponents = {};
if (downloadPromises.model) {
modelComponents.model = (await downloadPromises.model);
}
if (downloadPromises.tokenizer) {
modelComponents.tokenizer = (await downloadPromises.tokenizer);
}
if (downloadPromises.processor) {
modelComponents.processor = (await downloadPromises.processor);
}
disposeModelComponents(modelComponents);
// return modelComponents
};
const downloadSpeakerEmbeddings = async (speakerEmbeddings) => {
const speakerEmbeddingsPromises = [];
for (const speakerEmbedding of Object.values(speakerEmbeddings)) {
if (speakerEmbedding instanceof Float32Array) {
// nothing to download if we have the embeddings already
continue;
}
if (!speakerEmbedding.url) {
continue;
}
const speakerEmbeddingsPath = resolveModelFileLocation({
url: speakerEmbedding.url,
filePath: speakerEmbedding.file,
modelsCachePath: config.modelsCachePath,
});
const url = speakerEmbedding.url;
speakerEmbeddingsPromises.push((async () => {
const buffer = await fetchBuffer(url);
await fs.promises.writeFile(speakerEmbeddingsPath, buffer);
})());
}
return Promise.all(speakerEmbeddingsPromises);
};
const downloadModel = async (validationResult) => {
log(LogLevels.info, `${validationResult.message} - Downloading files`, {
model: config.id,
url: config.url,
location: config.location,
errors: validationResult.errors,
});
const modelDownloadPromises = [];
if (!config.url) {
throw new Error(`Missing URL for model ${config.id}`);
}
const { modelId, branch } = parseHuggingfaceModelIdAndBranch(config.url);
const directoriesToCopy = {};
const modelCacheDir = path.join(env.cacheDir, modelId);
directoriesToCopy[modelCacheDir] = config.location;
const prepareDownloadedVocoder = (vocoderOpts) => {
if (!vocoderOpts.url) {
return;
}
const vocoderPath = resolveModelFileLocation({
url: vocoderOpts.url,
filePath: vocoderOpts.file,
modelsCachePath: config.modelsCachePath,
});
const { modelId } = parseHuggingfaceModelIdAndBranch(vocoderOpts.url);
const vocoderCacheDir = path.join(env.cacheDir, modelId);
directoriesToCopy[vocoderCacheDir] = vocoderPath;
};
const prepareDownloadedProcessor = (processorOpts) => {
if (!processorOpts.url) {
return;
}
const processorPath = resolveModelFileLocation({
url: processorOpts.url,
filePath: processorOpts.file,
modelsCachePath: config.modelsCachePath,
});
const { modelId } = parseHuggingfaceModelIdAndBranch(processorOpts.url);
const processorCacheDir = path.join(env.cacheDir, modelId);
directoriesToCopy[processorCacheDir] = processorPath;
};
const requiredComponents = validationResult.errors?.primaryModel
? Object.keys(validationResult.errors.primaryModel)
: undefined;
modelDownloadPromises.push(downloadModelFiles(config, { modelId, branch }, requiredComponents));
if (config.processor?.url) {
prepareDownloadedProcessor(config.processor);
}
if (config.vocoder?.url) {
prepareDownloadedVocoder(config.vocoder);
}
if (config?.speakerEmbeddings) {
modelDownloadPromises.push(downloadSpeakerEmbeddings(config.speakerEmbeddings));
}
if (config.textModel) {
const requiredComponents = validationResult.errors?.textModel
? Object.keys(validationResult.errors.textModel)
: undefined;
modelDownloadPromises.push(downloadModelFiles(config.textModel, { modelId, branch }, requiredComponents));
}
if (config.visionModel) {
const requiredComponents = validationResult.errors?.visionModel
? Object.keys(validationResult.errors.visionModel)
: undefined;
modelDownloadPromises.push(downloadModelFiles(config.visionModel, { modelId, branch }, requiredComponents));
if (config.processor?.url) {
prepareDownloadedProcessor(config.processor);
}
}
if (config.speechModel) {
const requiredComponents = validationResult.errors?.speechModel
? Object.keys(validationResult.errors.speechModel)
: undefined;
modelDownloadPromises.push(downloadModelFiles(config.speechModel, { modelId, branch }, requiredComponents));
if (config.speechModel.vocoder?.url) {
prepareDownloadedVocoder(config.speechModel.vocoder);
}
if (config.speechModel?.speakerEmbeddings) {
modelDownloadPromises.push(downloadSpeakerEmbeddings(config.speechModel.speakerEmbeddings));
}
}
await Promise.all(modelDownloadPromises);
if (signal?.aborted) {
return;
}
// move all downloads to their final location
await Promise.all(Object.entries(directoriesToCopy).map(([from, to]) => {
if (fs.existsSync(from)) {
return moveDirectoryContents(from, to);
}
return Promise.resolve();
}));
};
try {
const validationResults = await validateModelFiles(config);
if (signal?.aborted) {
return;
}
if (validationResults) {
if (config.url) {
await downloadModel(validationResults);
}
else {
throw new Error(`Model files are invalid: ${validationResults.message}`);
}
}
}
catch (error) {
throw error;
}
finally {
releaseFileLocks();
}
const configMeta = {};
const fileList = [];
const modelFiles = fs.readdirSync(config.location, { recursive: true });
const pushFile = (file) => {
const targetFile = path.join(config.location, file);
const targetStat = fs.statSync(targetFile);
fileList.push({
file: targetFile,
size: targetStat.size,
});
if (targetFile.endsWith('.json')) {
const key = path.basename(targetFile).replace('.json', '');
configMeta[key] = JSON.parse(fs.readFileSync(targetFile, 'utf8'));
}
};
// add model files to the list
for (const file of modelFiles) {
pushFile(file.toString());
}
// add extra stuff from external repos
if (config.visionModel?.processor) {
const processorPath = resolveModelFileLocation({
url: config.visionModel.processor.url,
filePath: config.visionModel.processor.file,
modelsCachePath: config.modelsCachePath,
});
const processorFiles = fs.readdirSync(processorPath, { recursive: true });
for (const file of processorFiles) {
pushFile(file.toString());
}
}
return {
files: modelFiles,
...configMeta,
};
}
export async function createInstance({ config, log }, signal) {
const modelLoadPromises = [];
modelLoadPromises.push(loadModelComponents(config, config));
if (config.textModel) {
modelLoadPromises.push(loadModelComponents(config.textModel, config));
}
else {
modelLoadPromises.push(Promise.resolve(undefined));
}
if (config.visionModel) {
modelLoadPromises.push(loadModelComponents(config.visionModel, config));
}
else {
modelLoadPromises.push(Promise.resolve(undefined));
}
if (config.speechModel) {
modelLoadPromises.push(loadSpeechModelComponents(config.speechModel, config));
}
else {
modelLoadPromises.push(Promise.resolve(undefined));
}
const models = await Promise.all(modelLoadPromises);
const instance = {
primary: models[0],
text: models[1],
vision: models[2],
speech: models[3],
};
// warm up model by doing a tiny generation
if (config.task === 'chat-completion') {
const chatModel = instance.text || instance.primary;
if (chatModel.tokenizer && !chatModel.processor) {
// TODO figure out a way to warm up using processor?
const inputs = chatModel.tokenizer('a');
await chatModel.model.generate({ ...inputs, max_new_tokens: 1 });
}
}
if (config.task === 'text-completion') {
const textModel = instance.text || instance.primary;
const inputs = textModel.tokenizer('a');
await textModel.model.generate({ ...inputs, max_new_tokens: 1 });
}
// TODO warm up other model types
// ie for whisper, this seems to speed up the initial response time
// await model.generate({
// input_features: full([1, 80, 3000], 0.0),
// max_new_tokens: 1,
// });
return instance;
}
export async function disposeInstance(instance) {
const disposePromises = [];
if (instance.primary) {
disposePromises.push(disposeModelComponents(instance.primary));
}
if (instance.vision) {
disposePromises.push(disposeModelComponents(instance.vision));
}
if (instance.speech) {
disposePromises.push(disposeModelComponents(instance.speech));
}
await Promise.all(disposePromises);
}
class CustomStoppingCriteria extends StoppingCriteria {
stopped;
constructor() {
super();
this.stopped = false;
}
stop() {
this.stopped = true;
}
reset() {
this.stopped = false;
}
_call(inputIds, scores) {
return new Array(inputIds.length).fill(this.stopped);
}
}
function prepareInputMessages(messages) {
const images = [];
const inputMessages = messages.map((message) => {
if (typeof message.content === 'string') {
return {
role: message.role,
content: message.content,
};
}
else if (Array.isArray(message.content)) {
return {
role: message.role,
content: message.content.map((part) => {
if (part.type === 'text') {
return part;
}
else if (part.type === 'image') {
const rawImage = new RawImage(new Uint8ClampedArray(part.image.data), part.image.width, part.image.height, part.image.channels);
images.push(rawImage);
return {
type: 'text',
text: '<image_placeholder>',
};
}
else {
throw new Error('Invalid message content: unknown part type');
}
}),
};
}
else {
throw new Error('Invalid message content: must be string or array');
}
});
return {
inputMessages: inputMessages.map((message) => {
return {
role: message.role,
content: flattenMessageTextContent(message.content),
};
}),
images,
};
}
export async function processChatCompletionTask(task, ctx, signal) {
const { instance } = ctx;
if (!task.messages) {
throw new Error('Messages are required for chat completion.');
}
const chatModel = instance.text || instance.primary;
if (!(chatModel.tokenizer && chatModel.model)) {
throw new Error('Chat model is not loaded.');
}
const { images, inputMessages } = prepareInputMessages(task.messages);
let inputs;
let inputTokenCount = 0;
const inputText = chatModel.tokenizer.apply_chat_template(inputMessages, {
tokenize: false,
add_generation_prompt: true,
return_dict: true,
});
if (chatModel.processor) {
inputs = await chatModel.processor(inputMessages, {
images,
});
}
else {
inputs = chatModel.tokenizer(inputText, {
return_tensor: true,
add_special_tokens: false,
});
}
inputTokenCount = inputs.input_ids.size;
const stoppingCriteria = new CustomStoppingCriteria();
signal?.addEventListener('abort', () => {
stoppingCriteria.stop();
});
let responseText = '';
let finishReason = 'cancel';
const streamer = new TextStreamer(chatModel.tokenizer, {
skip_prompt: true,
decode_kwargs: {
skip_special_tokens: true,
},
callback_function: (output) => {
responseText += output;
if (task.stop && task.stop.some((stopToken) => output.includes(stopToken))) {
stoppingCriteria.stop();
finishReason = 'stopTrigger';
}
if (task.onChunk) {
const tokens = chatModel.tokenizer.encode(output);
task.onChunk({ text: output, tokens: tokens });
}
},
});
const maxTokens = task.maxTokens ?? 128;
const outputs = (await chatModel.model.generate({
...inputs,
// common params
max_new_tokens: maxTokens,
repetition_penalty: task.repeatPenalty ?? 1.0, // 1 = no penalty
temperature: task.temperature ?? 1.0,
top_k: task.topK ?? 50,
top_p: task.topP ?? 1.0,
// do_sample: true,
// num_beams: 1,
// num_return_sequences: 2, // TODO https://github.com/huggingface/transformers.js/issues/1007
// eos_token_id: stopTokens[0], // TODO implement stop
streamer,
// transformers-exclusive params
// Since the score is the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
// length_penalty: -64,
// The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay.
// exponential_decay_length_penalty: [1, 64],
// typical_p: 1,
// epsilon_cutoff: 0,
// eta_cutoff: 0,
// diversity_penalty: 0,
// encoder_repetition_penalty: 1.0, // 1 = no penalty
// no_repeat_ngram_size: 0,
// forced_eos_token_id: [],
// bad_words_ids: [],
// force_words_ids: [],
// suppress_tokens: [],
stopping_criteria: stoppingCriteria,
}));
const outputTokenCount = outputs.size;
// const hasEogToken = outputs.
// @ts-ignore
const outputTexts = chatModel.tokenizer.batch_decode(outputs, { skip_special_tokens: false });
const eosToken = chatModel.tokenizer._tokenizer_config.eos_token;
const hasEogToken = outputTexts[0].endsWith(eosToken);
const completionTokenCount = outputTokenCount - inputTokenCount;
if (hasEogToken) {
finishReason = 'eogToken';
}
else if (completionTokenCount >= maxTokens) {
finishReason = 'maxTokens';
}
return {
finishReason,
message: {
role: 'assistant',
content: responseText,
},
promptTokens: inputTokenCount,
completionTokens: outputTokenCount - inputTokenCount,
contextTokens: outputTokenCount,
};
}
// TextGenerationPipeline https://github.com/huggingface/transformers.js/blob/705cfc456f8b8f114891e1503b0cdbaa97cf4b11/src/pipelines.js#L977
// Generation Args https://github.com/huggingface/transformers.js/blob/705cfc456f8b8f114891e1503b0cdbaa97cf4b11/src/generation/configuration_utils.js#L11
// default Model.generate https://github.com/huggingface/transformers.js/blob/705cfc456f8b8f114891e1503b0cdbaa97cf4b11/src/models.js#L1378
export async function processTextCompletionTask(task, ctx, signal) {
const { instance } = ctx;
if (!task.prompt) {
throw new Error('Prompt is required for text completion.');
}
const textModel = instance.text || instance.primary;
if (!(textModel?.tokenizer && textModel?.model)) {
throw new Error('Text model is not loaded.');
}
if (!('generate' in textModel.model)) {
throw new Error('Text model does not support generation.');
}
textModel.tokenizer.padding_side = 'left';
const inputs = textModel.tokenizer(task.prompt, {
add_special_tokens: false,
padding: true,
truncation: true,
});
const stoppingCriteria = new CustomStoppingCriteria();
signal?.addEventListener('abort', () => {
stoppingCriteria.stop();
});
let finishReason = 'cancel';
const streamer = new TextStreamer(textModel.tokenizer, {
skip_prompt: true,
callback_function: (output) => {
if (task.stop && task.stop.some((stopToken) => output.includes(stopToken))) {
stoppingCriteria.stop();
finishReason = 'stopTrigger';
}
if (task.onChunk) {
const tokens = textModel.tokenizer.encode(output);
task.onChunk({ text: output, tokens: tokens });
}
},
});
const maxTokens = task.maxTokens ?? 128;
const outputs = await textModel.model.generate({
...inputs,
renormalize_logits: true,
output_scores: true, // TODO currently no effect
return_dict_in_generate: true,
// common params
max_new_tokens: maxTokens,
repetition_penalty: task.repeatPenalty ?? 1.0, // 1 = no penalty
temperature: task.temperature,
top_k: task.topK,
top_p: task.topP,
// num_beams: 1,
// num_return_sequences: 2, // TODO https://github.com/huggingface/transformers.js/issues/1007
// eos_token_id: stopTokens[0], // TODO implement stop
length_penalty: 0,
streamer,
// transformers-exclusive params
// length_penalty: -64, // Since the score is the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
// The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay.
// exponential_decay_length_penalty: [1, 64],
// typical_p: 1,
// epsilon_cutoff: 0,
// eta_cutoff: 0,
// diversity_penalty: 0,
// encoder_repetition_penalty: 1.0, // 1 = no penalty
// no_repeat_ngram_size: 0,
// forced_eos_token_id: [],
// bad_words_ids: [],
// force_words_ids: [],
// suppress_tokens: [],
});
// @ts-ignore
const outputTexts = instance.primary.tokenizer.batch_decode(outputs.sequences, {
skip_special_tokens: true,
clean_up_tokenization_spaces: true,
});
const generatedText = outputTexts[0].slice(task.prompt.length);
// @ts-ignore
const outputTokenCount = outputs.sequences.tolist().reduce((acc, sequence) => acc + sequence.length, 0);
const inputTokenCount = inputs.input_ids.size;
// const outputTexts = chatModel.tokenizer.batch_decode(outputs, { skip_special_tokens: false })
const eosToken = textModel.tokenizer._tokenizer_config.eos_token;
const hasEogToken = outputTexts[0].endsWith(eosToken);
const completionTokenCount = outputTokenCount - inputTokenCount;
if (hasEogToken) {
finishReason = 'eogToken';
}
else if (completionTokenCount >= maxTokens) {
finishReason = 'maxTokens';
}
return {
finishReason,
text: generatedText,
promptTokens: inputTokenCount,
completionTokens: outputTokenCount,
contextTokens: inputTokenCount + outputTokenCount,
};
}
// see https://github.com/xenova/transformers.js/blob/v3/src/utils/tensor.js
// https://github.com/xenova/transformers.js/blob/v3/src/pipelines.js#L1284
export async function processEmbeddingTask(task, ctx, signal) {
const { instance, config } = ctx;
if (!task.input) {
throw new Error('Input is required for embedding.');
}
const inputs = Array.isArray(task.input) ? task.input : [task.input];
const normalizedInputs = inputs.map((input) => {
if (typeof input === 'string') {
return {
type: 'text',
content: input,
};
}
else if (input.type) {
return input;
}
else {
throw new Error('Invalid input type');
}
});
const embeddings = [];
let inputTokens = 0;
const applyPooling = (result, pooling, modelInputs) => {
if (pooling === 'mean') {
return mean_pooling(result, modelInputs.attention_mask);
}
else if (pooling === 'cls') {
return result.slice(null, 0);
}
else {
throw Error(`Pooling method '${pooling}' not supported.`);
}
};
const truncateDimensions = (result, dimensions) => {
const truncatedData = new Float32Array(dimensions);
truncatedData.set(result.data.slice(0, dimensions));
return truncatedData;
};
for (const embeddingInput of normalizedInputs) {
if (signal?.aborted) {
break;
}
let result;
let modelInputs;
if (embeddingInput.type === 'text') {
const modelComponents = instance.text || instance.primary;
if (!modelComponents?.tokenizer || !modelComponents?.model) {
throw new Error('Text model is not loaded.');
}
modelInputs = modelComponents.tokenizer(embeddingInput.content, {
padding: true, // pads input if it is shorter than context window
truncation: true, // truncates input if it exceeds context window
});
inputTokens += modelInputs.input_ids.size;
// @ts-ignore TODO check _call
const modelOutputs = await instance.primary.model(modelInputs);
result =
modelOutputs.last_hidden_state ??
modelOutputs.logits ??
modelOutputs.token_embeddings ??
modelOutputs.text_embeds;
}
else if (embeddingInput.type === 'image') {
const modelComponents = instance.vision || instance.primary;
if (!modelComponents?.processor || !modelComponents?.model) {
throw new Error('Vision model is not loaded.');
}
// const
// const { data, info } = await sharp(embeddingInput.content.data).raw().toBuffer({ resolveWithObject: true })
const image = embeddingInput.content;
const rawImage = new RawImage(new Uint8ClampedArray(image.data), image.width, image.height, image.channels);
modelInputs = await modelComponents.processor(rawImage);
// @ts-ignore TODO check _call
const modelOutputs = await instance.vision.model(modelInputs);
result = modelOutputs.last_hidden_state ?? modelOutputs.logits ?? modelOutputs.image_embeds;
}
if (task.pooling) {
result = applyPooling(result, task.pooling, modelInputs);
}
if (task.dimensions && result.data.length > task.dimensions) {
embeddings.push(truncateDimensions(result, task.dimensions));
}
else {
embeddings.push(result.data);
}
}
return {
embeddings,
inputTokens,
};
}
export async function processImageToTextTask(task, ctx, signal) {
const { instance } = ctx;
if (!task.image) {
throw new Error('No image provided');
}
const image = task.image;
const rawImage = new RawImage(new Uint8ClampedArray(image.data), image.width, image.height, image.channels);
if (signal?.aborted) {
return;
}
const modelComponents = instance.vision || instance.primary;
if (!(modelComponents && modelComponents.tokenizer && modelComponents.processor && modelComponents.model)) {
throw new Error('No model loaded');
}
if (!('generate' in modelComponents.model)) {
throw new Error('Model does not support generation');
}
let textInputs = {};
if (task.prompt) {
textInputs = modelComponents.tokenizer(task.prompt);
}
const imageInputs = await modelComponents.processor(rawImage);
const outputTokens = await modelComponents.model.generate({
...textInputs,
...imageInputs,
max_new_tokens: task.maxTokens ?? 128,
});
// @ts-ignore
const outputText = modelComponents.tokenizer.batch_decode(outputTokens, {
skip_special_tokens: true,
});
return {
text: outputText[0],
};
}
// see examples
// https://huggingface.co/docs/transformers.js/guides/node-audio-processing
// https://github.com/xenova/transformers.js/tree/v3/examples/node-audio-processing
export async function processSpeechToTextTask(task, ctx, signal) {
const { instance } = ctx;
if (!task.audio) {
throw new Error('No audio provided');
}
const modelComponents = instance.speech || instance.primary;
if (!(modelComponents?.tokenizer && modelComponents?.model)) {
throw new Error('No speech model loaded');
}
const streamer = new TextStreamer(modelComponents.tokenizer, {
skip_prompt: true,
// skip_special_tokens: true,
callback_function: (output) => {
if (task.onChunk) {
task.onChunk({ text: output });
}
},
});
let inputSamples = task.audio.samples;
if (task.audio.sampleRate !== 16000) {
inputSamples = await resampleAudioBuffer(task.audio.samples, {
inputSampleRate: task.audio.sampleRate,
outputSampleRate: 16000,
nChannels: 1,
});
}
const inputs = await modelComponents.processor(inputSamples);
if (!('generate' in modelComponents.model)) {
throw new Error('Speech model class does not support text generation');
}
const outputs = await modelComponents.model.generate({
...inputs,
max_new_tokens: task.maxTokens ?? 128,
language: task.language ?? 'en',
streamer,
});
// @ts-ignore
const outputText = modelComponents.tokenizer.batch_decode(outputs, {
skip_special_tokens: true,
});
return {
text: outputText[0],
};
}
// TextGenerationPipeline https://github.com/huggingface/transformers.js/blob/e129c47c65a049173f35e6263fd8d9f660dfc1a7/src/pipelines.js#L2663
export async function processTextToSpeechTask(task, ctx, signal) {
const { instance } = ctx;
const modelComponents = instance.speech || instance.primary;
if (!modelComponents?.model || !modelComponents?.tokenizer) {
throw new Error('No speech model loaded');
}
if (!('generate_speech' in modelComponents.model)) {
throw new Error('The model does not support speech generation');
}
const encodedInputs = modelComponents.tokenizer(task.text, {
padding: true,
truncation: true,
});
if (!('speakerEmbeddings' in modelComponents)) {
throw new Error('No speaker embeddings supplied');
}
let speakerEmbeddings = modelComponents.speakerEmbeddings?.[Object.keys(modelComponents.speakerEmbeddings)[0]];
if (!speakerEmbeddings) {
throw new Error('No speaker embeddings supplied');
}
if (task.voice) {
speakerEmbeddings = modelComponents.speakerEmbeddings?.[task.voice];
if (!speakerEmbeddings) {
throw new Error(`No speaker embeddings found for voice ${task.voice}`);
}
}
if (signal?.aborted) {
throw new Error('Task aborted');
}
const speakerEmbeddingsTensor = new Tensor('float32', speakerEmbeddings, [1, speakerEmbeddings.length]);
const outputs = await modelComponents.model.generate_speech(encodedInputs.input_ids, speakerEmbeddingsTensor, {
vocoder: modelComponents.vocoder,
});
if (!outputs.waveform) {
throw new Error('No waveform generated');
}
const sampleRate = modelComponents.processor.feature_extractor.config.sampling_rate;
return {
audio: {
samples: outputs.waveform.data,
sampleRate,
channels: 1,
},
};
}
// ObjectDetectionPipeline https://github.com/huggingface/transformers.js/blob/6bd45ac66a861f37f3f95b81ac4b6d796a4ee231/src/pipelines.js#L2336
// ZeroShotObjectDetection https://github.com/huggingface/transformers.js/blob/6bd45ac66a861f37f3f95b81ac4b6d796a4ee231/src/pipelines.js#L2471
export async function processObjectDetectionTask(task, ctx, signal) {
const { instance } = ctx;
if (!task.image) {
throw new Error('No image provided');
}
const image = task.image;
const rawImage = new RawImage(new Uint8ClampedArray(image.data), image.width, image.height, image.channels);
const modelComponents = instance.vision || instance.primary;
if (!(modelComponents && modelComponents.model)) {
throw new Error('No model loaded');
}
if (signal?.aborted) {
throw new Error('Task aborted');
}
const results = [];
if (task?.labels?.length) {
if (!modelComponents.tokenizer || !modelComponents.processor) {
throw new Error('Model components not loaded.');
}
const labelInputs = modelComponents.tokenizer(task.labels, {
padding: true,
truncation: true,
});
const imageInputs = await modelComponents.processor([rawImage]);
const output = await modelComponents.model({
...labelInputs,
pixel_values: imageInputs.pixel_values[0].unsqueeze_(0),
});
// @ts-ignore
const processed = modelComponents.processor.image_processor.post_process_object_detection(output, task.threshold ?? 0.5, [[image.height, image.width]], true)[0];
for (let i = 0; i < processed.boxes.length; i++) {
results.push({
score: processed.scores[i],
label: task.labels[processed.classes[i]],
box: {
x: processed.boxes[i][0],
y: processed.boxes[i][1],
width: processed.boxes[i][2] - processed.boxes[i][0],
height: processed.boxes[i][3] - processed.boxes[i][1],
},
});
}
}
else {
// @ts-ignore
const { pixel_values, pixel_mask } = await modelComponents.processor([rawImage]);
const output = await modelComponents.model({ pixel_values, pixel_mask });
// @ts-ignore
const processed = modelComponents.processor.image_processor.post_process_object_detection(output, task.threshold ?? 0.5, [[image.height, image.width]], null, false);
// Add labels
// @ts-ignore
const id2label = modelComponents.model.config.id2label;
for (const batch of processed) {
for (let i = 0; i < batch.boxes.length; i++) {
results.push({
score: batch.scores[i],
label: id2label[batch.classes[i]],
box: {
x: batch.boxes[i][0],
y: batch.boxes[i][1],
width: batch.boxes[i][2] - batch.boxes[i][0],
height: batch.boxes[i][3] - batch.boxes[i][1],
},
});
}
}
}
return {
detections: results,
};
}
// https://github.com/huggingface/transformers.js/blob/6f43f244e04522545d3d939589c761fdaff057d4/src/pipelines.js#L1135
export async function processTextClassificationTask(task, ctx, signal) {
const { instance } = ctx;
const modelComponents = instance.text || instance.primary;
if (!modelComponents?.tokenizer || !modelComponents?.model) {
throw new Error('No text model loaded');
}
if (signal?.aborted) {
throw new Error('Task aborted');
}
if (!task.labels?.length) {
// Reuse the pipeline for normal text classification
const pipeline = new TextClassificationPipeline({
task: 'text-classification',
model: modelComponents.model,
tokenizer: modelComponents.tokenizer,
});
const pipelineRes = await pipeline(task.input, { top_k: task.topK });
if (Array.isArray(pipelineRes)) {
const resultItems = pipelineRes;
const classifications = resultItems.map((item) => {
const labels = [
{
name: item.label,
score: item.score,
},
];
return { labels };
});
return { classifications };
}
const singleResultItem = pipelineRes;
const labels = [
{
name: singleResultItem.label,
score: singleResultItem.score,
},
];
return {
classifications: [{ labels }],
};
}
// Zero shot classification
// @ts-ignore
const label2id = modelComponents.model.config.label2id;
let entailmentId = label2id['entailment'];
if (entailmentId === undefined) {
console.warn("Could not find 'entailment' in label2id mapping. Using 2 as entailment_id.");
entailmentId = 2;
}
let contradictionId = label2id['contradiction'] ?? label2id['not_entailment'];
if (contradictionId === undefined) {
console.warn("Could not find 'contradiction' in label2id mapping. Using 0 as contradiction_id.");
contradictionId = 0;
}
const texts = [];
if (typeof task.input === 'string') {
texts.push(task.input);
}
else if (Array.isArray(task.input)) {
texts.push(...task.input);
}
else {
throw new Error('Invalid input');
}
const hypotheses = task.labels.map((label) => task.hypothesisTemplate.replace('{}', label));
// How to perform the softmax over the logits:
// - true: softmax over the entailment vs. contradiction dim for each label independently
// - false: softmax the "entailment" logits over all candidate labels
const softmaxEach = task.labels.length === 1;
const toReturn = [];
for (const premise of texts) {
const entailsLogits = [];
for (const hypothesis of hypotheses) {
const inputs = modelComponents.tokenizer(premise, {
text_pair: hypothesis,
padding: true,
truncation: true,
});
const outputs = await modelComponents.model(inputs);
if (softmaxEach) {
entailsLogits.push([outputs.logits.data[contradictionId], outputs.logits.data[entailmentId]]);
}
else {
entailsLogits.push(outputs.logits.data[entailmentId]);
}
}
const scores = softmaxEach ? entailsLogits.map((x) => softmax(x)[1]) : softmax(entailsLogits);
const scoresSorted = scores.map((x, i) => [x, i]).sort((a, b) => b[0] - a[0]);
toReturn.push({
sequence: premise,
labels: scoresSorted.map((x) => task.labels[x[1]]),
scores: scoresSorted.map((x) => x[0]),
});
}
const classifications = toReturn.map((x) => {
let labels = x.labels.map((label, i) => {
return {
name: label,
score: x.scores[i],
};
});
if (task.topK) {
labels = labels.slice(0, task.topK);
}
return { labels };
});
return { classifications };
}
//# sourceMappingURL=engine.js.map