UNPKG

@lmagder/node-stable-diffusion-cpp

Version:

Node bindings for https://github.com/leejet/stable-diffusion.cpp

90 lines 4.57 kB
import fs from "node:fs"; import os from "node:os"; import path from "node:path"; import { createRequire } from "node:module"; import loadBinding from "pkg-prebuilds"; import { fileTypeFromBuffer, fileTypeFromStream } from "file-type"; import xz from "xz-decompress"; import { Stream } from "node:stream"; import { buffer } from "node:stream/consumers"; import decompress from "decompress"; import { hasher } from "node-object-hash"; const require = createRequire(import.meta.url); const decompressTar = require("decompress-tar"); const decompressUnzip = require("decompress-unzip"); const bindingOptionsFile = "binding-options.cjs"; const components = ["libcublas", "cuda_cudart"]; const versionListFile = "cuda_version.json"; const downloadMarkerFile = "cudadeps.done"; const repoUrl = "https://developer.download.nvidia.com/compute/cuda/redist"; const nodeToCudaArch = { arm64: "aarch64", ppc64: "ppc64le", x64: "x86_64", }; const nodeToCudaPlatform = { linux: "linux", win32: "windows", }; const decompressTarXz = () => async (input) => { const type = Buffer.isBuffer(input) ? await fileTypeFromBuffer(input) : await fileTypeFromStream(input); if (!type || type.ext !== "xz") { return []; } const tar = decompressTar(); const inStream = Stream.Readable.toWeb(Buffer.isBuffer(input) ? new Stream.PassThrough().end(input) : Stream.Readable.from(input)); const xzStream = new xz.XzReadableStream(inStream); return tar(Stream.Readable.fromWeb(xzStream)); }; const arch = nodeToCudaArch[process.env.npm_config_arch || os.arch()]; const platform = nodeToCudaPlatform[process.env.npm_config_platform || os.platform()]; const cudaSubfolder = `${platform}-${arch}`; const archiveExt = platform === "windows" ? "zip" : "tar.xz"; const fileExt = platform === "windows" ? ".dll" : ".so"; const options = require(path.join(process.cwd(), bindingOptionsFile)); // Find the correct bindings file const resolvedPath = path.dirname(loadBinding.resolve(process.cwd(), options, false, true)); const versionListPath = path.join(resolvedPath, versionListFile); if (fs.existsSync(versionListPath) && process.env.npm_command !== "ci") { const versionList = JSON.parse(fs.readFileSync(versionListPath, { encoding: "utf8" })); const versionListHash = hasher({ sort: true }).hash(versionList) + "_" + cudaSubfolder; const downloadMarkerPath = path.join(resolvedPath, downloadMarkerFile); const componentCount = Object.keys(versionList).length; const needsDownload = !fs.existsSync(downloadMarkerPath) || fs.readFileSync(downloadMarkerPath).toString() !== versionListHash; if (componentCount > 0 && needsDownload) { console.info(`Downloading components ${components} for ${arch} - ${platform}`); for (const componentId of components) { const componentVersion = versionList[componentId].version; const archivePath = `${repoUrl}/${componentId}/${cudaSubfolder}/${componentId}-${cudaSubfolder}-${componentVersion}-archive.${archiveExt}`; console.info(`Downloading ${archivePath}...`); const file = await fetch(archivePath); if (!file.body || !file.ok) throw new Error(`Downloading ${archivePath} failed`); const data = await buffer(file.body); console.info(`Done.`); console.info(`Extracting...`); const archiveFiles = (await decompress(data, { plugins: [decompressTarXz(), decompressUnzip()] })); console.info(`Done`); archiveFiles.sort((x, y) => x.type.localeCompare(y.type)); for (const d of archiveFiles) { if (d.path.toLowerCase().includes(fileExt) && !d.path.toLowerCase().includes("/stubs/")) { const dest = path.join(resolvedPath, path.basename(d.path)); console.info(`Writing ${dest}`); if (fs.existsSync(dest)) fs.unlinkSync(dest); if (d.type === "file") { fs.writeFileSync(dest, d.data, { encoding: "binary", mode: d.mode }); } else if (d.type === "symlink" && d.linkname) { fs.symlinkSync(d.linkname, dest); } else if (d.type === "link" && d.linkname) { fs.linkSync(d.linkname, dest); } } } } fs.writeFileSync(downloadMarkerPath, versionListHash); } } //# sourceMappingURL=cudadeps.js.map