react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
97 lines (89 loc) • 2.97 kB
text/typescript
import { ResourceSource } from '../../types/common';
import { ResourceFetcher } from '../../utils/ResourceFetcher';
import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils';
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
import { Logger } from '../../common/Logger';
/**
* Module for Tokenizer functionalities.
* @category Typescript API
*/
export class TokenizerModule {
/**
* Native module instance
*/
nativeModule: any;
/**
* Loads the tokenizer from the specified source.
* `tokenizerSource` is a string that points to the location of the tokenizer JSON file.
* @param tokenizer - Object containing `tokenizerSource`.
* @param onDownloadProgressCallback - Optional callback to monitor download progress.
*/
async load(
tokenizer: { tokenizerSource: ResourceSource },
onDownloadProgressCallback: (progress: number) => void = () => {}
): Promise<void> {
try {
const paths = await ResourceFetcher.fetch(
onDownloadProgressCallback,
tokenizer.tokenizerSource
);
const path = paths?.[0];
if (!path) {
throw new RnExecutorchError(
RnExecutorchErrorCode.DownloadInterrupted,
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
);
}
this.nativeModule = await global.loadTokenizerModule(path);
} catch (error) {
Logger.error('Load failed:', error);
throw parseUnknownError(error);
}
}
/**
* Converts a string into an array of token IDs.
* @param input - The input string to be tokenized.
* @returns An array of token IDs.
*/
async encode(input: string): Promise<number[]> {
return await this.nativeModule.encode(input);
}
/**
* Converts an array of token IDs into a string.
* @param tokens - Array of token IDs to be decoded.
* @param skipSpecialTokens - Whether to skip special tokens during decoding (default: true).
* @returns The decoded string.
*/
async decode(
tokens: number[],
skipSpecialTokens: boolean = true
): Promise<string> {
if (tokens.length === 0) {
return '';
}
return await this.nativeModule.decode(tokens, skipSpecialTokens);
}
/**
* Returns the size of the tokenizer's vocabulary.
* @returns The vocabulary size.
*/
async getVocabSize(): Promise<number> {
return await this.nativeModule.getVocabSize();
}
/**
* Returns the token associated to the ID.
* @param tokenId - ID of the token.
* @returns The token string associated to ID.
*/
async idToToken(tokenId: number): Promise<string> {
return this.nativeModule.idToToken(tokenId);
}
/**
* Returns the ID associated to the token.
* @param token - The token string.
* @returns The ID associated to the token.
*/
async tokenToId(token: string): Promise<number> {
return await this.nativeModule.tokenToId(token);
}
}