UNPKG

@react-native-ai/mlc

Version:
235 lines (234 loc) 6.9 kB
"use strict"; import NativeMLCEngine from './NativeMLCEngine'; export const mlc = { languageModel: (modelId = 'Llama-3.2-3B-Instruct') => { return new MlcChatLanguageModel(modelId); } }; const convertToolsToNativeFormat = tools => { return tools.filter(tool => tool.type === 'function').map(tool => { const parameters = {}; if (tool.inputSchema.properties) { Object.entries(tool.inputSchema.properties).forEach(([key, value]) => { if (!value) { return; } parameters[key] = value?.description || ''; }); } return { type: 'function', function: { name: tool.name, description: tool.description, parameters } }; }); }; const convertToolChoice = toolChoice => { if (!toolChoice) { return 'none'; } if (toolChoice.type === 'none' || toolChoice.type === 'auto') { return toolChoice.type; } console.warn(`Unsupported toolChoice value: ${JSON.stringify(toolChoice)}. Defaulting to 'none'.`); return undefined; }; const convertFinishReason = finishReason => { let unified = 'other'; if (finishReason === 'tool_calls') { unified = 'tool-calls'; } else if (finishReason === 'stop') { unified = 'stop'; } else if (finishReason === 'length') { unified = 'length'; } return { unified, raw: finishReason }; }; class MlcChatLanguageModel { specificationVersion = 'v3'; supportedUrls = {}; provider = 'mlc'; constructor(modelId) { this.modelId = modelId; } prepare() { return NativeMLCEngine.prepareModel(this.modelId); } async download(progressCallback) { const removeListener = NativeMLCEngine.onDownloadProgress(event => { progressCallback?.(event); }); await NativeMLCEngine.downloadModel(this.modelId); removeListener.remove(); } unload() { return NativeMLCEngine.unloadModel(); } remove() { return NativeMLCEngine.removeModel(this.modelId); } prepareMessages(messages) { return messages.map(message => { const content = Array.isArray(message.content) ? message.content.reduce((acc, part) => { if (part.type === 'text') { return acc + part.text; } console.warn('Unsupported message content type:', part); return acc; }, '') : message.content; return { role: message.role, content }; }); } async doGenerate(options) { const messages = this.prepareMessages(options.prompt); const generationOptions = { temperature: options.temperature, maxTokens: options.maxOutputTokens, topP: options.topP, topK: options.topK, responseFormat: options.responseFormat?.type === 'json' ? { type: 'json_object', schema: JSON.stringify(options.responseFormat.schema) } : undefined, tools: convertToolsToNativeFormat(options.tools || []), toolChoice: convertToolChoice(options.toolChoice) }; const response = await NativeMLCEngine.generateText(messages, generationOptions); return { content: [{ type: 'text', text: response.content }, ...response.tool_calls.map(toolCall => ({ type: 'tool-call', toolCallId: toolCall.id, toolName: toolCall.function.name, input: JSON.stringify(toolCall.function.arguments || {}) }))], finishReason: convertFinishReason(response.finish_reason), usage: { inputTokens: { total: response.usage.prompt_tokens, noCache: undefined, cacheRead: undefined, cacheWrite: undefined }, outputTokens: { total: response.usage.completion_tokens, text: undefined, reasoning: undefined } }, providerMetadata: { mlc: { extraUsage: { ...response.usage.extra } } }, warnings: [] }; } async doStream(options) { const messages = this.prepareMessages(options.prompt); if (typeof ReadableStream === 'undefined') { throw new Error(`ReadableStream is not available in this environment. Please load a polyfill, such as web-streams-polyfill.`); } const generationOptions = { temperature: options.temperature, maxTokens: options.maxOutputTokens, topP: options.topP, topK: options.topK, responseFormat: options.responseFormat?.type === 'json' ? { type: 'json_object', schema: JSON.stringify(options.responseFormat.schema) } : undefined, tools: convertToolsToNativeFormat(options.tools || []), toolChoice: convertToolChoice(options.toolChoice) }; let streamId; let listeners = []; const cleanup = () => { listeners.forEach(listener => listener.remove()); listeners = []; }; const stream = new ReadableStream({ async start(controller) { try { const id = streamId = await NativeMLCEngine.streamText(messages, generationOptions); controller.enqueue({ type: 'text-start', id }); const updateListener = NativeMLCEngine.onChatUpdate(data => { if (data.delta?.content) { controller.enqueue({ type: 'text-delta', delta: data.delta.content, id }); } }); const completeListener = NativeMLCEngine.onChatComplete(data => { controller.enqueue({ type: 'text-end', id }); controller.enqueue({ type: 'finish', finishReason: convertFinishReason(data.finish_reason), usage: { inputTokens: { total: data.usage.prompt_tokens, noCache: undefined, cacheRead: undefined, cacheWrite: undefined }, outputTokens: { total: data.usage.completion_tokens, text: undefined, reasoning: undefined } }, providerMetadata: { mlc: { extraUsage: { ...data.usage.extra } } } }); cleanup(); controller.close(); }); listeners = [updateListener, completeListener]; } catch (error) { cleanup(); controller.error(new Error(`MLC stream failed: ${error}`)); } }, cancel() { cleanup(); if (streamId) { NativeMLCEngine.cancelStream(streamId); } } }); return { stream, rawCall: { rawPrompt: options.prompt, rawSettings: {} } }; } } //# sourceMappingURL=ai-sdk.js.map