inference-server
Version:
Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.
202 lines • 8.34 kB
JavaScript
import fs from 'node:fs';
import { AutoModel, AutoProcessor, AutoTokenizer, } from '@huggingface/transformers';
import { resolveModelFileLocation } from '../../lib/resolveModelFileLocation.js';
import { parseHuggingfaceModelIdAndBranch, remoteFileExists, normalizeTransformersJsClass } from './util.js';
// currently the way transformers.js models are validated is by trying to load them, catching the error and checking if the error message contains the signature.
// major upside being that we let transformers.js logic figure out which files are required. downside is, we have to hope that the message text stays the same.
const fileNotFoundSignature = 'file was not found locally';
async function validateModel(modelOpts, config, modelPath) {
const modelClass = normalizeTransformersJsClass(modelOpts.modelClass, AutoModel);
const device = config.device?.gpu ? 'gpu' : 'cpu';
try {
const model = await modelClass.from_pretrained(modelPath, {
local_files_only: true,
device: device,
dtype: modelOpts.dtype || 'fp32',
});
await model.dispose();
}
catch (error) {
if (error.message.includes(fileNotFoundSignature)) {
return `Failed to load model (${error.message})`;
}
throw error;
}
return undefined;
}
async function validateTokenizer(modelOpts, config, modelPath) {
const tokenizerClass = normalizeTransformersJsClass(modelOpts.tokenizerClass, AutoTokenizer);
try {
if (config.url) {
const { branch } = parseHuggingfaceModelIdAndBranch(config.url);
const hasTokenizer = await remoteFileExists(`${config.url}/blob/${branch}/tokenizer.json`);
if (hasTokenizer) {
await tokenizerClass.from_pretrained(modelPath, {
local_files_only: true,
});
}
}
}
catch (error) {
if (error.message.includes(fileNotFoundSignature)) {
return `Failed to load model (${error.message})`;
}
throw error;
}
return undefined;
}
async function validateProcessor(modelOpts, config, modelPath) {
const processorClass = normalizeTransformersJsClass(modelOpts.processorClass, AutoProcessor);
try {
if (modelOpts.processor) {
const processorPath = resolveModelFileLocation({
url: modelOpts.processor.url,
filePath: modelOpts.processor.file,
modelsCachePath: config.modelsCachePath,
});
await processorClass.from_pretrained(processorPath, {
local_files_only: true,
});
}
else {
if (modelOpts.processorClass) {
await processorClass.from_pretrained(modelPath, {
local_files_only: true,
});
}
else if (config.url) {
const { branch } = parseHuggingfaceModelIdAndBranch(config.url);
const [hasProcessor, hasPreprocessor] = await Promise.all([
remoteFileExists(`${config.url}/blob/${branch}/processor_config.json`),
remoteFileExists(`${config.url}/blob/${branch}/preprocessor_config.json`),
]);
if (hasProcessor || hasPreprocessor) {
await processorClass.from_pretrained(modelPath, {
local_files_only: true,
});
}
}
}
}
catch (error) {
if (error.message.includes(fileNotFoundSignature)) {
return `Failed to load model (${error.message})`;
}
throw error;
}
return undefined;
}
async function validateVocoder(modelOpts, config, modelPath) {
const vocoderClass = normalizeTransformersJsClass(modelOpts.vocoderClass, AutoModel);
if (modelOpts.vocoder) {
const vocoderPath = resolveModelFileLocation({
url: modelOpts.vocoder.url,
filePath: modelOpts.vocoder.file,
modelsCachePath: config.modelsCachePath,
});
try {
await vocoderClass.from_pretrained(vocoderPath, {
local_files_only: true,
});
}
catch (error) {
if (error.message.includes(fileNotFoundSignature)) {
return `Failed to load vocoder (${error.message})`;
}
throw error;
}
}
return undefined;
}
async function validateModelComponents(modelOpts, config, modelPath) {
const componentValidationPromises = [
validateModel(modelOpts, config, modelPath),
validateTokenizer(modelOpts, config, modelPath),
validateProcessor(modelOpts, config, modelPath),
];
if ('vocoder' in modelOpts) {
componentValidationPromises.push(validateVocoder(modelOpts, config, modelPath));
}
const [model, tokenizer, processor, vocoder] = await Promise.all(componentValidationPromises);
const result = {};
if (model)
result.model = model;
if (tokenizer)
result.tokenizer = tokenizer;
if (processor)
result.processor = processor;
if (vocoder)
result.vocoder = vocoder;
return result;
}
async function validateSpeechModel(modelOpts, config, modelPath) {
if (modelOpts.speakerEmbeddings) {
for (const voice of Object.values(modelOpts.speakerEmbeddings)) {
if (voice instanceof Float32Array) {
continue;
}
const speakerEmbeddingsPath = resolveModelFileLocation({
url: voice.url,
filePath: voice.file,
modelsCachePath: config.modelsCachePath,
});
if (!fs.existsSync(speakerEmbeddingsPath)) {
return `Speaker embeddings file does not exist: ${speakerEmbeddingsPath}`;
}
}
}
return validateModelComponents(modelOpts, config, modelPath);
}
export async function validateModelFiles(config) {
if (!fs.existsSync(config.location)) {
return {
message: `model directory does not exist: ${config.location}`,
};
}
let modelPath = config.location;
if (!modelPath.endsWith('/')) {
modelPath += '/';
}
const modelValidationPromises = {};
// const noModelConfigured = !config.textModel && !config.visionModel && !config.speechModel
modelValidationPromises.primaryModel = validateModelComponents(config, config, modelPath);
if (config.textModel) {
modelValidationPromises.textModel = validateModelComponents(config.textModel, config, modelPath);
}
if (config.visionModel) {
modelValidationPromises.visionModel = validateModelComponents(config.visionModel, config, modelPath);
}
if (config.speechModel) {
modelValidationPromises.speechModel = validateSpeechModel(config.speechModel, config, modelPath);
}
await Promise.all(Object.values(modelValidationPromises));
const validationErrors = {};
const primaryModelErrors = await modelValidationPromises.primaryModel;
if (primaryModelErrors && Object.keys(primaryModelErrors).length) {
validationErrors.primaryModel = primaryModelErrors;
}
const textModelErrors = await modelValidationPromises.textModel;
if (textModelErrors && Object.keys(textModelErrors).length) {
validationErrors.textModel = textModelErrors;
}
const visionModelErrors = await modelValidationPromises.visionModel;
if (visionModelErrors && Object.keys(visionModelErrors).length) {
validationErrors.visionModel = visionModelErrors;
}
const speechModelErrors = await modelValidationPromises.speechModel;
if (speechModelErrors && Object.keys(speechModelErrors).length) {
validationErrors.speechModel = speechModelErrors;
}
const vocoderModelErrors = await modelValidationPromises.vocoderModel;
if (vocoderModelErrors && Object.keys(vocoderModelErrors).length) {
validationErrors.vocoderModel = vocoderModelErrors;
}
if (Object.keys(validationErrors).length > 0) {
return {
message: 'Failed to validate model components',
errors: validationErrors,
};
}
return undefined;
}
//# sourceMappingURL=validateModelFiles.js.map