react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
199 lines (187 loc) • 6.86 kB
text/typescript
import { ResourceFetcher } from '../../utils/ResourceFetcher';
import { ResourceSource } from '../../types/common';
import { TextToImageModelName } from '../../types/tti';
import { BaseModule } from '../BaseModule';
import { PNG } from 'pngjs/browser';
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils';
import { Logger } from '../../common/Logger';
/**
* Module for text-to-image generation tasks.
* @category Typescript API
*/
export class TextToImageModule extends BaseModule {
private inferenceCallback: (stepIdx: number) => void;
private constructor(
nativeModule: unknown,
inferenceCallback?: (stepIdx: number) => void
) {
super();
this.nativeModule = nativeModule;
this.inferenceCallback = (stepIdx: number) => {
inferenceCallback?.(stepIdx);
};
}
/**
* Creates a Text to Image instance for a built-in model.
* @param namedSources - An object specifying the model name, pipeline sources, and optional inference callback.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to a `TextToImageModule` instance.
* @example
* ```ts
* import { TextToImageModule, BK_SDM_TINY_VPRED_512 } from 'react-native-executorch';
* const tti = await TextToImageModule.fromModelName(BK_SDM_TINY_VPRED_512);
* ```
*/
static async fromModelName(
namedSources: {
modelName: TextToImageModelName;
tokenizerSource: ResourceSource;
schedulerSource: ResourceSource;
encoderSource: ResourceSource;
unetSource: ResourceSource;
decoderSource: ResourceSource;
inferenceCallback?: (stepIdx: number) => void;
},
onDownloadProgress: (progress: number) => void = () => {}
): Promise<TextToImageModule> {
try {
const nativeModule = await TextToImageModule.load(
namedSources,
onDownloadProgress
);
return new TextToImageModule(
nativeModule,
namedSources.inferenceCallback
);
} catch (error) {
Logger.error('Load failed:', error);
throw parseUnknownError(error);
}
}
/**
* Creates a Text to Image instance with user-provided model binaries.
* Use this when working with a custom-exported diffusion pipeline.
* Internally uses `'custom'` as the model name for telemetry.
* @remarks The native model contract for this method is not formally defined and may change
* between releases. Refer to the native source code for the current expected tensor interface.
* @param sources - An object containing the pipeline source paths.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @param inferenceCallback - Optional callback triggered after each diffusion step.
* @returns A Promise resolving to a `TextToImageModule` instance.
*/
static fromCustomModel(
sources: {
tokenizerSource: ResourceSource;
schedulerSource: ResourceSource;
encoderSource: ResourceSource;
unetSource: ResourceSource;
decoderSource: ResourceSource;
},
onDownloadProgress: (progress: number) => void = () => {},
inferenceCallback?: (stepIdx: number) => void
): Promise<TextToImageModule> {
return TextToImageModule.fromModelName(
{
modelName: 'custom' as TextToImageModelName,
...sources,
inferenceCallback,
},
onDownloadProgress
);
}
private static async load(
model: {
tokenizerSource: ResourceSource;
schedulerSource: ResourceSource;
encoderSource: ResourceSource;
unetSource: ResourceSource;
decoderSource: ResourceSource;
},
onDownloadProgressCallback: (progress: number) => void
): Promise<unknown> {
const results = await ResourceFetcher.fetch(
onDownloadProgressCallback,
model.tokenizerSource,
model.schedulerSource,
model.encoderSource,
model.unetSource,
model.decoderSource
);
if (!results || results.length !== 5) {
throw new RnExecutorchError(
RnExecutorchErrorCode.DownloadInterrupted,
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
);
}
const [tokenizerPath, schedulerPath, encoderPath, unetPath, decoderPath] =
results;
if (
!tokenizerPath ||
!schedulerPath ||
!encoderPath ||
!unetPath ||
!decoderPath
) {
throw new RnExecutorchError(
RnExecutorchErrorCode.DownloadInterrupted,
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
);
}
const response = await fetch('file://' + schedulerPath);
const schedulerConfig = await response.json();
return global.loadTextToImage(
tokenizerPath,
encoderPath,
unetPath,
decoderPath,
schedulerConfig.beta_start,
schedulerConfig.beta_end,
schedulerConfig.num_train_timesteps,
schedulerConfig.steps_offset
);
}
/**
* Runs the model to generate an image described by `input`, and conditioned by `seed`, performing `numSteps` inference steps.
* The resulting image, with dimensions `imageSize`×`imageSize` pixels, is returned as a base64-encoded string.
* @param input - The text prompt to generate the image from.
* @param imageSize - The desired width and height of the output image in pixels.
* @param numSteps - The number of inference steps to perform.
* @param seed - An optional seed for random number generation to ensure reproducibility.
* @returns A Base64-encoded string representing the generated PNG image.
*/
async forward(
input: string,
imageSize: number = 512,
numSteps: number = 5,
seed?: number
): Promise<string> {
const output = await this.nativeModule.generate(
input,
imageSize,
numSteps,
seed ? seed : -1,
this.inferenceCallback
);
const outputArray = new Uint8Array(output);
if (!outputArray.length) {
return '';
}
const png = new PNG({ width: imageSize, height: imageSize });
png.data = outputArray as unknown as Buffer;
const pngBuffer = PNG.sync.write(png, { colorType: 6 });
const pngArray = new Uint8Array(pngBuffer as unknown as ArrayBufferLike);
let binary = '';
const chunkSize = 8192;
for (let i = 0; i < pngArray.length; i += chunkSize) {
binary += String.fromCharCode(...pngArray.subarray(i, i + chunkSize));
}
return btoa(binary);
}
/**
* Interrupts model generation. The model is stopped in the nearest step.
*/
public interrupt(): void {
this.nativeModule.interrupt();
}
}