UNPKG

llama.rn

Version:

React Native binding of llama.cpp

364 lines (322 loc) 8.94 kB
import { NativeEventEmitter, DeviceEventEmitter, Platform } from 'react-native' import type { DeviceEventEmitterStatic } from 'react-native' import RNLlama from './NativeRNLlama' import type { NativeContextParams, NativeLlamaContext, NativeCompletionParams, NativeCompletionTokenProb, NativeCompletionResult, NativeTokenizeResult, NativeEmbeddingResult, NativeSessionLoadResult, NativeEmbeddingParams, NativeCompletionTokenProbItem, NativeCompletionResultTimings, } from './NativeRNLlama' import type { SchemaGrammarConverterPropOrder, SchemaGrammarConverterBuiltinRule, } from './grammar' import { SchemaGrammarConverter, convertJsonSchemaToGrammar } from './grammar' import type { RNLlamaMessagePart, RNLlamaOAICompatibleMessage } from './chat' import { formatChat } from './chat' export type { NativeContextParams, NativeLlamaContext, NativeCompletionParams, NativeCompletionTokenProb, NativeCompletionResult, NativeTokenizeResult, NativeEmbeddingResult, NativeSessionLoadResult, NativeEmbeddingParams, NativeCompletionTokenProbItem, NativeCompletionResultTimings, RNLlamaMessagePart, RNLlamaOAICompatibleMessage, SchemaGrammarConverterPropOrder, SchemaGrammarConverterBuiltinRule, } export { SchemaGrammarConverter, convertJsonSchemaToGrammar } const EVENT_ON_INIT_CONTEXT_PROGRESS = '@RNLlama_onInitContextProgress' const EVENT_ON_TOKEN = '@RNLlama_onToken' let EventEmitter: NativeEventEmitter | DeviceEventEmitterStatic if (Platform.OS === 'ios') { // @ts-ignore EventEmitter = new NativeEventEmitter(RNLlama) } if (Platform.OS === 'android') { EventEmitter = DeviceEventEmitter } export type TokenData = { token: string completion_probabilities?: Array<NativeCompletionTokenProb> } type TokenNativeEvent = { contextId: number tokenResult: TokenData } export type ContextParams = Omit< NativeContextParams, 'cache_type_k' | 'cache_type_v' | 'pooling_type' > & { cache_type_k?: | 'f16' | 'f32' | 'q8_0' | 'q4_0' | 'q4_1' | 'iq4_nl' | 'q5_0' | 'q5_1' cache_type_v?: | 'f16' | 'f32' | 'q8_0' | 'q4_0' | 'q4_1' | 'iq4_nl' | 'q5_0' | 'q5_1' pooling_type?: 'none' | 'mean' | 'cls' | 'last' | 'rank' } export type EmbeddingParams = NativeEmbeddingParams export type CompletionParams = Omit< NativeCompletionParams, 'emit_partial_completion' | 'prompt' > & { prompt?: string messages?: RNLlamaOAICompatibleMessage[] chatTemplate?: string } export type BenchResult = { modelDesc: string modelSize: number modelNParams: number ppAvg: number ppStd: number tgAvg: number tgStd: number } export class LlamaContext { id: number gpu: boolean = false reasonNoGPU: string = '' model: { isChatTemplateSupported?: boolean } = {} constructor({ contextId, gpu, reasonNoGPU, model }: NativeLlamaContext) { this.id = contextId this.gpu = gpu this.reasonNoGPU = reasonNoGPU this.model = model } /** * Load cached prompt & completion state from a file. */ async loadSession(filepath: string): Promise<NativeSessionLoadResult> { let path = filepath if (path.startsWith('file://')) path = path.slice(7) return RNLlama.loadSession(this.id, path) } /** * Save current cached prompt & completion state to a file. */ async saveSession( filepath: string, options?: { tokenSize: number }, ): Promise<number> { return RNLlama.saveSession(this.id, filepath, options?.tokenSize || -1) } async getFormattedChat( messages: RNLlamaOAICompatibleMessage[], template?: string, ): Promise<string> { const chat = formatChat(messages) let tmpl = this.model?.isChatTemplateSupported ? undefined : 'chatml' if (template) tmpl = template // Force replace if provided return RNLlama.getFormattedChat(this.id, chat, tmpl) } async completion( params: CompletionParams, callback?: (data: TokenData) => void, ): Promise<NativeCompletionResult> { let finalPrompt = params.prompt if (params.messages) { // messages always win finalPrompt = await this.getFormattedChat( params.messages, params.chatTemplate, ) } let tokenListener: any = callback && EventEmitter.addListener(EVENT_ON_TOKEN, (evt: TokenNativeEvent) => { const { contextId, tokenResult } = evt if (contextId !== this.id) return callback(tokenResult) }) if (!finalPrompt) throw new Error('Prompt is required') const promise = RNLlama.completion(this.id, { ...params, prompt: finalPrompt, emit_partial_completion: !!callback, }) return promise .then((completionResult) => { tokenListener?.remove() tokenListener = null return completionResult }) .catch((err: any) => { tokenListener?.remove() tokenListener = null throw err }) } stopCompletion(): Promise<void> { return RNLlama.stopCompletion(this.id) } tokenize(text: string): Promise<NativeTokenizeResult> { return RNLlama.tokenize(this.id, text) } detokenize(tokens: number[]): Promise<string> { return RNLlama.detokenize(this.id, tokens) } embedding( text: string, params?: EmbeddingParams, ): Promise<NativeEmbeddingResult> { return RNLlama.embedding(this.id, text, params || {}) } async bench( pp: number, tg: number, pl: number, nr: number, ): Promise<BenchResult> { const result = await RNLlama.bench(this.id, pp, tg, pl, nr) const [modelDesc, modelSize, modelNParams, ppAvg, ppStd, tgAvg, tgStd] = JSON.parse(result) return { modelDesc, modelSize, modelNParams, ppAvg, ppStd, tgAvg, tgStd, } } async applyLoraAdapters( loraList: Array<{ path: string; scaled?: number }>, ): Promise<void> { let loraAdapters: Array<{ path: string; scaled?: number }> = [] if (loraList) loraAdapters = loraList.map((l) => ({ path: l.path.replace(/file:\/\//, ''), scaled: l.scaled, })) return RNLlama.applyLoraAdapters(this.id, loraAdapters) } async removeLoraAdapters(): Promise<void> { return RNLlama.removeLoraAdapters(this.id) } async getLoadedLoraAdapters(): Promise< Array<{ path: string; scaled?: number }> > { return RNLlama.getLoadedLoraAdapters(this.id) } async release(): Promise<void> { return RNLlama.releaseContext(this.id) } } export async function setContextLimit(limit: number): Promise<void> { return RNLlama.setContextLimit(limit) } let contextIdCounter = 0 const contextIdRandom = () => process.env.NODE_ENV === 'test' ? 0 : Math.floor(Math.random() * 100000) const modelInfoSkip = [ // Large fields 'tokenizer.ggml.tokens', 'tokenizer.ggml.token_type', 'tokenizer.ggml.merges', ] export async function loadLlamaModelInfo(model: string): Promise<Object> { let path = model if (path.startsWith('file://')) path = path.slice(7) return RNLlama.modelInfo(path, modelInfoSkip) } const poolTypeMap = { // -1 is unspecified as undefined none: 0, mean: 1, cls: 2, last: 3, rank: 4, } export async function initLlama( { model, is_model_asset: isModelAsset, pooling_type: poolingType, lora, lora_list: loraList, ...rest }: ContextParams, onProgress?: (progress: number) => void, ): Promise<LlamaContext> { let path = model if (path.startsWith('file://')) path = path.slice(7) let loraPath = lora if (loraPath?.startsWith('file://')) loraPath = loraPath.slice(7) let loraAdapters: Array<{ path: string; scaled?: number }> = [] if (loraList) loraAdapters = loraList.map((l) => ({ path: l.path.replace(/file:\/\//, ''), scaled: l.scaled, })) const contextId = contextIdCounter + contextIdRandom() contextIdCounter += 1 let removeProgressListener: any = null if (onProgress) { removeProgressListener = EventEmitter.addListener( EVENT_ON_INIT_CONTEXT_PROGRESS, (evt: { contextId: number; progress: number }) => { if (evt.contextId !== contextId) return onProgress(evt.progress) }, ) } const poolType = poolTypeMap[poolingType as keyof typeof poolTypeMap] const { gpu, reasonNoGPU, model: modelDetails, androidLib, } = await RNLlama.initContext(contextId, { model: path, is_model_asset: !!isModelAsset, use_progress_callback: !!onProgress, pooling_type: poolType, lora: loraPath, lora_list: loraAdapters, ...rest, }).catch((err: any) => { removeProgressListener?.remove() throw err }) removeProgressListener?.remove() return new LlamaContext({ contextId, gpu, reasonNoGPU, model: modelDetails, androidLib, }) } export async function releaseAllLlama(): Promise<void> { return RNLlama.releaseAllContexts() }