UNPKG

inference-server

Version:

Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.

161 lines 7.12 kB
import fs from 'node:fs'; import { AutoModel, AutoProcessor, AutoTokenizer, } from '@huggingface/transformers'; import { resolveModelFileLocation } from '../../lib/resolveModelFileLocation.js'; import { normalizeTransformersJsClass } from './util.js'; export async function loadModelComponents(modelOpts, config) { const device = config.device?.gpu ? 'gpu' : 'cpu'; const modelClass = normalizeTransformersJsClass(modelOpts.modelClass, AutoModel); let modelPath = config.location; if (!modelPath.endsWith('/')) { modelPath += '/'; } const loadPromises = []; const modelPromise = modelClass.from_pretrained(modelPath, { local_files_only: true, device: device, dtype: modelOpts.dtype || 'fp32', }); loadPromises.push(modelPromise); const hasTokenizer = fs.existsSync(modelPath + 'tokenizer.json'); if (hasTokenizer) { const tokenizerClass = normalizeTransformersJsClass(modelOpts.tokenizerClass, AutoTokenizer); const tokenizerPromise = tokenizerClass.from_pretrained(modelPath, { local_files_only: true, }); loadPromises.push(tokenizerPromise); } else { loadPromises.push(Promise.resolve(undefined)); } const hasPreprocessor = fs.existsSync(modelPath + 'preprocessor_config.json'); const hasProcessor = fs.existsSync(modelPath + 'processor_config.json'); if (hasProcessor || hasPreprocessor || modelOpts.processor) { const processorClass = normalizeTransformersJsClass(modelOpts.processorClass, AutoProcessor); if (modelOpts.processor) { const processorPath = resolveModelFileLocation({ url: modelOpts.processor.url, filePath: modelOpts.processor.file, modelsCachePath: config.modelsCachePath, }); const processorPromise = processorClass.from_pretrained(processorPath, { local_files_only: true, }); loadPromises.push(processorPromise); } else { const processorPromise = processorClass.from_pretrained(modelPath, { local_files_only: true, }); loadPromises.push(processorPromise); } } else { loadPromises.push(Promise.resolve(undefined)); } if ('vocoder' in modelOpts && modelOpts.vocoder) { const vocoderClass = normalizeTransformersJsClass(modelOpts.vocoderClass, AutoModel); const vocoderPath = resolveModelFileLocation({ url: modelOpts.vocoder.url, filePath: modelOpts.vocoder.file, modelsCachePath: config.modelsCachePath, }); const vocoderPromise = vocoderClass.from_pretrained(vocoderPath, { local_files_only: true, }); loadPromises.push(vocoderPromise); } else { loadPromises.push(Promise.resolve(undefined)); } if ('speakerEmbeddings' in modelOpts && modelOpts.speakerEmbeddings) { const speakerEmbeddings = modelOpts.speakerEmbeddings; const speakerEmbeddingsPromises = []; for (const speakerEmbedding of Object.values(speakerEmbeddings)) { if (speakerEmbedding instanceof Float32Array) { speakerEmbeddingsPromises.push(Promise.resolve(speakerEmbedding)); continue; } const speakerEmbeddingPath = resolveModelFileLocation({ url: speakerEmbedding.url, filePath: speakerEmbedding.file, modelsCachePath: config.modelsCachePath, }); const speakerEmbeddingPromise = fs.promises .readFile(speakerEmbeddingPath) .then((data) => new Float32Array(data.buffer)); speakerEmbeddingsPromises.push(speakerEmbeddingPromise); } loadPromises.push(Promise.all(speakerEmbeddingsPromises)); } else { loadPromises.push(Promise.resolve(undefined)); } const loadedComponents = await Promise.all(loadPromises); const modelComponents = {}; if (loadedComponents[0]) { modelComponents.model = loadedComponents[0]; } if (loadedComponents[1]) { modelComponents.tokenizer = loadedComponents[1]; } if (loadedComponents[2]) { modelComponents.processor = loadedComponents[2]; } if (loadedComponents[3]) { modelComponents.vocoder = loadedComponents[3]; } if (loadedComponents[4] && 'speakerEmbeddings' in modelOpts && modelOpts.speakerEmbeddings) { const loadedSpeakerEmbeddings = loadedComponents[4]; modelComponents.speakerEmbeddings = Object.fromEntries(Object.keys(modelOpts.speakerEmbeddings).map((key, index) => [key, loadedSpeakerEmbeddings[index]])); } return modelComponents; } export async function loadSpeechModelComponents(modelOpts, config) { const loadPromises = [loadModelComponents(modelOpts, config)]; if (modelOpts.vocoder) { const vocoderClass = modelOpts.vocoderClass ?? AutoModel; const vocoderPath = resolveModelFileLocation({ url: modelOpts.vocoder.url, filePath: modelOpts.vocoder.file, modelsCachePath: config.modelsCachePath, }); const vocoderPromise = vocoderClass.from_pretrained(vocoderPath, { local_files_only: true, }); loadPromises.push(vocoderPromise); } else { loadPromises.push(Promise.resolve(undefined)); } if ('speakerEmbeddings' in modelOpts && modelOpts.speakerEmbeddings) { const speakerEmbeddings = modelOpts.speakerEmbeddings; const speakerEmbeddingsPromises = []; for (const speakerEmbedding of Object.values(speakerEmbeddings)) { if (speakerEmbedding instanceof Float32Array) { speakerEmbeddingsPromises.push(Promise.resolve(speakerEmbedding)); continue; } const speakerEmbeddingPath = resolveModelFileLocation({ url: speakerEmbedding.url, filePath: speakerEmbedding.file, modelsCachePath: config.modelsCachePath, }); const speakerEmbeddingPromise = fs.promises .readFile(speakerEmbeddingPath) .then((data) => new Float32Array(data.buffer)); speakerEmbeddingsPromises.push(speakerEmbeddingPromise); } loadPromises.push(Promise.all(speakerEmbeddingsPromises)); } const loadedComponents = await Promise.all(loadPromises); const speechModelInstance = loadedComponents[0]; if (loadedComponents[1]) { speechModelInstance.vocoder = loadedComponents[1]; } if (loadedComponents[2] && 'speakerEmbeddings' in modelOpts && modelOpts.speakerEmbeddings) { const loadedSpeakerEmbeddings = loadedComponents[2]; speechModelInstance.speakerEmbeddings = Object.fromEntries(Object.keys(modelOpts.speakerEmbeddings).map((key, index) => [key, loadedSpeakerEmbeddings[index]])); } return speechModelInstance; } //# sourceMappingURL=loadModelComponents.js.map