wrapture
Version:
Wrapture lets you go from a Python-trained model to deployable JavaScript with a single command. It generates TypeScript bindings and a Web/Node-compatible wrapper, using WebGPU/WASM-ready ONNX runtimes.
111 lines (95 loc) • 3.3 kB
text/typescript
/* eslint-disable import/no-unused-modules */
import fs from 'node:fs';
import path from 'node:path';
import { log } from './log-level.js';
/**
* Options for generating ONNX wrapper files.
*/
export interface GenerateWrapperOptionsInterface {
/**
* The backend to use for inference. This affects the model file used.
* If set to `'wasm'`, the generated wrapper will load `model_quant.onnx`,
* otherwise it will load `model.onnx`.
*/
backend: 'wasm' | 'webgl' | string;
}
/**
* Generates a TypeScript wrapper and type definition file (`wrapped.ts` and `wrapped.d.ts`)
* for use with `onnxruntime-web`, including utility functions like `softmax`, `argmax`,
* and a typed `predict()` function.
*
* The generated code loads the correct ONNX model based on the provided backend.
*
* @param {string} outputDir - The directory where the wrapper files will be written.
* @param {GenerateWrapperOptionsInterface} opts - Wrapper generation options, including backend type.
* @returns A Promise that resolves when the wrapper files are successfully written.
*
* @throws Will throw an error if file writing fails.
*
* @example
* ```ts
* await generateWrapper('./dist', { backend: 'wasm' });
* // Creates `wrapped.ts` and `wrapped.d.ts` in ./dist
* ```
*
* @see https://www.npmjs.com/package/onnxruntime-web
*/
export const generateWrapper = async (
outputDir: string,
opts: GenerateWrapperOptionsInterface
): Promise<void> => {
log.info('Generating wrapper files...');
const wrapper = `import { InferenceSession, Tensor } from 'onnxruntime-web';
const softmax = (logits) => {
const exps = logits.map(Math.exp);
const sum = exps.reduce((a, b) => a + b, 0);
return exps.map(e => e / sum);
}
const argmax = (arr) => {
return arr.reduce((maxIdx, val, idx, src) => val > src[maxIdx] ? idx : maxIdx, 0);
}
export const loadModel = async () => {
const session = await InferenceSession.create(
new URL('./${opts.backend === 'wasm' ? 'model_quant.onnx' : 'model.onnx'}', import.meta.url).href
);
return {
predict: async (input) => {
const feeds = { input: new Tensor('float32', input.data, input.dims) };
const results = await session.run(feeds);
const raw = results.output.data;
if (!(raw instanceof Float32Array)) {
throw new Error('Expected Float32Array logits but got something else');
}
const logits = raw;
const probabilities = softmax(Array.from(logits));
const predictedClass = argmax(probabilities);
return { logits, probabilities, predictedClass };
}
};
};
`;
const typings = `export interface ModelInput {
data: Float32Array;
dims: number[];
}
export interface ModelOutput {
logits: Float32Array;
probabilities: number[];
predictedClass: number;
}
export interface LoadedModel {
predict(input: ModelInput): Promise<ModelOutput>;
}
/**
* Load the ONNX model and return a wrapper with \`predict()\` function.
*/
export function loadModel(): Promise<LoadedModel>;`;
try {
fs.writeFileSync(path.join(outputDir, 'wrapped.ts'), wrapper);
fs.writeFileSync(path.join(outputDir, 'wrapped.d.ts'), typings);
log.info('Wrapper files generated');
} catch (error) {
log.error('Failed generate wrapper files');
throw error;
}
};