cui-llama.rn
Version:
Fork of llama.rn for ChatterUI
774 lines (693 loc) • 21.1 kB
text/typescript
import { NativeEventEmitter, DeviceEventEmitter, Platform } from 'react-native'
import type { DeviceEventEmitterStatic } from 'react-native'
import RNLlama from './NativeRNLlama'
import type {
NativeContextParams,
NativeLlamaContext,
NativeCompletionParams,
NativeCompletionTokenProb,
NativeCompletionResult,
NativeTokenizeResult,
NativeEmbeddingResult,
NativeSessionLoadResult,
NativeCPUFeatures,
NativeEmbeddingParams,
NativeRerankParams,
NativeRerankResult,
NativeCompletionTokenProbItem,
NativeCompletionResultTimings,
JinjaFormattedChatResult,
FormattedChatResult,
NativeImageProcessingResult,
NativeLlamaChatMessage,
} from './NativeRNLlama'
import type {
SchemaGrammarConverterPropOrder,
SchemaGrammarConverterBuiltinRule,
} from './grammar'
import { SchemaGrammarConverter, convertJsonSchemaToGrammar } from './grammar'
export type RNLlamaMessagePart = {
type: string
text?: string
image_url?: {
url?: string
}
input_audio?: {
format: string
data?: string
url?: string
}
}
export type RNLlamaOAICompatibleMessage = {
role: string
content?: string | RNLlamaMessagePart[]
}
export type {
NativeContextParams,
NativeLlamaContext,
NativeCompletionParams,
NativeCompletionTokenProb,
NativeCompletionResult,
NativeTokenizeResult,
NativeEmbeddingResult,
NativeSessionLoadResult,
NativeEmbeddingParams,
NativeRerankParams,
NativeRerankResult,
NativeCompletionTokenProbItem,
NativeCompletionResultTimings,
FormattedChatResult,
JinjaFormattedChatResult,
NativeImageProcessingResult,
// Deprecated
SchemaGrammarConverterPropOrder,
SchemaGrammarConverterBuiltinRule,
}
export const RNLLAMA_MTMD_DEFAULT_MEDIA_MARKER = '<__media__>'
export { SchemaGrammarConverter, convertJsonSchemaToGrammar }
const EVENT_ON_INIT_CONTEXT_PROGRESS = '@RNLlama_onInitContextProgress'
const EVENT_ON_TOKEN = '@RNLlama_onToken'
const EVENT_ON_NATIVE_LOG = '@RNLlama_onNativeLog'
let EventEmitter: NativeEventEmitter | DeviceEventEmitterStatic
if (Platform.OS === 'ios') {
// @ts-ignore
EventEmitter = new NativeEventEmitter(RNLlama)
}
if (Platform.OS === 'android') {
EventEmitter = DeviceEventEmitter
}
const logListeners: Array<(level: string, text: string) => void> = []
// @ts-ignore
if (EventEmitter) {
EventEmitter.addListener(
EVENT_ON_NATIVE_LOG,
(evt: { level: string; text: string }) => {
logListeners.forEach((listener) => listener(evt.level, evt.text))
},
)
// Trigger unset to use default log callback
RNLlama?.toggleNativeLog?.(false)?.catch?.(() => {})
}
export type TokenData = {
token: string
completion_probabilities?: Array<NativeCompletionTokenProb>
}
type TokenNativeEvent = {
contextId: number
tokenResult: TokenData
}
export enum CACHE_TYPE {
F16 = 'f16',
F32 = 'f32',
Q8_0 = 'q8_0',
Q4_0 = 'q4_0',
Q4_1 = 'q4_1',
IQ4_NL = 'iq4_nl',
Q5_0 = 'q5_0',
Q5_1 = 'q5_1'
}
export type ContextParams = Omit<
NativeContextParams,
'cache_type_k' | 'cache_type_v' | 'pooling_type'
> & {
cache_type_k?: CACHE_TYPE
cache_type_v?: CACHE_TYPE
pooling_type?: 'none' | 'mean' | 'cls' | 'last' | 'rank'
}
export type EmbeddingParams = NativeEmbeddingParams
export type RerankParams = {
normalize?: number
}
export type RerankResult = {
score: number
index: number
document?: string
}
export type CompletionResponseFormat = {
type: 'text' | 'json_object' | 'json_schema'
json_schema?: {
strict?: boolean
schema: object
}
schema?: object // for json_object type
}
export type CompletionBaseParams = {
prompt?: string
messages?: RNLlamaOAICompatibleMessage[]
chatTemplate?: string // deprecated
chat_template?: string
jinja?: boolean
tools?: object
parallel_tool_calls?: object
tool_choice?: string
response_format?: CompletionResponseFormat
media_paths?: string | string[]
}
export type CompletionParams = Omit<
NativeCompletionParams,
'emit_partial_completion' | 'prompt'
> &
CompletionBaseParams
export type BenchResult = {
modelDesc: string
modelSize: number
modelNParams: number
ppAvg: number
ppStd: number
tgAvg: number
tgStd: number
}
const getJsonSchema = (responseFormat?: CompletionResponseFormat) => {
if (responseFormat?.type === 'json_schema') {
return responseFormat.json_schema?.schema
}
if (responseFormat?.type === 'json_object') {
return responseFormat.schema || {}
}
return null
}
export class LlamaContext {
id: number
gpu: boolean = false
reasonNoGPU: string = ''
model: NativeLlamaContext['model']
constructor({ contextId, gpu, reasonNoGPU, model }: NativeLlamaContext) {
this.id = contextId
this.gpu = gpu
this.reasonNoGPU = reasonNoGPU
this.model = model
}
/**
* Load cached prompt & completion state from a file.
*/
async loadSession(filepath: string): Promise<NativeSessionLoadResult> {
let path = filepath
if (path.startsWith('file://')) path = path.slice(7)
return RNLlama.loadSession(this.id, path)
}
/**
* Save current cached prompt & completion state to a file.
*/
async saveSession(
filepath: string,
options?: { tokenSize: number },
): Promise<number> {
return RNLlama.saveSession(this.id, filepath, options?.tokenSize || -1)
}
isLlamaChatSupported(): boolean {
return !!this.model.chatTemplates.llamaChat
}
isJinjaSupported(): boolean {
const { minja } = this.model.chatTemplates
return !!minja?.toolUse || !!minja?.default
}
async getFormattedChat(
messages: RNLlamaOAICompatibleMessage[],
template?: string | null,
params?: {
jinja?: boolean
response_format?: CompletionResponseFormat
tools?: object
parallel_tool_calls?: object
tool_choice?: string,
enable_thinking?: boolean,
},
): Promise<FormattedChatResult | JinjaFormattedChatResult> {
const mediaPaths: string[] = []
const chat = messages.map((msg) => {
if (Array.isArray(msg.content)) {
const content = msg.content.map((part) => {
// Handle multimodal content
if (part.type === 'image_url') {
let path = part.image_url?.url || ''
if (path?.startsWith('file://')) path = path.slice(7)
mediaPaths.push(path)
return {
type: 'text',
text: RNLLAMA_MTMD_DEFAULT_MEDIA_MARKER,
}
} else if (part.type === 'input_audio') {
const { input_audio: audio } = part
if (!audio) throw new Error('input_audio is required')
const { format } = audio
if (format != 'wav' && format != 'mp3') {
throw new Error(`Unsupported audio format: ${format}`)
}
if (audio.url) {
const path = audio.url.replace(/file:\/\//, '')
mediaPaths.push(path)
} else if (audio.data) {
mediaPaths.push(audio.data)
}
return {
type: 'text',
text: RNLLAMA_MTMD_DEFAULT_MEDIA_MARKER,
}
}
return part
})
return {
...msg,
content,
}
}
return msg
}) as NativeLlamaChatMessage[]
const useJinja = this.isJinjaSupported() && params?.jinja
let tmpl
if (template) tmpl = template // Force replace if provided
const jsonSchema = getJsonSchema(params?.response_format)
const result = await RNLlama.getFormattedChat(
this.id,
JSON.stringify(chat),
tmpl,
{
jinja: useJinja,
json_schema: jsonSchema ? JSON.stringify(jsonSchema) : undefined,
tools: params?.tools ? JSON.stringify(params.tools) : undefined,
parallel_tool_calls: params?.parallel_tool_calls
? JSON.stringify(params.parallel_tool_calls)
: undefined,
tool_choice: params?.tool_choice,
enable_thinking: params?.enable_thinking ?? true,
},
)
if (!useJinja) {
return {
type: 'llama-chat',
prompt: result as string,
has_media: mediaPaths.length > 0,
media_paths: mediaPaths,
}
}
const jinjaResult = result as JinjaFormattedChatResult
jinjaResult.type = 'jinja'
jinjaResult.has_media = mediaPaths.length > 0
jinjaResult.media_paths = mediaPaths
return jinjaResult
}
/**
* Generate a completion based on the provided parameters
* @param params Completion parameters including prompt or messages
* @param callback Optional callback for token-by-token streaming
* @returns Promise resolving to the completion result
*
* Note: For multimodal support, you can include an media_paths parameter.
* This will process the images and add them to the context before generating text.
* Multimodal support must be enabled via initMultimodal() first.
*/
async completion(
params: CompletionParams,
callback?: (data: TokenData) => void,
): Promise<NativeCompletionResult> {
const nativeParams = {
...params,
prompt: params.prompt || '',
emit_partial_completion: !!callback,
}
if (params.messages) {
const formattedResult = await this.getFormattedChat(
params.messages,
params.chat_template || params.chatTemplate,
{
jinja: params.jinja,
tools: params.tools,
parallel_tool_calls: params.parallel_tool_calls,
tool_choice: params.tool_choice,
enable_thinking: params.enable_thinking,
},
)
if (formattedResult.type === 'jinja') {
const jinjaResult = formattedResult as JinjaFormattedChatResult
nativeParams.prompt = jinjaResult.prompt || ''
if (typeof jinjaResult.chat_format === 'number')
nativeParams.chat_format = jinjaResult.chat_format
if (jinjaResult.grammar) nativeParams.grammar = jinjaResult.grammar
if (typeof jinjaResult.grammar_lazy === 'boolean')
nativeParams.grammar_lazy = jinjaResult.grammar_lazy
if (jinjaResult.grammar_triggers)
nativeParams.grammar_triggers = jinjaResult.grammar_triggers
if (jinjaResult.preserved_tokens)
nativeParams.preserved_tokens = jinjaResult.preserved_tokens
if (jinjaResult.additional_stops) {
if (!nativeParams.stop) nativeParams.stop = []
nativeParams.stop.push(...jinjaResult.additional_stops)
}
if (jinjaResult.has_media) {
nativeParams.media_paths = jinjaResult.media_paths
}
} else if (formattedResult.type === 'llama-chat') {
const llamaChatResult = formattedResult as FormattedChatResult
nativeParams.prompt = llamaChatResult.prompt || ''
if (llamaChatResult.has_media) {
nativeParams.media_paths = llamaChatResult.media_paths
}
}
} else {
nativeParams.prompt = params.prompt || ''
}
// If media_paths were explicitly provided or extracted from messages, use them
if (!nativeParams.media_paths && params.media_paths) {
nativeParams.media_paths = params.media_paths
}
if (nativeParams.response_format && !nativeParams.grammar) {
const jsonSchema = getJsonSchema(params.response_format)
if (jsonSchema) nativeParams.json_schema = JSON.stringify(jsonSchema)
}
let tokenListener: any =
callback &&
EventEmitter.addListener(EVENT_ON_TOKEN, (evt: TokenNativeEvent) => {
const { contextId, tokenResult } = evt
if (contextId !== this.id) return
callback(tokenResult)
})
if (!nativeParams.prompt) throw new Error('Prompt is required')
const promise = RNLlama.completion(this.id, nativeParams)
return promise
.then((completionResult) => {
tokenListener?.remove()
tokenListener = null
return completionResult
})
.catch((err: any) => {
tokenListener?.remove()
tokenListener = null
throw err
})
}
stopCompletion(): Promise<void> {
return RNLlama.stopCompletion(this.id)
}
/**
* Tokenize text or text with images
* @param text Text to tokenize
* @param params.media_paths Array of image paths to tokenize (if multimodal is enabled)
* @returns Promise resolving to the tokenize result
*/
tokenizeAsync(
text: string,
{
media_paths: mediaPaths,
}: {
media_paths?: string[]
} = {},
): Promise<NativeTokenizeResult> {
return RNLlama.tokenizeAsync(this.id, text, mediaPaths)
}
tokenizeSync(
text: string,
{
media_paths: mediaPaths,
}: {
media_paths?: string[]
} = {},
): NativeTokenizeResult {
return RNLlama.tokenizeSync(this.id, text, mediaPaths)
}
detokenize(tokens: number[]): Promise<string> {
return RNLlama.detokenize(this.id, tokens)
}
embedding(
text: string,
params?: EmbeddingParams,
): Promise<NativeEmbeddingResult> {
return RNLlama.embedding(this.id, text, params || {})
}
/**
* Rerank documents based on relevance to a query
* @param query The query text to rank documents against
* @param documents Array of document texts to rank
* @param params Optional reranking parameters
* @returns Promise resolving to an array of ranking results with scores and indices
*/
async rerank(
query: string,
documents: string[],
params?: RerankParams,
): Promise<RerankResult[]> {
const results = await RNLlama.rerank(this.id, query, documents, params || {})
// Sort by score descending and add document text if requested
return results
.map((result) => ({
...result,
document: documents[result.index],
}))
.sort((a, b) => b.score - a.score)
}
async bench(
pp: number,
tg: number,
pl: number,
nr: number,
): Promise<BenchResult> {
const result = await RNLlama.bench(this.id, pp, tg, pl, nr)
const [modelDesc, modelSize, modelNParams, ppAvg, ppStd, tgAvg, tgStd] =
JSON.parse(result)
return {
modelDesc,
modelSize,
modelNParams,
ppAvg,
ppStd,
tgAvg,
tgStd,
}
}
async applyLoraAdapters(
loraList: Array<{ path: string; scaled?: number }>,
): Promise<void> {
let loraAdapters: Array<{ path: string; scaled?: number }> = []
if (loraList)
loraAdapters = loraList.map((l) => ({
path: l.path.replace(/file:\/\//, ''),
scaled: l.scaled,
}))
return RNLlama.applyLoraAdapters(this.id, loraAdapters)
}
async removeLoraAdapters(): Promise<void> {
return RNLlama.removeLoraAdapters(this.id)
}
async getLoadedLoraAdapters(): Promise<
Array<{ path: string; scaled?: number }>
> {
return RNLlama.getLoadedLoraAdapters(this.id)
}
/**
* Initialize multimodal support with a mmproj file
* @param params Parameters for multimodal support
* @param params.path Path to the multimodal projector file
* @param params.use_gpu Whether to use GPU
* @returns Promise resolving to true if initialization was successful
*/
async initMultimodal({
path,
use_gpu: useGpu,
}: {
path: string
use_gpu?: boolean
}): Promise<boolean> {
if (path.startsWith('file://')) path = path.slice(7)
return RNLlama.initMultimodal(this.id, {
path,
use_gpu: useGpu ?? true,
})
}
/**
* Check if multimodal support is enabled
* @returns Promise resolving to true if multimodal is enabled
*/
async isMultimodalEnabled(): Promise<boolean> {
return await RNLlama.isMultimodalEnabled(this.id)
}
/**
* Check multimodal support
* @returns Promise resolving to an object with vision and audio support
*/
async getMultimodalSupport(): Promise<{
vision: boolean
audio: boolean
}> {
return await RNLlama.getMultimodalSupport(this.id)
}
/**
* Release multimodal support
* @returns Promise resolving to void
*/
async releaseMultimodal(): Promise<void> {
return await RNLlama.releaseMultimodal(this.id)
}
/**
* Initialize TTS support with a vocoder model
* @param params Parameters for TTS support
* @param params.path Path to the vocoder model
* @returns Promise resolving to true if initialization was successful
*/
async initVocoder({ path }: { path: string }): Promise<boolean> {
if (path.startsWith('file://')) path = path.slice(7)
return await RNLlama.initVocoder(this.id, path)
}
/**
* Check if TTS support is enabled
* @returns Promise resolving to true if TTS is enabled
*/
async isVocoderEnabled(): Promise<boolean> {
return await RNLlama.isVocoderEnabled(this.id)
}
/**
* Get a formatted audio completion prompt
* @param speakerJsonStr JSON string representing the speaker
* @param textToSpeak Text to speak
* @returns Promise resolving to the formatted audio completion prompt
*/
async getFormattedAudioCompletion(
speaker: object | null,
textToSpeak: string,
): Promise<string> {
return await RNLlama.getFormattedAudioCompletion(
this.id,
speaker ? JSON.stringify(speaker) : '',
textToSpeak,
)
}
/**
* Get guide tokens for audio completion
* @param textToSpeak Text to speak
* @returns Promise resolving to the guide tokens
*/
async getAudioCompletionGuideTokens(
textToSpeak: string,
): Promise<Array<number>> {
return await RNLlama.getAudioCompletionGuideTokens(this.id, textToSpeak)
}
/**
* Decode audio tokens
* @param tokens Array of audio tokens
* @returns Promise resolving to the decoded audio tokens
*/
async decodeAudioTokens(tokens: number[]): Promise<Array<number>> {
return await RNLlama.decodeAudioTokens(this.id, tokens)
}
/**
* Release TTS support
* @returns Promise resolving to void
*/
async releaseVocoder(): Promise<void> {
return await RNLlama.releaseVocoder(this.id)
}
async release(): Promise<void> {
return RNLlama.releaseContext(this.id)
}
}
export async function getCpuFeatures() : Promise<NativeCPUFeatures> {
if(Platform.OS === 'android') {
return RNLlama.getCpuFeatures()
}
console.warn("getCpuFeatures() is an android only feature")
return {
i8mm: false,
armv8: false,
dotprod: false,
}
}
export async function toggleNativeLog(enabled: boolean): Promise<void> {
return RNLlama.toggleNativeLog(enabled)
}
export function addNativeLogListener(
listener: (level: string, text: string) => void,
): { remove: () => void } {
logListeners.push(listener)
return {
remove: () => {
logListeners.splice(logListeners.indexOf(listener), 1)
},
}
}
export async function setContextLimit(limit: number): Promise<void> {
return RNLlama.setContextLimit(limit)
}
let contextIdCounter = 0
const contextIdRandom = () =>
process.env.NODE_ENV === 'test' ? 0 : Math.floor(Math.random() * 100000)
const modelInfoSkip = [
// Large fields
'tokenizer.ggml.tokens',
'tokenizer.ggml.token_type',
'tokenizer.ggml.merges',
'tokenizer.ggml.scores',
]
export async function loadLlamaModelInfo(model: string): Promise<Object> {
let path = model
if (path.startsWith('file://')) path = path.slice(7)
return RNLlama.modelInfo(path, modelInfoSkip)
}
const poolTypeMap = {
// -1 is unspecified as undefined
none: 0,
mean: 1,
cls: 2,
last: 3,
rank: 4,
}
export async function initLlama(
{
model,
is_model_asset: isModelAsset,
pooling_type: poolingType,
lora,
lora_list: loraList,
...rest
}: ContextParams,
onProgress?: (progress: number) => void,
): Promise<LlamaContext> {
let path = model
if (path.startsWith('file://')) path = path.slice(7)
let loraPath = lora
if (loraPath?.startsWith('file://')) loraPath = loraPath.slice(7)
let loraAdapters: Array<{ path: string; scaled?: number }> = []
if (loraList)
loraAdapters = loraList.map((l) => ({
path: l.path.replace(/file:\/\//, ''),
scaled: l.scaled,
}))
const contextId = contextIdCounter + contextIdRandom()
contextIdCounter += 1
let removeProgressListener: any = null
if (onProgress) {
removeProgressListener = EventEmitter.addListener(
EVENT_ON_INIT_CONTEXT_PROGRESS,
(evt: { contextId: number; progress: number }) => {
if (evt.contextId !== contextId) return
onProgress(evt.progress)
},
)
}
const poolType = poolTypeMap[poolingType as keyof typeof poolTypeMap]
const {
gpu,
reasonNoGPU,
model: modelDetails,
androidLib,
} = await RNLlama.initContext(contextId, {
model: path,
is_model_asset: !!isModelAsset,
use_progress_callback: !!onProgress,
pooling_type: poolType,
lora: loraPath,
lora_list: loraAdapters,
...rest,
}).catch((err: any) => {
removeProgressListener?.remove()
throw err
})
removeProgressListener?.remove()
return new LlamaContext({
contextId,
gpu,
reasonNoGPU,
model: modelDetails,
androidLib,
})
}
export async function releaseAllLlama(): Promise<void> {
return RNLlama.releaseAllContexts()
}