UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

199 lines (187 loc) 6.86 kB
import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { ResourceSource } from '../../types/common'; import { TextToImageModelName } from '../../types/tti'; import { BaseModule } from '../BaseModule'; import { PNG } from 'pngjs/browser'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; import { Logger } from '../../common/Logger'; /** * Module for text-to-image generation tasks. * @category Typescript API */ export class TextToImageModule extends BaseModule { private inferenceCallback: (stepIdx: number) => void; private constructor( nativeModule: unknown, inferenceCallback?: (stepIdx: number) => void ) { super(); this.nativeModule = nativeModule; this.inferenceCallback = (stepIdx: number) => { inferenceCallback?.(stepIdx); }; } /** * Creates a Text to Image instance for a built-in model. * @param namedSources - An object specifying the model name, pipeline sources, and optional inference callback. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. * @returns A Promise resolving to a `TextToImageModule` instance. * @example * ```ts * import { TextToImageModule, BK_SDM_TINY_VPRED_512 } from 'react-native-executorch'; * const tti = await TextToImageModule.fromModelName(BK_SDM_TINY_VPRED_512); * ``` */ static async fromModelName( namedSources: { modelName: TextToImageModelName; tokenizerSource: ResourceSource; schedulerSource: ResourceSource; encoderSource: ResourceSource; unetSource: ResourceSource; decoderSource: ResourceSource; inferenceCallback?: (stepIdx: number) => void; }, onDownloadProgress: (progress: number) => void = () => {} ): Promise<TextToImageModule> { try { const nativeModule = await TextToImageModule.load( namedSources, onDownloadProgress ); return new TextToImageModule( nativeModule, namedSources.inferenceCallback ); } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); } } /** * Creates a Text to Image instance with user-provided model binaries. * Use this when working with a custom-exported diffusion pipeline. * Internally uses `'custom'` as the model name for telemetry. * @remarks The native model contract for this method is not formally defined and may change * between releases. Refer to the native source code for the current expected tensor interface. * @param sources - An object containing the pipeline source paths. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. * @param inferenceCallback - Optional callback triggered after each diffusion step. * @returns A Promise resolving to a `TextToImageModule` instance. */ static fromCustomModel( sources: { tokenizerSource: ResourceSource; schedulerSource: ResourceSource; encoderSource: ResourceSource; unetSource: ResourceSource; decoderSource: ResourceSource; }, onDownloadProgress: (progress: number) => void = () => {}, inferenceCallback?: (stepIdx: number) => void ): Promise<TextToImageModule> { return TextToImageModule.fromModelName( { modelName: 'custom' as TextToImageModelName, ...sources, inferenceCallback, }, onDownloadProgress ); } private static async load( model: { tokenizerSource: ResourceSource; schedulerSource: ResourceSource; encoderSource: ResourceSource; unetSource: ResourceSource; decoderSource: ResourceSource; }, onDownloadProgressCallback: (progress: number) => void ): Promise<unknown> { const results = await ResourceFetcher.fetch( onDownloadProgressCallback, model.tokenizerSource, model.schedulerSource, model.encoderSource, model.unetSource, model.decoderSource ); if (!results || results.length !== 5) { throw new RnExecutorchError( RnExecutorchErrorCode.DownloadInterrupted, 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' ); } const [tokenizerPath, schedulerPath, encoderPath, unetPath, decoderPath] = results; if ( !tokenizerPath || !schedulerPath || !encoderPath || !unetPath || !decoderPath ) { throw new RnExecutorchError( RnExecutorchErrorCode.DownloadInterrupted, 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' ); } const response = await fetch('file://' + schedulerPath); const schedulerConfig = await response.json(); return global.loadTextToImage( tokenizerPath, encoderPath, unetPath, decoderPath, schedulerConfig.beta_start, schedulerConfig.beta_end, schedulerConfig.num_train_timesteps, schedulerConfig.steps_offset ); } /** * Runs the model to generate an image described by `input`, and conditioned by `seed`, performing `numSteps` inference steps. * The resulting image, with dimensions `imageSize`×`imageSize` pixels, is returned as a base64-encoded string. * @param input - The text prompt to generate the image from. * @param imageSize - The desired width and height of the output image in pixels. * @param numSteps - The number of inference steps to perform. * @param seed - An optional seed for random number generation to ensure reproducibility. * @returns A Base64-encoded string representing the generated PNG image. */ async forward( input: string, imageSize: number = 512, numSteps: number = 5, seed?: number ): Promise<string> { const output = await this.nativeModule.generate( input, imageSize, numSteps, seed ? seed : -1, this.inferenceCallback ); const outputArray = new Uint8Array(output); if (!outputArray.length) { return ''; } const png = new PNG({ width: imageSize, height: imageSize }); png.data = outputArray as unknown as Buffer; const pngBuffer = PNG.sync.write(png, { colorType: 6 }); const pngArray = new Uint8Array(pngBuffer as unknown as ArrayBufferLike); let binary = ''; const chunkSize = 8192; for (let i = 0; i < pngArray.length; i += chunkSize) { binary += String.fromCharCode(...pngArray.subarray(i, i + chunkSize)); } return btoa(binary); } /** * Interrupts model generation. The model is stopped in the nearest step. */ public interrupt(): void { this.nativeModule.interrupt(); } }