inference-server
Version:
Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.
284 lines (258 loc) • 8.08 kB
text/typescript
import { promises as fs, existsSync } from 'node:fs'
import PQueue from 'p-queue'
import prettyMilliseconds from 'pretty-ms'
import prettyBytes from 'pretty-bytes'
import {
FileDownloadProgress,
ModelConfig,
ModelEngine,
} from '#package/types/index.js'
import {
Logger,
LogLevels,
LogLevel,
createSublogger,
} from '#package/lib/logger.js'
import { formatBytesPerSecond, mergeAbortSignals } from '#package/lib/util.js'
interface ModelFile {
size: number
}
export interface StoredModel extends ModelConfig {
meta?: unknown
downloads?: Map<string, DownloadTracker>
status: 'unloaded' | 'preparing' | 'ready' | 'error'
}
export interface ModelStoreOptions {
modelsCachePath: string
models: Record<string, ModelConfig>
prepareConcurrency?: number
log?: Logger | LogLevel
}
export class ModelStore {
prepareQueue: PQueue
models: Record<string, StoredModel> = {}
engines?: Record<string, ModelEngine>
private prepareController: AbortController
private modelsCachePath: string
private log: Logger
constructor(options: ModelStoreOptions) {
this.prepareController = new AbortController()
this.log = createSublogger(options.log)
this.prepareQueue = new PQueue({
concurrency: options.prepareConcurrency ?? 10,
})
this.modelsCachePath = options.modelsCachePath
this.models = Object.fromEntries(
Object.entries(options.models).map(([modelId, model]) => [
modelId,
{
...model,
status: 'unloaded',
},
]),
)
}
async init(engines: Record<string, ModelEngine>) {
this.engines = engines
if (!existsSync(this.modelsCachePath)) {
await fs.mkdir(this.modelsCachePath, { recursive: true })
}
const blockingPromises = []
for (const modelId in this.models) {
const model = this.models[modelId]
if (model.prepare === 'blocking' || model.minInstances > 0) {
blockingPromises.push(this.prepareModel(modelId))
} else if (model.prepare === 'async') {
this.prepareModel(modelId)
}
}
if (blockingPromises.length) {
this.log(LogLevels.debug, `Preparing files for ${blockingPromises.length} models`)
await Promise.all(blockingPromises)
this.log(LogLevels.debug, 'All files for initially required models are ready')
}
}
dispose() {
this.prepareController.abort()
}
private onDownloadProgress(
modelId: string,
progress: { file: string; loadedBytes: number; totalBytes: number },
) {
const model = this.models[modelId]
if (!model.downloads) {
model.downloads = new Map()
}
if (model.downloads.has(progress.file)) {
const tracker = model.downloads.get(progress.file)!
tracker.pushProgress(progress)
} else {
const tracker = new DownloadTracker(5000)
tracker.pushProgress(progress)
model.downloads.set(progress.file, tracker)
}
}
// makes sure all required files for the model exist and are valid
// checking model checksums and reading metadata is model + engine specific and can be slow
async prepareModel(modelId: string, signal?: AbortSignal) {
const model = this.models[modelId]
if (!this.engines) {
throw new Error('No engines available - did you call init()?')
}
model.status = 'preparing'
const engine = this.engines[model.engine]
this.log(LogLevels.info, 'Preparing model', {
model: modelId,
task: model.task,
})
await this.prepareQueue.add(async () => {
if (!('prepareModel' in engine)) {
model.status = 'ready'
return model
}
const logProgressInterval = setInterval(() => {
const progress = Array.from(model.downloads?.values() ?? [])
.map((tracker) => tracker.getStatus())
.reduce(
(acc, status) => {
acc.loadedBytes += status?.loadedBytes || 0
acc.totalBytes += status?.totalBytes || 0
acc.speed += status?.speed || 0
return acc
},
{ loadedBytes: 0, totalBytes: 0, speed: 0 },
)
if (progress.totalBytes) {
const percent = (progress.loadedBytes / progress.totalBytes) * 100
const formattedTotalBytes = prettyBytes(progress.totalBytes, { space: false })
const formattedLoadedBytes = prettyBytes(progress.loadedBytes, { space: false })
this.log(LogLevels.info, `Downloading at ${formatBytesPerSecond(progress.speed)} ${percent.toFixed(1)}% - ${formattedLoadedBytes} of ${formattedTotalBytes}`, {
model: modelId,
})
}
}, 10000)
try {
const modelMeta = await engine.prepareModel(
{ config: model, log: this.log },
(progress) => {
this.onDownloadProgress(model.id, progress)
},
mergeAbortSignals([signal, this.prepareController.signal]),
)
model.downloads = undefined
model.meta = modelMeta
model.status = 'ready'
this.log(LogLevels.info, 'Model ready to use', {
model: modelId,
task: model.task,
})
} catch (error) {
this.log(LogLevels.error, 'Error preparing model', {
model: modelId,
error: error,
})
model.status = 'error'
} finally {
clearInterval(logProgressInterval)
}
return model
})
}
getStatus() {
const formatFloat = (num?: number) => parseFloat(num?.toFixed(2) || '0')
const storeStatusInfo = Object.fromEntries(
Object.entries(this.models).map(([modelId, model]) => {
let downloads: any = undefined
if (model.downloads) {
downloads = [...model.downloads].reduce<any>(
(acc, [key, download]) => {
const status = download.getStatus()
const latestState =
download.progressBuffer[download.progressBuffer.length - 1]
const totalBytes = latestState?.totalBytes ?? 0
const loadedBytes = latestState?.loadedBytes ?? 0
const etaSeconds = status?.etaSeconds ?? 0
const formattedEta = prettyMilliseconds(etaSeconds * 1000)
const formattedTotalBytes = prettyBytes(totalBytes)
const formattedLoadedBytes = prettyBytes(loadedBytes)
acc.push({
file: key,
loadedBytes,
formattedLoadedBytes,
totalBytes,
formattedTotalBytes,
percent: formatFloat(status?.percent),
speed: formatFloat(status?.speed),
etaSeconds: formatFloat(etaSeconds),
formattedEta,
})
return acc
},
[],
)
}
return [
modelId,
{
engine: model.engine,
device: model.device,
minInstances: model.minInstances,
maxInstances: model.maxInstances,
status: model.status,
downloads,
},
]
}),
)
return storeStatusInfo
}
}
type ProgressState = {
loadedBytes: number
totalBytes: number
timestamp: number // in milliseconds
}
type DownloadStatus = {
percent: number
speed: number
etaSeconds: number
loadedBytes: number
totalBytes: number
}
class DownloadTracker {
progressBuffer: ProgressState[] = []
private timeWindow: number
constructor(timeWindow: number = 1000) {
this.timeWindow = timeWindow
}
pushProgress({ loadedBytes, totalBytes }: FileDownloadProgress): void {
const timestamp = Date.now()
this.progressBuffer.push({ loadedBytes, totalBytes, timestamp })
this.cleanup()
}
private cleanup(): void {
const cutoffTime = Date.now() - this.timeWindow
this.progressBuffer = this.progressBuffer.filter(
(item) => item.timestamp >= cutoffTime,
)
}
getStatus(): DownloadStatus | null {
if (this.progressBuffer.length < 2) {
return null // Not enough data to calculate speed and ETA
}
const latestState = this.progressBuffer[this.progressBuffer.length - 1]
const previousState = this.progressBuffer[0] // oldest state within the time window
const bytesLoaded = latestState.loadedBytes - previousState.loadedBytes
const timeElapsed = latestState.timestamp - previousState.timestamp // in milliseconds
const speed = bytesLoaded / (timeElapsed / 1000) // bytes per second
const remainingBytes = latestState.totalBytes - latestState.loadedBytes
const eta = speed > 0 ? remainingBytes / speed : 0
return {
speed,
etaSeconds: eta,
percent: latestState.loadedBytes / latestState.totalBytes,
loadedBytes: latestState.loadedBytes,
totalBytes: latestState.totalBytes,
}
}
}