llama.rn
Version:
React Native binding of llama.cpp
210 lines (208 loc) • 6.51 kB
JavaScript
import { NativeEventEmitter, DeviceEventEmitter, Platform } from 'react-native';
import RNLlama from './NativeRNLlama';
import { SchemaGrammarConverter, convertJsonSchemaToGrammar } from './grammar';
import { formatChat } from './chat';
export { SchemaGrammarConverter, convertJsonSchemaToGrammar };
const EVENT_ON_INIT_CONTEXT_PROGRESS = '@RNLlama_onInitContextProgress';
const EVENT_ON_TOKEN = '@RNLlama_onToken';
let EventEmitter;
if (Platform.OS === 'ios') {
// @ts-ignore
EventEmitter = new NativeEventEmitter(RNLlama);
}
if (Platform.OS === 'android') {
EventEmitter = DeviceEventEmitter;
}
export class LlamaContext {
gpu = false;
reasonNoGPU = '';
model = {};
constructor(_ref) {
let {
contextId,
gpu,
reasonNoGPU,
model
} = _ref;
this.id = contextId;
this.gpu = gpu;
this.reasonNoGPU = reasonNoGPU;
this.model = model;
}
/**
* Load cached prompt & completion state from a file.
*/
async loadSession(filepath) {
let path = filepath;
if (path.startsWith('file://')) path = path.slice(7);
return RNLlama.loadSession(this.id, path);
}
/**
* Save current cached prompt & completion state to a file.
*/
async saveSession(filepath, options) {
return RNLlama.saveSession(this.id, filepath, (options === null || options === void 0 ? void 0 : options.tokenSize) || -1);
}
async getFormattedChat(messages, template) {
var _this$model;
const chat = formatChat(messages);
let tmpl = (_this$model = this.model) !== null && _this$model !== void 0 && _this$model.isChatTemplateSupported ? undefined : 'chatml';
if (template) tmpl = template; // Force replace if provided
return RNLlama.getFormattedChat(this.id, chat, tmpl);
}
async completion(params, callback) {
let finalPrompt = params.prompt;
if (params.messages) {
// messages always win
finalPrompt = await this.getFormattedChat(params.messages, params.chatTemplate);
}
let tokenListener = callback && EventEmitter.addListener(EVENT_ON_TOKEN, evt => {
const {
contextId,
tokenResult
} = evt;
if (contextId !== this.id) return;
callback(tokenResult);
});
if (!finalPrompt) throw new Error('Prompt is required');
const promise = RNLlama.completion(this.id, {
...params,
prompt: finalPrompt,
emit_partial_completion: !!callback
});
return promise.then(completionResult => {
var _tokenListener;
(_tokenListener = tokenListener) === null || _tokenListener === void 0 ? void 0 : _tokenListener.remove();
tokenListener = null;
return completionResult;
}).catch(err => {
var _tokenListener2;
(_tokenListener2 = tokenListener) === null || _tokenListener2 === void 0 ? void 0 : _tokenListener2.remove();
tokenListener = null;
throw err;
});
}
stopCompletion() {
return RNLlama.stopCompletion(this.id);
}
tokenize(text) {
return RNLlama.tokenize(this.id, text);
}
detokenize(tokens) {
return RNLlama.detokenize(this.id, tokens);
}
embedding(text, params) {
return RNLlama.embedding(this.id, text, params || {});
}
async bench(pp, tg, pl, nr) {
const result = await RNLlama.bench(this.id, pp, tg, pl, nr);
const [modelDesc, modelSize, modelNParams, ppAvg, ppStd, tgAvg, tgStd] = JSON.parse(result);
return {
modelDesc,
modelSize,
modelNParams,
ppAvg,
ppStd,
tgAvg,
tgStd
};
}
async applyLoraAdapters(loraList) {
let loraAdapters = [];
if (loraList) loraAdapters = loraList.map(l => ({
path: l.path.replace(/file:\/\//, ''),
scaled: l.scaled
}));
return RNLlama.applyLoraAdapters(this.id, loraAdapters);
}
async removeLoraAdapters() {
return RNLlama.removeLoraAdapters(this.id);
}
async getLoadedLoraAdapters() {
return RNLlama.getLoadedLoraAdapters(this.id);
}
async release() {
return RNLlama.releaseContext(this.id);
}
}
export async function setContextLimit(limit) {
return RNLlama.setContextLimit(limit);
}
let contextIdCounter = 0;
const contextIdRandom = () => process.env.NODE_ENV === 'test' ? 0 : Math.floor(Math.random() * 100000);
const modelInfoSkip = [
// Large fields
'tokenizer.ggml.tokens', 'tokenizer.ggml.token_type', 'tokenizer.ggml.merges'];
export async function loadLlamaModelInfo(model) {
let path = model;
if (path.startsWith('file://')) path = path.slice(7);
return RNLlama.modelInfo(path, modelInfoSkip);
}
const poolTypeMap = {
// -1 is unspecified as undefined
none: 0,
mean: 1,
cls: 2,
last: 3,
rank: 4
};
export async function initLlama(_ref2, onProgress) {
var _loraPath, _removeProgressListen2;
let {
model,
is_model_asset: isModelAsset,
pooling_type: poolingType,
lora,
lora_list: loraList,
...rest
} = _ref2;
let path = model;
if (path.startsWith('file://')) path = path.slice(7);
let loraPath = lora;
if ((_loraPath = loraPath) !== null && _loraPath !== void 0 && _loraPath.startsWith('file://')) loraPath = loraPath.slice(7);
let loraAdapters = [];
if (loraList) loraAdapters = loraList.map(l => ({
path: l.path.replace(/file:\/\//, ''),
scaled: l.scaled
}));
const contextId = contextIdCounter + contextIdRandom();
contextIdCounter += 1;
let removeProgressListener = null;
if (onProgress) {
removeProgressListener = EventEmitter.addListener(EVENT_ON_INIT_CONTEXT_PROGRESS, evt => {
if (evt.contextId !== contextId) return;
onProgress(evt.progress);
});
}
const poolType = poolTypeMap[poolingType];
const {
gpu,
reasonNoGPU,
model: modelDetails,
androidLib
} = await RNLlama.initContext(contextId, {
model: path,
is_model_asset: !!isModelAsset,
use_progress_callback: !!onProgress,
pooling_type: poolType,
lora: loraPath,
lora_list: loraAdapters,
...rest
}).catch(err => {
var _removeProgressListen;
(_removeProgressListen = removeProgressListener) === null || _removeProgressListen === void 0 ? void 0 : _removeProgressListen.remove();
throw err;
});
(_removeProgressListen2 = removeProgressListener) === null || _removeProgressListen2 === void 0 ? void 0 : _removeProgressListen2.remove();
return new LlamaContext({
contextId,
gpu,
reasonNoGPU,
model: modelDetails,
androidLib
});
}
export async function releaseAllLlama() {
return RNLlama.releaseAllContexts();
}
//# sourceMappingURL=index.js.map