onnxruntime-node
Version:
ONNXRuntime Node.js binding
307 lines (275 loc) • 10 kB
JavaScript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
;
const fs = require('fs');
const https = require('https');
const { execFileSync } = require('child_process');
const path = require('path');
const os = require('os');
const AdmZip = require('adm-zip'); // Use adm-zip instead of spawn
async function downloadFile(url, dest) {
return new Promise((resolve, reject) => {
const file = fs.createWriteStream(dest);
https
.get(url, (res) => {
if (res.statusCode !== 200) {
file.close();
fs.unlinkSync(dest);
reject(new Error(`Failed to download from ${url}. HTTP status code = ${res.statusCode}`));
return;
}
res.pipe(file);
file.on('finish', () => {
file.close();
resolve();
});
file.on('error', (err) => {
fs.unlinkSync(dest);
reject(err);
});
})
.on('error', (err) => {
fs.unlinkSync(dest);
reject(err);
});
});
}
async function downloadJson(url) {
return new Promise((resolve, reject) => {
https
.get(url, (res) => {
const { statusCode } = res;
const contentType = res.headers['content-type'];
if (!statusCode) {
reject(new Error('No response statud code from server.'));
return;
}
if (statusCode >= 400 && statusCode < 500) {
resolve(null);
return;
} else if (statusCode !== 200) {
reject(new Error(`Failed to download build list. HTTP status code = ${statusCode}`));
return;
}
if (!contentType || !/^application\/json/.test(contentType)) {
reject(new Error(`unexpected content type: ${contentType}`));
return;
}
res.setEncoding('utf8');
let rawData = '';
res.on('data', (chunk) => {
rawData += chunk;
});
res.on('end', () => {
try {
resolve(JSON.parse(rawData));
} catch (e) {
reject(e);
}
});
res.on('error', (err) => {
reject(err);
});
})
.on('error', (err) => {
reject(err);
});
});
}
async function installPackages(packages, manifests, feeds) {
// Step.1: resolve packages
const resolvedPackages = new Map();
for (const packageCandidates of packages) {
// iterate all candidates from packagesInfo and try to find the first one that exists
for (const { feed, version } of packageCandidates.versions) {
const { type, index } = feeds[feed];
const pkg = await resolvePackage(type, index, packageCandidates.name, version);
if (pkg) {
resolvedPackages.set(packageCandidates, pkg);
break;
}
}
if (!resolvedPackages.has(packageCandidates)) {
throw new Error(`Failed to resolve package. No package exists for: ${JSON.stringify(packageCandidates)}`);
}
}
// Step.2: download packages
for (const [pkgInfo, pkg] of resolvedPackages) {
const manifestsForPackage = manifests.filter((x) => x.packagesInfo === pkgInfo);
await pkg.download(manifestsForPackage);
}
}
async function resolvePackage(type, index, packageName, version) {
// https://learn.microsoft.com/en-us/nuget/api/overview
const nugetPackageUrlResolver = async (index, packageName, version) => {
// STEP.1 - get Nuget package index
const nugetIndex = await downloadJson(index);
if (!nugetIndex) {
throw new Error(`Failed to download Nuget index from ${index}`);
}
// STEP.2 - get the base url of "PackageBaseAddress/3.0.0"
const packageBaseUrl = nugetIndex.resources.find((x) => x['@type'] === 'PackageBaseAddress/3.0.0')?.['@id'];
if (!packageBaseUrl) {
throw new Error(`Failed to find PackageBaseAddress in Nuget index`);
}
// STEP.3 - get the package version info
const packageInfo = await downloadJson(`${packageBaseUrl}${packageName.toLowerCase()}/index.json`);
if (!packageInfo.versions.includes(version.toLowerCase())) {
throw new Error(`Failed to find specific package versions for ${packageName} in ${index}`);
}
// STEP.4 - generate the package URL
const packageUrl = `${packageBaseUrl}${packageName.toLowerCase()}/${version.toLowerCase()}/${packageName.toLowerCase()}.${version.toLowerCase()}.nupkg`;
const packageFileName = `${packageName.toLowerCase()}.${version.toLowerCase()}.nupkg`;
return {
download: async (manifests) => {
if (manifests.length === 0) {
return;
}
// Create a temporary directory
const tempDir = path.join(os.tmpdir(), `onnxruntime-node-pkgs_${Date.now()}`);
fs.mkdirSync(tempDir, { recursive: true });
try {
const packageFilePath = path.join(tempDir, packageFileName);
// Download the NuGet package
console.log(`Downloading ${packageUrl}`);
await downloadFile(packageUrl, packageFilePath);
// Load the NuGet package (which is a ZIP file)
let zip;
try {
zip = new AdmZip(packageFilePath);
} catch (err) {
throw new Error(`Failed to open NuGet package: ${err.message}`);
}
// Extract only the needed files from the package
const extractDir = path.join(tempDir, 'extracted');
fs.mkdirSync(extractDir, { recursive: true });
// Process each manifest and extract/copy files to their destinations
for (const manifest of manifests) {
const { filepath, pathInPackage } = manifest;
// Create directory for the target file
const targetDir = path.dirname(filepath);
fs.mkdirSync(targetDir, { recursive: true });
// Check if the file exists directly in the zip
const zipEntry = zip.getEntry(pathInPackage);
if (!zipEntry) {
throw new Error(`Failed to find ${pathInPackage} in NuGet package`);
}
console.log(`Extracting ${pathInPackage} to ${filepath}`);
// Extract just this entry to a temporary location
const extractedFilePath = path.join(extractDir, path.basename(pathInPackage));
zip.extractEntryTo(zipEntry, extractDir, false, true);
// Copy to the final destination
fs.copyFileSync(extractedFilePath, filepath);
}
} finally {
// Clean up the temporary directory - always runs even if an error occurs
try {
fs.rmSync(tempDir, { recursive: true });
} catch (e) {
console.warn(`Failed to clean up temporary directory: ${tempDir}`, e);
// Don't rethrow this error as it would mask the original error
}
}
},
};
};
switch (type) {
case 'nuget':
return await nugetPackageUrlResolver(index, packageName, version);
default:
throw new Error(`Unsupported package type: ${type}`);
}
}
function tryGetCudaVersion() {
// Should only return 11 or 12.
// try to get the CUDA version from the system ( `nvcc --version` )
let ver = 12;
try {
const nvccVersion = execFileSync('nvcc', ['--version'], { encoding: 'utf8' });
const match = nvccVersion.match(/release (\d+)/);
if (match) {
ver = parseInt(match[1]);
if (ver !== 11 && ver !== 12) {
throw new Error(`Unsupported CUDA version: ${ver}`);
}
}
} catch (e) {
if (e?.code === 'ENOENT') {
console.warn('`nvcc` not found. Assuming CUDA 12.');
} else {
console.warn('Failed to detect CUDA version from `nvcc --version`:', e.message);
}
}
// assume CUDA 12 if failed to detect
return ver;
}
function parseInstallFlag() {
let flag = process.env.ONNXRUNTIME_NODE_INSTALL || process.env.npm_config_onnxruntime_node_install;
if (!flag) {
for (let i = 0; i < process.argv.length; i++) {
if (process.argv[i].startsWith('--onnxruntime-node-install=')) {
flag = process.argv[i].split('=')[1];
break;
} else if (process.argv[i] === '--onnxruntime-node-install') {
flag = 'true';
}
}
}
switch (flag) {
case 'true':
case '1':
case 'ON':
return true;
case 'skip':
return false;
case undefined: {
flag = parseInstallCudaFlag();
if (flag === 'skip') {
return false;
}
if (flag === 11) {
throw new Error('CUDA 11 is no longer supported. Please consider using CPU or upgrade to CUDA 12.');
}
if (flag === 12) {
return 'cuda12';
}
return undefined;
}
default:
if (!flag || typeof flag !== 'string') {
throw new Error(`Invalid value for --onnxruntime-node-install: ${flag}`);
}
}
}
function parseInstallCudaFlag() {
let flag = process.env.ONNXRUNTIME_NODE_INSTALL_CUDA || process.env.npm_config_onnxruntime_node_install_cuda;
if (!flag) {
for (let i = 0; i < process.argv.length; i++) {
if (process.argv[i].startsWith('--onnxruntime-node-install-cuda=')) {
flag = process.argv[i].split('=')[1];
break;
} else if (process.argv[i] === '--onnxruntime-node-install-cuda') {
flag = 'true';
}
}
}
switch (flag) {
case 'true':
case '1':
case 'ON':
return tryGetCudaVersion();
case 'v11':
return 11;
case 'v12':
return 12;
case 'skip':
case undefined:
return flag;
default:
throw new Error(`Invalid value for --onnxruntime-node-install-cuda: ${flag}`);
}
}
module.exports = {
installPackages,
parseInstallFlag,
};