UNPKG

inference-server

Version:

Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.

207 lines 8.7 kB
import { promises as fs, existsSync } from 'node:fs'; import PQueue from 'p-queue'; import prettyMilliseconds from 'pretty-ms'; import prettyBytes from 'pretty-bytes'; import { LogLevels, createSublogger, } from './lib/logger.js'; import { formatBytesPerSecond, mergeAbortSignals } from './lib/util.js'; export class ModelStore { prepareQueue; models = {}; engines; prepareController; modelsCachePath; log; constructor(options) { this.prepareController = new AbortController(); this.log = createSublogger(options.log); this.prepareQueue = new PQueue({ concurrency: options.prepareConcurrency ?? 10, }); this.modelsCachePath = options.modelsCachePath; this.models = Object.fromEntries(Object.entries(options.models).map(([modelId, model]) => [ modelId, { ...model, status: 'unloaded', }, ])); } async init(engines) { this.engines = engines; if (!existsSync(this.modelsCachePath)) { await fs.mkdir(this.modelsCachePath, { recursive: true }); } const blockingPromises = []; for (const modelId in this.models) { const model = this.models[modelId]; if (model.prepare === 'blocking' || model.minInstances > 0) { blockingPromises.push(this.prepareModel(modelId)); } else if (model.prepare === 'async') { this.prepareModel(modelId); } } if (blockingPromises.length) { this.log(LogLevels.debug, `Preparing files for ${blockingPromises.length} models`); await Promise.all(blockingPromises); this.log(LogLevels.debug, 'All files for initially required models are ready'); } } dispose() { this.prepareController.abort(); } onDownloadProgress(modelId, progress) { const model = this.models[modelId]; if (!model.downloads) { model.downloads = new Map(); } if (model.downloads.has(progress.file)) { const tracker = model.downloads.get(progress.file); tracker.pushProgress(progress); } else { const tracker = new DownloadTracker(5000); tracker.pushProgress(progress); model.downloads.set(progress.file, tracker); } } // makes sure all required files for the model exist and are valid // checking model checksums and reading metadata is model + engine specific and can be slow async prepareModel(modelId, signal) { const model = this.models[modelId]; if (!this.engines) { throw new Error('No engines available - did you call init()?'); } model.status = 'preparing'; const engine = this.engines[model.engine]; this.log(LogLevels.info, 'Preparing model', { model: modelId, task: model.task, }); await this.prepareQueue.add(async () => { if (!('prepareModel' in engine)) { model.status = 'ready'; return model; } const logProgressInterval = setInterval(() => { const progress = Array.from(model.downloads?.values() ?? []) .map((tracker) => tracker.getStatus()) .reduce((acc, status) => { acc.loadedBytes += status?.loadedBytes || 0; acc.totalBytes += status?.totalBytes || 0; acc.speed += status?.speed || 0; return acc; }, { loadedBytes: 0, totalBytes: 0, speed: 0 }); if (progress.totalBytes) { const percent = (progress.loadedBytes / progress.totalBytes) * 100; const formattedTotalBytes = prettyBytes(progress.totalBytes, { space: false }); const formattedLoadedBytes = prettyBytes(progress.loadedBytes, { space: false }); this.log(LogLevels.info, `Downloading at ${formatBytesPerSecond(progress.speed)} ${percent.toFixed(1)}% - ${formattedLoadedBytes} of ${formattedTotalBytes}`, { model: modelId, }); } }, 10000); try { const modelMeta = await engine.prepareModel({ config: model, log: this.log }, (progress) => { this.onDownloadProgress(model.id, progress); }, mergeAbortSignals([signal, this.prepareController.signal])); model.downloads = undefined; model.meta = modelMeta; model.status = 'ready'; this.log(LogLevels.info, 'Model ready to use', { model: modelId, task: model.task, }); } catch (error) { this.log(LogLevels.error, 'Error preparing model', { model: modelId, error: error, }); model.status = 'error'; } finally { clearInterval(logProgressInterval); } return model; }); } getStatus() { const formatFloat = (num) => parseFloat(num?.toFixed(2) || '0'); const storeStatusInfo = Object.fromEntries(Object.entries(this.models).map(([modelId, model]) => { let downloads = undefined; if (model.downloads) { downloads = [...model.downloads].reduce((acc, [key, download]) => { const status = download.getStatus(); const latestState = download.progressBuffer[download.progressBuffer.length - 1]; const totalBytes = latestState?.totalBytes ?? 0; const loadedBytes = latestState?.loadedBytes ?? 0; const etaSeconds = status?.etaSeconds ?? 0; const formattedEta = prettyMilliseconds(etaSeconds * 1000); const formattedTotalBytes = prettyBytes(totalBytes); const formattedLoadedBytes = prettyBytes(loadedBytes); acc.push({ file: key, loadedBytes, formattedLoadedBytes, totalBytes, formattedTotalBytes, percent: formatFloat(status?.percent), speed: formatFloat(status?.speed), etaSeconds: formatFloat(etaSeconds), formattedEta, }); return acc; }, []); } return [ modelId, { engine: model.engine, device: model.device, minInstances: model.minInstances, maxInstances: model.maxInstances, status: model.status, downloads, }, ]; })); return storeStatusInfo; } } class DownloadTracker { progressBuffer = []; timeWindow; constructor(timeWindow = 1000) { this.timeWindow = timeWindow; } pushProgress({ loadedBytes, totalBytes }) { const timestamp = Date.now(); this.progressBuffer.push({ loadedBytes, totalBytes, timestamp }); this.cleanup(); } cleanup() { const cutoffTime = Date.now() - this.timeWindow; this.progressBuffer = this.progressBuffer.filter((item) => item.timestamp >= cutoffTime); } getStatus() { if (this.progressBuffer.length < 2) { return null; // Not enough data to calculate speed and ETA } const latestState = this.progressBuffer[this.progressBuffer.length - 1]; const previousState = this.progressBuffer[0]; // oldest state within the time window const bytesLoaded = latestState.loadedBytes - previousState.loadedBytes; const timeElapsed = latestState.timestamp - previousState.timestamp; // in milliseconds const speed = bytesLoaded / (timeElapsed / 1000); // bytes per second const remainingBytes = latestState.totalBytes - latestState.loadedBytes; const eta = speed > 0 ? remainingBytes / speed : 0; return { speed, etaSeconds: eta, percent: latestState.loadedBytes / latestState.totalBytes, loadedBytes: latestState.loadedBytes, totalBytes: latestState.totalBytes, }; } } //# sourceMappingURL=store.js.map