inference-server
Version:
Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.
342 lines (319 loc) • 9.46 kB
text/typescript
import StableDiffusion from '@lmagder/node-stable-diffusion-cpp'
import { gguf } from '@huggingface/gguf'
import fs from 'node:fs'
import path from 'node:path'
import {
EngineContext,
FileDownloadProgress,
ModelConfig,
TextToImageTaskResult,
ModelFileSource,
Image,
TextToImageTaskArgs,
EngineTaskContext,
ImageToImageTaskArgs,
} from '#package/types/index.js'
import { LogLevel, LogLevels } from '#package/lib/logger.js'
import { downloadModelFile } from '#package/lib/downloadModelFile.js'
import { resolveModelFileLocation } from '#package/lib/resolveModelFileLocation.js'
import { acquireFileLock } from '#package/lib/acquireFileLock.js'
import { getRandomNumber } from '#package/lib/util.js'
import { StableDiffusionSamplingMethod, StableDiffusionSchedule, StableDiffusionWeightType } from './types.js'
import { validateModelFiles, ModelValidationResult } from './validateModelFiles.js'
import { parseQuantization, getWeightType, getSamplingMethod } from './util.js'
export interface StableDiffusionInstance {
context: StableDiffusion.Context
}
export interface StableDiffusionModelConfig extends ModelConfig {
location: string
sha256?: string
clipL?: ModelFileSource
clipG?: ModelFileSource
vae?: ModelFileSource
t5xxl?: ModelFileSource
controlNet?: ModelFileSource
taesd?: ModelFileSource
diffusionModel?: boolean
model?: ModelFileSource
loras?: ModelFileSource[]
samplingMethod?: StableDiffusionSamplingMethod
weightType?: StableDiffusionWeightType
schedule?: StableDiffusionSchedule
device?: {
gpu?: boolean | 'auto' | (string & {})
cpuThreads?: number
}
}
interface StableDiffusionModelMeta {
gguf: any
}
export const autoGpu = true
export async function prepareModel(
{ config, log }: EngineContext<StableDiffusionModelConfig, StableDiffusionModelMeta>,
onProgress?: (progress: FileDownloadProgress) => void,
signal?: AbortSignal,
) {
fs.mkdirSync(path.dirname(config.location), { recursive: true })
const releaseFileLock = await acquireFileLock(config.location)
if (signal?.aborted) {
releaseFileLock()
return
}
log(LogLevels.info, `Preparing stable-diffusion model at ${config.location}`, {
model: config.id,
})
const downloadModel = (url: string, validationResult: ModelValidationResult) => {
log(LogLevels.info, `${validationResult.message} - Downloading model files`, {
model: config.id,
url: config.url,
location: config.location,
errors: validationResult.errors,
})
const downloadPromises = []
if (validationResult.errors.model && config.location) {
downloadPromises.push(
downloadModelFile({
url: url,
filePath: config.location,
modelsCachePath: config.modelsCachePath,
onProgress,
signal,
}),
)
}
const pushDownload = (src: ModelFileSource) => {
if (!src.url) {
return
}
downloadPromises.push(
downloadModelFile({
url: src.url,
filePath: src.file,
modelsCachePath: config.modelsCachePath,
onProgress,
signal,
}),
)
}
if (validationResult.errors.clipG && config.clipG) {
pushDownload(config.clipG)
}
if (validationResult.errors.clipL && config.clipL) {
pushDownload(config.clipL)
}
if (validationResult.errors.vae && config.vae) {
pushDownload(config.vae)
}
if (validationResult.errors.t5xxl && config.t5xxl) {
pushDownload(config.t5xxl)
}
if (validationResult.errors.controlNet && config.controlNet) {
pushDownload(config.controlNet)
}
if (validationResult.errors.taesd && config.taesd) {
pushDownload(config.taesd)
}
if (config.loras) {
for (const lora of config.loras) {
if (!lora.url) {
continue
}
pushDownload(lora)
}
}
return Promise.all(downloadPromises)
}
try {
if (signal?.aborted) {
return
}
const validationResults = await validateModelFiles(config)
if (signal?.aborted) {
return
}
if (validationResults) {
if (config.url) {
await downloadModel(config.url, validationResults)
} else {
throw new Error(`${validationResults.message} - No URL provided`)
}
}
const finalValidationError = await validateModelFiles(config)
if (finalValidationError) {
throw new Error(`Downloaded files are invalid: ${finalValidationError}`)
}
const result: any = {}
if (config.location.endsWith('.gguf')) {
const { metadata, tensorInfos } = await gguf(config.location, {
allowLocalFile: true,
})
result.gguf = metadata
}
return result
} catch (error) {
throw error
} finally {
releaseFileLock()
}
}
export async function createInstance({ config, log }: EngineContext<StableDiffusionModelConfig>, signal?: AbortSignal) {
log(LogLevels.debug, 'Load Stable Diffusion model', config)
const handleLog = (level: string, message: string) => {
log(level as LogLevel, message)
}
const handleProgress = (step: number, steps: number, time: number) => {
log(LogLevels.debug, `Progress: ${step}/${steps} (${time}ms)`)
}
const resolveComponentLocation = (src?: ModelFileSource) => {
if (src) {
return resolveModelFileLocation({
url: src.url,
filePath: src.file,
modelsCachePath: config.modelsCachePath,
})
}
return undefined
}
const vaeFilePath = resolveComponentLocation(config.vae)
const clipLFilePath = resolveComponentLocation(config.clipL)
const clipGFilePath = resolveComponentLocation(config.clipG)
const t5xxlFilePath = resolveComponentLocation(config.t5xxl)
const controlNetFilePath = resolveComponentLocation(config.controlNet)
const taesdFilePath = resolveComponentLocation(config.taesd)
let weightType = config.weightType ? getWeightType(config.weightType) : undefined
if (typeof weightType === 'undefined') {
const quantization = parseQuantization(config.location)
if (quantization) {
weightType = getWeightType(quantization)
}
}
if (typeof weightType === 'undefined') {
log(LogLevels.warn, 'Failed to parse model weight type (quantization) from file name, falling back to f32', {
file: config.location,
})
}
const loraDir = path.join(path.dirname(config.location), 'loras')
const contextParams = {
model: !config.diffusionModel ? config.location : undefined,
diffusionModel: config.diffusionModel ? config.location : undefined,
numThreads: config.device?.cpuThreads,
vae: vaeFilePath,
clipL: clipLFilePath,
clipG: clipGFilePath,
t5xxl: t5xxlFilePath,
controlNet: controlNetFilePath,
taesd: taesdFilePath,
weightType: weightType,
loraDir: loraDir,
// TODO how to expose?
// keepClipOnCpu: true,
// keepControlNetOnCpu: true,
// keepVaeOnCpu: true,
}
log(LogLevels.debug, 'Creating context with', contextParams)
const context = await StableDiffusion.createContext(
// @ts-ignore
contextParams,
handleLog,
handleProgress,
)
return {
context,
}
}
export async function processTextToImageTask(
task: TextToImageTaskArgs,
ctx: EngineTaskContext<StableDiffusionInstance, StableDiffusionModelConfig, StableDiffusionModelMeta>,
signal?: AbortSignal,
): Promise<TextToImageTaskResult> {
const { instance, config, log } = ctx
const seed = task.seed ?? getRandomNumber(0, 1000000)
const results = await instance.context.txt2img({
prompt: task.prompt,
negativePrompt: task.negativePrompt,
width: task.width || 512,
height: task.height || 512,
batchCount: task.batchCount,
sampleMethod: getSamplingMethod(task.samplingMethod || config.samplingMethod),
sampleSteps: task.sampleSteps,
cfgScale: task.cfgScale,
guidance: task.guidance,
styleRatio: task.styleRatio,
controlStrength: task.controlStrength,
normalizeInput: false,
seed,
})
const images: Image[] = []
for (const [idx, img] of results.entries()) {
images.push({
data: img.data,
width: img.width,
height: img.height,
channels: img.channel,
})
}
if (!images.length) {
throw new Error('No images generated')
}
return {
images: images,
seed,
}
}
export async function processImageToImageTask(
task: ImageToImageTaskArgs,
ctx: EngineTaskContext<StableDiffusionInstance, StableDiffusionModelConfig, StableDiffusionModelMeta>,
signal?: AbortSignal,
): Promise<TextToImageTaskResult> {
const { instance, config, log } = ctx
const seed = task.seed ?? getRandomNumber(0, 1000000)
console.debug('processImageToImageTask', {
width: task.image.width,
height: task.image.height,
channel: task.image.channels as 3 | 4,
})
const initImage = {
// data: await request.image.handle.raw().toBuffer(),
data: task.image.data,
width: task.image.width,
height: task.image.height,
channel: task.image.channels as 3 | 4,
}
const results = await instance.context.img2img({
initImage,
prompt: task.prompt,
width: task.width || 512,
height: task.height || 512,
batchCount: task.batchCount,
sampleMethod: getSamplingMethod(task.samplingMethod || config.samplingMethod),
cfgScale: task.cfgScale,
sampleSteps: task.sampleSteps,
guidance: task.guidance,
strength: task.strength,
styleRatio: task.styleRatio,
controlStrength: task.controlStrength,
seed,
})
const images: Image[] = []
for (const [idx, img] of results.entries()) {
// console.debug('img', {
// id: idx,
// width: img.width,
// height: img.height,
// channels: img.channel,
// })
images.push({
data: img.data,
width: img.width,
height: img.height,
channels: img.channel,
})
}
if (!images.length) {
throw new Error('No images generated')
}
return {
images: images,
seed,
}
}