inference-server
Version:
Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.
266 lines • 9.31 kB
JavaScript
import StableDiffusion from '@lmagder/node-stable-diffusion-cpp';
import { gguf } from '@huggingface/gguf';
import fs from 'node:fs';
import path from 'node:path';
import { LogLevels } from '../../lib/logger.js';
import { downloadModelFile } from '../../lib/downloadModelFile.js';
import { resolveModelFileLocation } from '../../lib/resolveModelFileLocation.js';
import { acquireFileLock } from '../../lib/acquireFileLock.js';
import { getRandomNumber } from '../../lib/util.js';
import { validateModelFiles } from './validateModelFiles.js';
import { parseQuantization, getWeightType, getSamplingMethod } from './util.js';
export const autoGpu = true;
export async function prepareModel({ config, log }, onProgress, signal) {
fs.mkdirSync(path.dirname(config.location), { recursive: true });
const releaseFileLock = await acquireFileLock(config.location);
if (signal?.aborted) {
releaseFileLock();
return;
}
log(LogLevels.info, `Preparing stable-diffusion model at ${config.location}`, {
model: config.id,
});
const downloadModel = (url, validationResult) => {
log(LogLevels.info, `${validationResult.message} - Downloading model files`, {
model: config.id,
url: config.url,
location: config.location,
errors: validationResult.errors,
});
const downloadPromises = [];
if (validationResult.errors.model && config.location) {
downloadPromises.push(downloadModelFile({
url: url,
filePath: config.location,
modelsCachePath: config.modelsCachePath,
onProgress,
signal,
}));
}
const pushDownload = (src) => {
if (!src.url) {
return;
}
downloadPromises.push(downloadModelFile({
url: src.url,
filePath: src.file,
modelsCachePath: config.modelsCachePath,
onProgress,
signal,
}));
};
if (validationResult.errors.clipG && config.clipG) {
pushDownload(config.clipG);
}
if (validationResult.errors.clipL && config.clipL) {
pushDownload(config.clipL);
}
if (validationResult.errors.vae && config.vae) {
pushDownload(config.vae);
}
if (validationResult.errors.t5xxl && config.t5xxl) {
pushDownload(config.t5xxl);
}
if (validationResult.errors.controlNet && config.controlNet) {
pushDownload(config.controlNet);
}
if (validationResult.errors.taesd && config.taesd) {
pushDownload(config.taesd);
}
if (config.loras) {
for (const lora of config.loras) {
if (!lora.url) {
continue;
}
pushDownload(lora);
}
}
return Promise.all(downloadPromises);
};
try {
if (signal?.aborted) {
return;
}
const validationResults = await validateModelFiles(config);
if (signal?.aborted) {
return;
}
if (validationResults) {
if (config.url) {
await downloadModel(config.url, validationResults);
}
else {
throw new Error(`${validationResults.message} - No URL provided`);
}
}
const finalValidationError = await validateModelFiles(config);
if (finalValidationError) {
throw new Error(`Downloaded files are invalid: ${finalValidationError}`);
}
const result = {};
if (config.location.endsWith('.gguf')) {
const { metadata, tensorInfos } = await gguf(config.location, {
allowLocalFile: true,
});
result.gguf = metadata;
}
return result;
}
catch (error) {
throw error;
}
finally {
releaseFileLock();
}
}
export async function createInstance({ config, log }, signal) {
log(LogLevels.debug, 'Load Stable Diffusion model', config);
const handleLog = (level, message) => {
log(level, message);
};
const handleProgress = (step, steps, time) => {
log(LogLevels.debug, `Progress: ${step}/${steps} (${time}ms)`);
};
const resolveComponentLocation = (src) => {
if (src) {
return resolveModelFileLocation({
url: src.url,
filePath: src.file,
modelsCachePath: config.modelsCachePath,
});
}
return undefined;
};
const vaeFilePath = resolveComponentLocation(config.vae);
const clipLFilePath = resolveComponentLocation(config.clipL);
const clipGFilePath = resolveComponentLocation(config.clipG);
const t5xxlFilePath = resolveComponentLocation(config.t5xxl);
const controlNetFilePath = resolveComponentLocation(config.controlNet);
const taesdFilePath = resolveComponentLocation(config.taesd);
let weightType = config.weightType ? getWeightType(config.weightType) : undefined;
if (typeof weightType === 'undefined') {
const quantization = parseQuantization(config.location);
if (quantization) {
weightType = getWeightType(quantization);
}
}
if (typeof weightType === 'undefined') {
log(LogLevels.warn, 'Failed to parse model weight type (quantization) from file name, falling back to f32', {
file: config.location,
});
}
const loraDir = path.join(path.dirname(config.location), 'loras');
const contextParams = {
model: !config.diffusionModel ? config.location : undefined,
diffusionModel: config.diffusionModel ? config.location : undefined,
numThreads: config.device?.cpuThreads,
vae: vaeFilePath,
clipL: clipLFilePath,
clipG: clipGFilePath,
t5xxl: t5xxlFilePath,
controlNet: controlNetFilePath,
taesd: taesdFilePath,
weightType: weightType,
loraDir: loraDir,
// TODO how to expose?
// keepClipOnCpu: true,
// keepControlNetOnCpu: true,
// keepVaeOnCpu: true,
};
log(LogLevels.debug, 'Creating context with', contextParams);
const context = await StableDiffusion.createContext(
// @ts-ignore
contextParams, handleLog, handleProgress);
return {
context,
};
}
export async function processTextToImageTask(task, ctx, signal) {
const { instance, config, log } = ctx;
const seed = task.seed ?? getRandomNumber(0, 1000000);
const results = await instance.context.txt2img({
prompt: task.prompt,
negativePrompt: task.negativePrompt,
width: task.width || 512,
height: task.height || 512,
batchCount: task.batchCount,
sampleMethod: getSamplingMethod(task.samplingMethod || config.samplingMethod),
sampleSteps: task.sampleSteps,
cfgScale: task.cfgScale,
guidance: task.guidance,
styleRatio: task.styleRatio,
controlStrength: task.controlStrength,
normalizeInput: false,
seed,
});
const images = [];
for (const [idx, img] of results.entries()) {
images.push({
data: img.data,
width: img.width,
height: img.height,
channels: img.channel,
});
}
if (!images.length) {
throw new Error('No images generated');
}
return {
images: images,
seed,
};
}
export async function processImageToImageTask(task, ctx, signal) {
const { instance, config, log } = ctx;
const seed = task.seed ?? getRandomNumber(0, 1000000);
console.debug('processImageToImageTask', {
width: task.image.width,
height: task.image.height,
channel: task.image.channels,
});
const initImage = {
// data: await request.image.handle.raw().toBuffer(),
data: task.image.data,
width: task.image.width,
height: task.image.height,
channel: task.image.channels,
};
const results = await instance.context.img2img({
initImage,
prompt: task.prompt,
width: task.width || 512,
height: task.height || 512,
batchCount: task.batchCount,
sampleMethod: getSamplingMethod(task.samplingMethod || config.samplingMethod),
cfgScale: task.cfgScale,
sampleSteps: task.sampleSteps,
guidance: task.guidance,
strength: task.strength,
styleRatio: task.styleRatio,
controlStrength: task.controlStrength,
seed,
});
const images = [];
for (const [idx, img] of results.entries()) {
// console.debug('img', {
// id: idx,
// width: img.width,
// height: img.height,
// channels: img.channel,
// })
images.push({
data: img.data,
width: img.width,
height: img.height,
channels: img.channel,
});
}
if (!images.length) {
throw new Error('No images generated');
}
return {
images: images,
seed,
};
}
//# sourceMappingURL=engine.js.map