cui-llama.rn
Version:
Fork of llama.rn for ChatterUI
357 lines (315 loc) • 9.69 kB
text/typescript
import { NativeEventEmitter, DeviceEventEmitter, Platform } from 'react-native'
import type { DeviceEventEmitterStatic } from 'react-native'
import RNLlama from './NativeRNLlama'
import type {
NativeContextParams,
NativeLlamaContext,
NativeCompletionParams,
NativeCompletionTokenProb,
NativeCompletionResult,
NativeTokenizeResult,
NativeEmbeddingResult,
NativeSessionLoadResult,
NativeCPUFeatures,
NativeEmbeddingParams,
NativeCompletionTokenProbItem,
NativeCompletionResultTimings,
} from './NativeRNLlama'
import type { SchemaGrammarConverterPropOrder, SchemaGrammarConverterBuiltinRule } from './grammar'
import { SchemaGrammarConverter, convertJsonSchemaToGrammar } from './grammar'
import type { RNLlamaMessagePart, RNLlamaOAICompatibleMessage } from './chat'
import { formatChat } from './chat'
export type {
NativeContextParams,
NativeLlamaContext,
NativeCompletionParams,
NativeCompletionTokenProb,
NativeCompletionResult,
NativeTokenizeResult,
NativeEmbeddingResult,
NativeSessionLoadResult,
NativeEmbeddingParams,
NativeCompletionTokenProbItem,
NativeCompletionResultTimings,
RNLlamaMessagePart,
RNLlamaOAICompatibleMessage,
SchemaGrammarConverterPropOrder,
SchemaGrammarConverterBuiltinRule,
}
export { SchemaGrammarConverter, convertJsonSchemaToGrammar }
const EVENT_ON_INIT_CONTEXT_PROGRESS = '@RNLlama_onInitContextProgress'
const EVENT_ON_TOKEN = '@RNLlama_onToken'
let EventEmitter: NativeEventEmitter | DeviceEventEmitterStatic
if (Platform.OS === 'ios') {
// @ts-ignore
EventEmitter = new NativeEventEmitter(RNLlama)
}
if (Platform.OS === 'android') {
EventEmitter = DeviceEventEmitter
}
export type TokenData = {
token: string
completion_probabilities?: Array<NativeCompletionTokenProb>
}
type TokenNativeEvent = {
contextId: number
tokenResult: TokenData
}
export enum GGML_TYPE {
LM_GGML_TYPE_F32 = 0,
LM_GGML_TYPE_F16 = 1,
LM_GGML_TYPE_Q4_0 = 2,
LM_GGML_TYPE_Q4_1 = 3,
// LM_GGML_TYPE_Q4_2 = 4, support has been removed
// LM_GGML_TYPE_Q4_3 = 5, support has been removed
LM_GGML_TYPE_Q5_0 = 6,
LM_GGML_TYPE_Q5_1 = 7,
LM_GGML_TYPE_Q8_0 = 8,
LM_GGML_TYPE_Q8_1 = 9,
LM_GGML_TYPE_Q2_K = 10,
LM_GGML_TYPE_Q3_K = 11,
LM_GGML_TYPE_Q4_K = 12,
LM_GGML_TYPE_Q5_K = 13,
LM_GGML_TYPE_Q6_K = 14,
LM_GGML_TYPE_Q8_K = 15,
LM_GGML_TYPE_IQ2_XXS = 16,
LM_GGML_TYPE_IQ2_XS = 17,
LM_GGML_TYPE_IQ3_XXS = 18,
LM_GGML_TYPE_IQ1_S = 19,
LM_GGML_TYPE_IQ4_NL = 20,
LM_GGML_TYPE_IQ3_S = 21,
LM_GGML_TYPE_IQ2_S = 22,
LM_GGML_TYPE_IQ4_XS = 23,
LM_GGML_TYPE_I8 = 24,
LM_GGML_TYPE_I16 = 25,
LM_GGML_TYPE_I32 = 26,
LM_GGML_TYPE_I64 = 27,
LM_GGML_TYPE_F64 = 28,
LM_GGML_TYPE_IQ1_M = 29,
LM_GGML_TYPE_BF16 = 30,
// LM_GGML_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files
// LM_GGML_TYPE_Q4_0_4_8 = 32,
// LM_GGML_TYPE_Q4_0_8_8 = 33,
LM_GGML_TYPE_TQ1_0 = 34,
LM_GGML_TYPE_TQ2_0 = 35,
// LM_GGML_TYPE_IQ4_NL_4_4 = 36,
// LM_GGML_TYPE_IQ4_NL_4_8 = 37,
// LM_GGML_TYPE_IQ4_NL_8_8 = 38,
LM_GGML_TYPE_COUNT = 39,
};
export type ContextParams = Omit<
NativeContextParams,
'cache_type_k' | 'cache_type_v' | 'pooling_type'
> & {
cache_type_k?: GGML_TYPE
cache_type_v?: GGML_TYPE
pooling_type?: 'none' | 'mean' | 'cls' | 'last' | 'rank'
}
export type EmbeddingParams = NativeEmbeddingParams
export type CompletionParams = Omit<
NativeCompletionParams,
'emit_partial_completion' | 'prompt'
> & {
prompt?: string
messages?: RNLlamaOAICompatibleMessage[]
chatTemplate?: string
}
export type BenchResult = {
modelDesc: string
modelSize: number
modelNParams: number
ppAvg: number
ppStd: number
tgAvg: number
tgStd: number
}
export class LlamaContext {
id: number
gpu: boolean = false
reasonNoGPU: string = ''
model: {
isChatTemplateSupported?: boolean
} = {}
constructor({ contextId, gpu, reasonNoGPU, model }: NativeLlamaContext) {
this.id = contextId
this.gpu = gpu
this.reasonNoGPU = reasonNoGPU
this.model = model
}
/**
* Load cached prompt & completion state from a file.
*/
async loadSession(filepath: string): Promise<NativeSessionLoadResult> {
let path = filepath
if (path.startsWith('file://')) path = path.slice(7)
return RNLlama.loadSession(this.id, path)
}
/**
* Save current cached prompt & completion state to a file.
*/
async saveSession(
filepath: string,
options?: { tokenSize: number },
): Promise<number> {
return RNLlama.saveSession(this.id, filepath, options?.tokenSize || -1)
}
async getFormattedChat(
messages: RNLlamaOAICompatibleMessage[],
template?: string,
): Promise<string> {
const chat = formatChat(messages)
let tmpl = this.model?.isChatTemplateSupported ? undefined : 'chatml'
if (template) tmpl = template // Force replace if provided
return RNLlama.getFormattedChat(this.id, chat, tmpl)
}
async completion(
params: CompletionParams,
callback?: (data: TokenData) => void,
): Promise<NativeCompletionResult> {
let finalPrompt = params.prompt
if (params.messages) {
// messages always win
finalPrompt = await this.getFormattedChat(params.messages, params.chatTemplate)
}
let tokenListener: any =
callback &&
EventEmitter.addListener(EVENT_ON_TOKEN, (evt: TokenNativeEvent) => {
const { contextId, tokenResult } = evt
if (contextId !== this.id) return
callback(tokenResult)
})
if (!finalPrompt) throw new Error('Prompt is required')
const promise = RNLlama.completion(this.id, {
...params,
prompt: finalPrompt,
emit_partial_completion: !!callback,
})
return promise
.then((completionResult) => {
tokenListener?.remove()
tokenListener = null
return completionResult
})
.catch((err: any) => {
tokenListener?.remove()
tokenListener = null
throw err
})
}
stopCompletion(): Promise<void> {
return RNLlama.stopCompletion(this.id)
}
tokenizeAsync(text: string): Promise<NativeTokenizeResult> {
return RNLlama.tokenizeAsync(this.id, text)
}
tokenizeSync(text: string): NativeTokenizeResult {
return RNLlama.tokenizeSync(this.id, text)
}
detokenize(tokens: number[]): Promise<string> {
return RNLlama.detokenize(this.id, tokens)
}
embedding(
text: string,
params?: EmbeddingParams,
): Promise<NativeEmbeddingResult> {
return RNLlama.embedding(this.id, text, params || {})
}
async bench(
pp: number,
tg: number,
pl: number,
nr: number,
): Promise<BenchResult> {
const result = await RNLlama.bench(this.id, pp, tg, pl, nr)
const [modelDesc, modelSize, modelNParams, ppAvg, ppStd, tgAvg, tgStd] =
JSON.parse(result)
return {
modelDesc,
modelSize,
modelNParams,
ppAvg,
ppStd,
tgAvg,
tgStd,
}
}
async release(): Promise<void> {
return RNLlama.releaseContext(this.id)
}
}
export async function getCpuFeatures() : Promise<NativeCPUFeatures> {
return RNLlama.getCpuFeatures()
}
export async function setContextLimit(limit: number): Promise<void> {
return RNLlama.setContextLimit(limit)
}
let contextIdCounter = 0
const contextIdRandom = () =>
process.env.NODE_ENV === 'test' ? 0 : Math.floor(Math.random() * 100000)
const modelInfoSkip = [
// Large fields
'tokenizer.ggml.tokens',
'tokenizer.ggml.token_type',
'tokenizer.ggml.merges',
]
export async function loadLlamaModelInfo(model: string): Promise<Object> {
let path = model
if (path.startsWith('file://')) path = path.slice(7)
return RNLlama.modelInfo(path, modelInfoSkip)
}
const poolTypeMap = {
// -1 is unspecified as undefined
none: 0,
mean: 1,
cls: 2,
last: 3,
rank: 4,
}
export async function initLlama(
{
model,
is_model_asset: isModelAsset,
pooling_type: poolingType,
lora,
...rest
}: ContextParams,
onProgress?: (progress: number) => void,
): Promise<LlamaContext> {
let path = model
if (path.startsWith('file://')) path = path.slice(7)
let loraPath = lora
if (loraPath?.startsWith('file://')) loraPath = loraPath.slice(7)
const contextId = contextIdCounter + contextIdRandom()
contextIdCounter += 1
let removeProgressListener: any = null
if (onProgress) {
removeProgressListener = EventEmitter.addListener(
EVENT_ON_INIT_CONTEXT_PROGRESS,
(evt: { contextId: number; progress: number }) => {
if (evt.contextId !== contextId) return
onProgress(evt.progress)
},
)
}
const poolType = poolTypeMap[poolingType as keyof typeof poolTypeMap]
const {
gpu,
reasonNoGPU,
model: modelDetails,
} = await RNLlama.initContext(contextId, {
model: path,
is_model_asset: !!isModelAsset,
use_progress_callback: !!onProgress,
pooling_type: poolType,
lora: loraPath,
...rest,
}).catch((err: any) => {
removeProgressListener?.remove()
throw err
})
removeProgressListener?.remove()
return new LlamaContext({ contextId, gpu, reasonNoGPU, model: modelDetails })
}
export async function releaseAllLlama(): Promise<void> {
return RNLlama.releaseAllContexts()
}