inference-server
Version:
Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.
207 lines • 8.7 kB
JavaScript
import { promises as fs, existsSync } from 'node:fs';
import PQueue from 'p-queue';
import prettyMilliseconds from 'pretty-ms';
import prettyBytes from 'pretty-bytes';
import { LogLevels, createSublogger, } from './lib/logger.js';
import { formatBytesPerSecond, mergeAbortSignals } from './lib/util.js';
export class ModelStore {
prepareQueue;
models = {};
engines;
prepareController;
modelsCachePath;
log;
constructor(options) {
this.prepareController = new AbortController();
this.log = createSublogger(options.log);
this.prepareQueue = new PQueue({
concurrency: options.prepareConcurrency ?? 10,
});
this.modelsCachePath = options.modelsCachePath;
this.models = Object.fromEntries(Object.entries(options.models).map(([modelId, model]) => [
modelId,
{
...model,
status: 'unloaded',
},
]));
}
async init(engines) {
this.engines = engines;
if (!existsSync(this.modelsCachePath)) {
await fs.mkdir(this.modelsCachePath, { recursive: true });
}
const blockingPromises = [];
for (const modelId in this.models) {
const model = this.models[modelId];
if (model.prepare === 'blocking' || model.minInstances > 0) {
blockingPromises.push(this.prepareModel(modelId));
}
else if (model.prepare === 'async') {
this.prepareModel(modelId);
}
}
if (blockingPromises.length) {
this.log(LogLevels.debug, `Preparing files for ${blockingPromises.length} models`);
await Promise.all(blockingPromises);
this.log(LogLevels.debug, 'All files for initially required models are ready');
}
}
dispose() {
this.prepareController.abort();
}
onDownloadProgress(modelId, progress) {
const model = this.models[modelId];
if (!model.downloads) {
model.downloads = new Map();
}
if (model.downloads.has(progress.file)) {
const tracker = model.downloads.get(progress.file);
tracker.pushProgress(progress);
}
else {
const tracker = new DownloadTracker(5000);
tracker.pushProgress(progress);
model.downloads.set(progress.file, tracker);
}
}
// makes sure all required files for the model exist and are valid
// checking model checksums and reading metadata is model + engine specific and can be slow
async prepareModel(modelId, signal) {
const model = this.models[modelId];
if (!this.engines) {
throw new Error('No engines available - did you call init()?');
}
model.status = 'preparing';
const engine = this.engines[model.engine];
this.log(LogLevels.info, 'Preparing model', {
model: modelId,
task: model.task,
});
await this.prepareQueue.add(async () => {
if (!('prepareModel' in engine)) {
model.status = 'ready';
return model;
}
const logProgressInterval = setInterval(() => {
const progress = Array.from(model.downloads?.values() ?? [])
.map((tracker) => tracker.getStatus())
.reduce((acc, status) => {
acc.loadedBytes += status?.loadedBytes || 0;
acc.totalBytes += status?.totalBytes || 0;
acc.speed += status?.speed || 0;
return acc;
}, { loadedBytes: 0, totalBytes: 0, speed: 0 });
if (progress.totalBytes) {
const percent = (progress.loadedBytes / progress.totalBytes) * 100;
const formattedTotalBytes = prettyBytes(progress.totalBytes, { space: false });
const formattedLoadedBytes = prettyBytes(progress.loadedBytes, { space: false });
this.log(LogLevels.info, `Downloading at ${formatBytesPerSecond(progress.speed)} ${percent.toFixed(1)}% - ${formattedLoadedBytes} of ${formattedTotalBytes}`, {
model: modelId,
});
}
}, 10000);
try {
const modelMeta = await engine.prepareModel({ config: model, log: this.log }, (progress) => {
this.onDownloadProgress(model.id, progress);
}, mergeAbortSignals([signal, this.prepareController.signal]));
model.downloads = undefined;
model.meta = modelMeta;
model.status = 'ready';
this.log(LogLevels.info, 'Model ready to use', {
model: modelId,
task: model.task,
});
}
catch (error) {
this.log(LogLevels.error, 'Error preparing model', {
model: modelId,
error: error,
});
model.status = 'error';
}
finally {
clearInterval(logProgressInterval);
}
return model;
});
}
getStatus() {
const formatFloat = (num) => parseFloat(num?.toFixed(2) || '0');
const storeStatusInfo = Object.fromEntries(Object.entries(this.models).map(([modelId, model]) => {
let downloads = undefined;
if (model.downloads) {
downloads = [...model.downloads].reduce((acc, [key, download]) => {
const status = download.getStatus();
const latestState = download.progressBuffer[download.progressBuffer.length - 1];
const totalBytes = latestState?.totalBytes ?? 0;
const loadedBytes = latestState?.loadedBytes ?? 0;
const etaSeconds = status?.etaSeconds ?? 0;
const formattedEta = prettyMilliseconds(etaSeconds * 1000);
const formattedTotalBytes = prettyBytes(totalBytes);
const formattedLoadedBytes = prettyBytes(loadedBytes);
acc.push({
file: key,
loadedBytes,
formattedLoadedBytes,
totalBytes,
formattedTotalBytes,
percent: formatFloat(status?.percent),
speed: formatFloat(status?.speed),
etaSeconds: formatFloat(etaSeconds),
formattedEta,
});
return acc;
}, []);
}
return [
modelId,
{
engine: model.engine,
device: model.device,
minInstances: model.minInstances,
maxInstances: model.maxInstances,
status: model.status,
downloads,
},
];
}));
return storeStatusInfo;
}
}
class DownloadTracker {
progressBuffer = [];
timeWindow;
constructor(timeWindow = 1000) {
this.timeWindow = timeWindow;
}
pushProgress({ loadedBytes, totalBytes }) {
const timestamp = Date.now();
this.progressBuffer.push({ loadedBytes, totalBytes, timestamp });
this.cleanup();
}
cleanup() {
const cutoffTime = Date.now() - this.timeWindow;
this.progressBuffer = this.progressBuffer.filter((item) => item.timestamp >= cutoffTime);
}
getStatus() {
if (this.progressBuffer.length < 2) {
return null; // Not enough data to calculate speed and ETA
}
const latestState = this.progressBuffer[this.progressBuffer.length - 1];
const previousState = this.progressBuffer[0]; // oldest state within the time window
const bytesLoaded = latestState.loadedBytes - previousState.loadedBytes;
const timeElapsed = latestState.timestamp - previousState.timestamp; // in milliseconds
const speed = bytesLoaded / (timeElapsed / 1000); // bytes per second
const remainingBytes = latestState.totalBytes - latestState.loadedBytes;
const eta = speed > 0 ? remainingBytes / speed : 0;
return {
speed,
etaSeconds: eta,
percent: latestState.loadedBytes / latestState.totalBytes,
loadedBytes: latestState.loadedBytes,
totalBytes: latestState.totalBytes,
};
}
}
//# sourceMappingURL=store.js.map