UNPKG

inference-server

Version:

Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.

266 lines 9.31 kB
import StableDiffusion from '@lmagder/node-stable-diffusion-cpp'; import { gguf } from '@huggingface/gguf'; import fs from 'node:fs'; import path from 'node:path'; import { LogLevels } from '../../lib/logger.js'; import { downloadModelFile } from '../../lib/downloadModelFile.js'; import { resolveModelFileLocation } from '../../lib/resolveModelFileLocation.js'; import { acquireFileLock } from '../../lib/acquireFileLock.js'; import { getRandomNumber } from '../../lib/util.js'; import { validateModelFiles } from './validateModelFiles.js'; import { parseQuantization, getWeightType, getSamplingMethod } from './util.js'; export const autoGpu = true; export async function prepareModel({ config, log }, onProgress, signal) { fs.mkdirSync(path.dirname(config.location), { recursive: true }); const releaseFileLock = await acquireFileLock(config.location); if (signal?.aborted) { releaseFileLock(); return; } log(LogLevels.info, `Preparing stable-diffusion model at ${config.location}`, { model: config.id, }); const downloadModel = (url, validationResult) => { log(LogLevels.info, `${validationResult.message} - Downloading model files`, { model: config.id, url: config.url, location: config.location, errors: validationResult.errors, }); const downloadPromises = []; if (validationResult.errors.model && config.location) { downloadPromises.push(downloadModelFile({ url: url, filePath: config.location, modelsCachePath: config.modelsCachePath, onProgress, signal, })); } const pushDownload = (src) => { if (!src.url) { return; } downloadPromises.push(downloadModelFile({ url: src.url, filePath: src.file, modelsCachePath: config.modelsCachePath, onProgress, signal, })); }; if (validationResult.errors.clipG && config.clipG) { pushDownload(config.clipG); } if (validationResult.errors.clipL && config.clipL) { pushDownload(config.clipL); } if (validationResult.errors.vae && config.vae) { pushDownload(config.vae); } if (validationResult.errors.t5xxl && config.t5xxl) { pushDownload(config.t5xxl); } if (validationResult.errors.controlNet && config.controlNet) { pushDownload(config.controlNet); } if (validationResult.errors.taesd && config.taesd) { pushDownload(config.taesd); } if (config.loras) { for (const lora of config.loras) { if (!lora.url) { continue; } pushDownload(lora); } } return Promise.all(downloadPromises); }; try { if (signal?.aborted) { return; } const validationResults = await validateModelFiles(config); if (signal?.aborted) { return; } if (validationResults) { if (config.url) { await downloadModel(config.url, validationResults); } else { throw new Error(`${validationResults.message} - No URL provided`); } } const finalValidationError = await validateModelFiles(config); if (finalValidationError) { throw new Error(`Downloaded files are invalid: ${finalValidationError}`); } const result = {}; if (config.location.endsWith('.gguf')) { const { metadata, tensorInfos } = await gguf(config.location, { allowLocalFile: true, }); result.gguf = metadata; } return result; } catch (error) { throw error; } finally { releaseFileLock(); } } export async function createInstance({ config, log }, signal) { log(LogLevels.debug, 'Load Stable Diffusion model', config); const handleLog = (level, message) => { log(level, message); }; const handleProgress = (step, steps, time) => { log(LogLevels.debug, `Progress: ${step}/${steps} (${time}ms)`); }; const resolveComponentLocation = (src) => { if (src) { return resolveModelFileLocation({ url: src.url, filePath: src.file, modelsCachePath: config.modelsCachePath, }); } return undefined; }; const vaeFilePath = resolveComponentLocation(config.vae); const clipLFilePath = resolveComponentLocation(config.clipL); const clipGFilePath = resolveComponentLocation(config.clipG); const t5xxlFilePath = resolveComponentLocation(config.t5xxl); const controlNetFilePath = resolveComponentLocation(config.controlNet); const taesdFilePath = resolveComponentLocation(config.taesd); let weightType = config.weightType ? getWeightType(config.weightType) : undefined; if (typeof weightType === 'undefined') { const quantization = parseQuantization(config.location); if (quantization) { weightType = getWeightType(quantization); } } if (typeof weightType === 'undefined') { log(LogLevels.warn, 'Failed to parse model weight type (quantization) from file name, falling back to f32', { file: config.location, }); } const loraDir = path.join(path.dirname(config.location), 'loras'); const contextParams = { model: !config.diffusionModel ? config.location : undefined, diffusionModel: config.diffusionModel ? config.location : undefined, numThreads: config.device?.cpuThreads, vae: vaeFilePath, clipL: clipLFilePath, clipG: clipGFilePath, t5xxl: t5xxlFilePath, controlNet: controlNetFilePath, taesd: taesdFilePath, weightType: weightType, loraDir: loraDir, // TODO how to expose? // keepClipOnCpu: true, // keepControlNetOnCpu: true, // keepVaeOnCpu: true, }; log(LogLevels.debug, 'Creating context with', contextParams); const context = await StableDiffusion.createContext( // @ts-ignore contextParams, handleLog, handleProgress); return { context, }; } export async function processTextToImageTask(task, ctx, signal) { const { instance, config, log } = ctx; const seed = task.seed ?? getRandomNumber(0, 1000000); const results = await instance.context.txt2img({ prompt: task.prompt, negativePrompt: task.negativePrompt, width: task.width || 512, height: task.height || 512, batchCount: task.batchCount, sampleMethod: getSamplingMethod(task.samplingMethod || config.samplingMethod), sampleSteps: task.sampleSteps, cfgScale: task.cfgScale, guidance: task.guidance, styleRatio: task.styleRatio, controlStrength: task.controlStrength, normalizeInput: false, seed, }); const images = []; for (const [idx, img] of results.entries()) { images.push({ data: img.data, width: img.width, height: img.height, channels: img.channel, }); } if (!images.length) { throw new Error('No images generated'); } return { images: images, seed, }; } export async function processImageToImageTask(task, ctx, signal) { const { instance, config, log } = ctx; const seed = task.seed ?? getRandomNumber(0, 1000000); console.debug('processImageToImageTask', { width: task.image.width, height: task.image.height, channel: task.image.channels, }); const initImage = { // data: await request.image.handle.raw().toBuffer(), data: task.image.data, width: task.image.width, height: task.image.height, channel: task.image.channels, }; const results = await instance.context.img2img({ initImage, prompt: task.prompt, width: task.width || 512, height: task.height || 512, batchCount: task.batchCount, sampleMethod: getSamplingMethod(task.samplingMethod || config.samplingMethod), cfgScale: task.cfgScale, sampleSteps: task.sampleSteps, guidance: task.guidance, strength: task.strength, styleRatio: task.styleRatio, controlStrength: task.controlStrength, seed, }); const images = []; for (const [idx, img] of results.entries()) { // console.debug('img', { // id: idx, // width: img.width, // height: img.height, // channels: img.channel, // }) images.push({ data: img.data, width: img.width, height: img.height, channels: img.channel, }); } if (!images.length) { throw new Error('No images generated'); } return { images: images, seed, }; } //# sourceMappingURL=engine.js.map