UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

97 lines (89 loc) 2.97 kB
import { ResourceSource } from '../../types/common'; import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { Logger } from '../../common/Logger'; /** * Module for Tokenizer functionalities. * @category Typescript API */ export class TokenizerModule { /** * Native module instance */ nativeModule: any; /** * Loads the tokenizer from the specified source. * `tokenizerSource` is a string that points to the location of the tokenizer JSON file. * @param tokenizer - Object containing `tokenizerSource`. * @param onDownloadProgressCallback - Optional callback to monitor download progress. */ async load( tokenizer: { tokenizerSource: ResourceSource }, onDownloadProgressCallback: (progress: number) => void = () => {} ): Promise<void> { try { const paths = await ResourceFetcher.fetch( onDownloadProgressCallback, tokenizer.tokenizerSource ); const path = paths?.[0]; if (!path) { throw new RnExecutorchError( RnExecutorchErrorCode.DownloadInterrupted, 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' ); } this.nativeModule = await global.loadTokenizerModule(path); } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); } } /** * Converts a string into an array of token IDs. * @param input - The input string to be tokenized. * @returns An array of token IDs. */ async encode(input: string): Promise<number[]> { return await this.nativeModule.encode(input); } /** * Converts an array of token IDs into a string. * @param tokens - Array of token IDs to be decoded. * @param skipSpecialTokens - Whether to skip special tokens during decoding (default: true). * @returns The decoded string. */ async decode( tokens: number[], skipSpecialTokens: boolean = true ): Promise<string> { if (tokens.length === 0) { return ''; } return await this.nativeModule.decode(tokens, skipSpecialTokens); } /** * Returns the size of the tokenizer's vocabulary. * @returns The vocabulary size. */ async getVocabSize(): Promise<number> { return await this.nativeModule.getVocabSize(); } /** * Returns the token associated to the ID. * @param tokenId - ID of the token. * @returns The token string associated to ID. */ async idToToken(tokenId: number): Promise<string> { return this.nativeModule.idToToken(tokenId); } /** * Returns the ID associated to the token. * @param token - The token string. * @returns The ID associated to the token. */ async tokenToId(token: string): Promise<number> { return await this.nativeModule.tokenToId(token); } }