UNPKG

inference-server

Version:

Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.

202 lines 8.34 kB
import fs from 'node:fs'; import { AutoModel, AutoProcessor, AutoTokenizer, } from '@huggingface/transformers'; import { resolveModelFileLocation } from '../../lib/resolveModelFileLocation.js'; import { parseHuggingfaceModelIdAndBranch, remoteFileExists, normalizeTransformersJsClass } from './util.js'; // currently the way transformers.js models are validated is by trying to load them, catching the error and checking if the error message contains the signature. // major upside being that we let transformers.js logic figure out which files are required. downside is, we have to hope that the message text stays the same. const fileNotFoundSignature = 'file was not found locally'; async function validateModel(modelOpts, config, modelPath) { const modelClass = normalizeTransformersJsClass(modelOpts.modelClass, AutoModel); const device = config.device?.gpu ? 'gpu' : 'cpu'; try { const model = await modelClass.from_pretrained(modelPath, { local_files_only: true, device: device, dtype: modelOpts.dtype || 'fp32', }); await model.dispose(); } catch (error) { if (error.message.includes(fileNotFoundSignature)) { return `Failed to load model (${error.message})`; } throw error; } return undefined; } async function validateTokenizer(modelOpts, config, modelPath) { const tokenizerClass = normalizeTransformersJsClass(modelOpts.tokenizerClass, AutoTokenizer); try { if (config.url) { const { branch } = parseHuggingfaceModelIdAndBranch(config.url); const hasTokenizer = await remoteFileExists(`${config.url}/blob/${branch}/tokenizer.json`); if (hasTokenizer) { await tokenizerClass.from_pretrained(modelPath, { local_files_only: true, }); } } } catch (error) { if (error.message.includes(fileNotFoundSignature)) { return `Failed to load model (${error.message})`; } throw error; } return undefined; } async function validateProcessor(modelOpts, config, modelPath) { const processorClass = normalizeTransformersJsClass(modelOpts.processorClass, AutoProcessor); try { if (modelOpts.processor) { const processorPath = resolveModelFileLocation({ url: modelOpts.processor.url, filePath: modelOpts.processor.file, modelsCachePath: config.modelsCachePath, }); await processorClass.from_pretrained(processorPath, { local_files_only: true, }); } else { if (modelOpts.processorClass) { await processorClass.from_pretrained(modelPath, { local_files_only: true, }); } else if (config.url) { const { branch } = parseHuggingfaceModelIdAndBranch(config.url); const [hasProcessor, hasPreprocessor] = await Promise.all([ remoteFileExists(`${config.url}/blob/${branch}/processor_config.json`), remoteFileExists(`${config.url}/blob/${branch}/preprocessor_config.json`), ]); if (hasProcessor || hasPreprocessor) { await processorClass.from_pretrained(modelPath, { local_files_only: true, }); } } } } catch (error) { if (error.message.includes(fileNotFoundSignature)) { return `Failed to load model (${error.message})`; } throw error; } return undefined; } async function validateVocoder(modelOpts, config, modelPath) { const vocoderClass = normalizeTransformersJsClass(modelOpts.vocoderClass, AutoModel); if (modelOpts.vocoder) { const vocoderPath = resolveModelFileLocation({ url: modelOpts.vocoder.url, filePath: modelOpts.vocoder.file, modelsCachePath: config.modelsCachePath, }); try { await vocoderClass.from_pretrained(vocoderPath, { local_files_only: true, }); } catch (error) { if (error.message.includes(fileNotFoundSignature)) { return `Failed to load vocoder (${error.message})`; } throw error; } } return undefined; } async function validateModelComponents(modelOpts, config, modelPath) { const componentValidationPromises = [ validateModel(modelOpts, config, modelPath), validateTokenizer(modelOpts, config, modelPath), validateProcessor(modelOpts, config, modelPath), ]; if ('vocoder' in modelOpts) { componentValidationPromises.push(validateVocoder(modelOpts, config, modelPath)); } const [model, tokenizer, processor, vocoder] = await Promise.all(componentValidationPromises); const result = {}; if (model) result.model = model; if (tokenizer) result.tokenizer = tokenizer; if (processor) result.processor = processor; if (vocoder) result.vocoder = vocoder; return result; } async function validateSpeechModel(modelOpts, config, modelPath) { if (modelOpts.speakerEmbeddings) { for (const voice of Object.values(modelOpts.speakerEmbeddings)) { if (voice instanceof Float32Array) { continue; } const speakerEmbeddingsPath = resolveModelFileLocation({ url: voice.url, filePath: voice.file, modelsCachePath: config.modelsCachePath, }); if (!fs.existsSync(speakerEmbeddingsPath)) { return `Speaker embeddings file does not exist: ${speakerEmbeddingsPath}`; } } } return validateModelComponents(modelOpts, config, modelPath); } export async function validateModelFiles(config) { if (!fs.existsSync(config.location)) { return { message: `model directory does not exist: ${config.location}`, }; } let modelPath = config.location; if (!modelPath.endsWith('/')) { modelPath += '/'; } const modelValidationPromises = {}; // const noModelConfigured = !config.textModel && !config.visionModel && !config.speechModel modelValidationPromises.primaryModel = validateModelComponents(config, config, modelPath); if (config.textModel) { modelValidationPromises.textModel = validateModelComponents(config.textModel, config, modelPath); } if (config.visionModel) { modelValidationPromises.visionModel = validateModelComponents(config.visionModel, config, modelPath); } if (config.speechModel) { modelValidationPromises.speechModel = validateSpeechModel(config.speechModel, config, modelPath); } await Promise.all(Object.values(modelValidationPromises)); const validationErrors = {}; const primaryModelErrors = await modelValidationPromises.primaryModel; if (primaryModelErrors && Object.keys(primaryModelErrors).length) { validationErrors.primaryModel = primaryModelErrors; } const textModelErrors = await modelValidationPromises.textModel; if (textModelErrors && Object.keys(textModelErrors).length) { validationErrors.textModel = textModelErrors; } const visionModelErrors = await modelValidationPromises.visionModel; if (visionModelErrors && Object.keys(visionModelErrors).length) { validationErrors.visionModel = visionModelErrors; } const speechModelErrors = await modelValidationPromises.speechModel; if (speechModelErrors && Object.keys(speechModelErrors).length) { validationErrors.speechModel = speechModelErrors; } const vocoderModelErrors = await modelValidationPromises.vocoderModel; if (vocoderModelErrors && Object.keys(vocoderModelErrors).length) { validationErrors.vocoderModel = vocoderModelErrors; } if (Object.keys(validationErrors).length > 0) { return { message: 'Failed to validate model components', errors: validationErrors, }; } return undefined; } //# sourceMappingURL=validateModelFiles.js.map