react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
115 lines (109 loc) • 5.34 kB
JavaScript
;
import { ResourceFetcher } from '../../utils/ResourceFetcher';
import { BaseModule } from '../BaseModule';
import { PNG } from 'pngjs/browser';
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils';
import { Logger } from '../../common/Logger';
/**
* Module for text-to-image generation tasks.
* @category Typescript API
*/
export class TextToImageModule extends BaseModule {
constructor(nativeModule, inferenceCallback) {
super();
this.nativeModule = nativeModule;
this.inferenceCallback = stepIdx => {
inferenceCallback?.(stepIdx);
};
}
/**
* Creates a Text to Image instance for a built-in model.
* @param namedSources - An object specifying the model name, pipeline sources, and optional inference callback.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to a `TextToImageModule` instance.
* @example
* ```ts
* import { TextToImageModule, BK_SDM_TINY_VPRED_512 } from 'react-native-executorch';
* const tti = await TextToImageModule.fromModelName(BK_SDM_TINY_VPRED_512);
* ```
*/
static async fromModelName(namedSources, onDownloadProgress = () => {}) {
try {
const nativeModule = await TextToImageModule.load(namedSources, onDownloadProgress);
return new TextToImageModule(nativeModule, namedSources.inferenceCallback);
} catch (error) {
Logger.error('Load failed:', error);
throw parseUnknownError(error);
}
}
/**
* Creates a Text to Image instance with user-provided model binaries.
* Use this when working with a custom-exported diffusion pipeline.
* Internally uses `'custom'` as the model name for telemetry.
* @remarks The native model contract for this method is not formally defined and may change
* between releases. Refer to the native source code for the current expected tensor interface.
* @param sources - An object containing the pipeline source paths.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @param inferenceCallback - Optional callback triggered after each diffusion step.
* @returns A Promise resolving to a `TextToImageModule` instance.
*/
static fromCustomModel(sources, onDownloadProgress = () => {}, inferenceCallback) {
return TextToImageModule.fromModelName({
modelName: 'custom',
...sources,
inferenceCallback
}, onDownloadProgress);
}
static async load(model, onDownloadProgressCallback) {
const results = await ResourceFetcher.fetch(onDownloadProgressCallback, model.tokenizerSource, model.schedulerSource, model.encoderSource, model.unetSource, model.decoderSource);
if (!results || results.length !== 5) {
throw new RnExecutorchError(RnExecutorchErrorCode.DownloadInterrupted, 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.');
}
const [tokenizerPath, schedulerPath, encoderPath, unetPath, decoderPath] = results;
if (!tokenizerPath || !schedulerPath || !encoderPath || !unetPath || !decoderPath) {
throw new RnExecutorchError(RnExecutorchErrorCode.DownloadInterrupted, 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.');
}
const response = await fetch('file://' + schedulerPath);
const schedulerConfig = await response.json();
return global.loadTextToImage(tokenizerPath, encoderPath, unetPath, decoderPath, schedulerConfig.beta_start, schedulerConfig.beta_end, schedulerConfig.num_train_timesteps, schedulerConfig.steps_offset);
}
/**
* Runs the model to generate an image described by `input`, and conditioned by `seed`, performing `numSteps` inference steps.
* The resulting image, with dimensions `imageSize`×`imageSize` pixels, is returned as a base64-encoded string.
* @param input - The text prompt to generate the image from.
* @param imageSize - The desired width and height of the output image in pixels.
* @param numSteps - The number of inference steps to perform.
* @param seed - An optional seed for random number generation to ensure reproducibility.
* @returns A Base64-encoded string representing the generated PNG image.
*/
async forward(input, imageSize = 512, numSteps = 5, seed) {
const output = await this.nativeModule.generate(input, imageSize, numSteps, seed ? seed : -1, this.inferenceCallback);
const outputArray = new Uint8Array(output);
if (!outputArray.length) {
return '';
}
const png = new PNG({
width: imageSize,
height: imageSize
});
png.data = outputArray;
const pngBuffer = PNG.sync.write(png, {
colorType: 6
});
const pngArray = new Uint8Array(pngBuffer);
let binary = '';
const chunkSize = 8192;
for (let i = 0; i < pngArray.length; i += chunkSize) {
binary += String.fromCharCode(...pngArray.subarray(i, i + chunkSize));
}
return btoa(binary);
}
/**
* Interrupts model generation. The model is stopped in the nearest step.
*/
interrupt() {
this.nativeModule.interrupt();
}
}
//# sourceMappingURL=TextToImageModule.js.map