@react-native-ai/mlc
Version:
MLC LLM provider for Vercel AI SDK
235 lines (234 loc) • 6.9 kB
JavaScript
;
import NativeMLCEngine from './NativeMLCEngine';
export const mlc = {
languageModel: (modelId = 'Llama-3.2-3B-Instruct') => {
return new MlcChatLanguageModel(modelId);
}
};
const convertToolsToNativeFormat = tools => {
return tools.filter(tool => tool.type === 'function').map(tool => {
const parameters = {};
if (tool.inputSchema.properties) {
Object.entries(tool.inputSchema.properties).forEach(([key, value]) => {
if (!value) {
return;
}
parameters[key] = value?.description || '';
});
}
return {
type: 'function',
function: {
name: tool.name,
description: tool.description,
parameters
}
};
});
};
const convertToolChoice = toolChoice => {
if (!toolChoice) {
return 'none';
}
if (toolChoice.type === 'none' || toolChoice.type === 'auto') {
return toolChoice.type;
}
console.warn(`Unsupported toolChoice value: ${JSON.stringify(toolChoice)}. Defaulting to 'none'.`);
return undefined;
};
const convertFinishReason = finishReason => {
let unified = 'other';
if (finishReason === 'tool_calls') {
unified = 'tool-calls';
} else if (finishReason === 'stop') {
unified = 'stop';
} else if (finishReason === 'length') {
unified = 'length';
}
return {
unified,
raw: finishReason
};
};
class MlcChatLanguageModel {
specificationVersion = 'v3';
supportedUrls = {};
provider = 'mlc';
constructor(modelId) {
this.modelId = modelId;
}
prepare() {
return NativeMLCEngine.prepareModel(this.modelId);
}
async download(progressCallback) {
const removeListener = NativeMLCEngine.onDownloadProgress(event => {
progressCallback?.(event);
});
await NativeMLCEngine.downloadModel(this.modelId);
removeListener.remove();
}
unload() {
return NativeMLCEngine.unloadModel();
}
remove() {
return NativeMLCEngine.removeModel(this.modelId);
}
prepareMessages(messages) {
return messages.map(message => {
const content = Array.isArray(message.content) ? message.content.reduce((acc, part) => {
if (part.type === 'text') {
return acc + part.text;
}
console.warn('Unsupported message content type:', part);
return acc;
}, '') : message.content;
return {
role: message.role,
content
};
});
}
async doGenerate(options) {
const messages = this.prepareMessages(options.prompt);
const generationOptions = {
temperature: options.temperature,
maxTokens: options.maxOutputTokens,
topP: options.topP,
topK: options.topK,
responseFormat: options.responseFormat?.type === 'json' ? {
type: 'json_object',
schema: JSON.stringify(options.responseFormat.schema)
} : undefined,
tools: convertToolsToNativeFormat(options.tools || []),
toolChoice: convertToolChoice(options.toolChoice)
};
const response = await NativeMLCEngine.generateText(messages, generationOptions);
return {
content: [{
type: 'text',
text: response.content
}, ...response.tool_calls.map(toolCall => ({
type: 'tool-call',
toolCallId: toolCall.id,
toolName: toolCall.function.name,
input: JSON.stringify(toolCall.function.arguments || {})
}))],
finishReason: convertFinishReason(response.finish_reason),
usage: {
inputTokens: {
total: response.usage.prompt_tokens,
noCache: undefined,
cacheRead: undefined,
cacheWrite: undefined
},
outputTokens: {
total: response.usage.completion_tokens,
text: undefined,
reasoning: undefined
}
},
providerMetadata: {
mlc: {
extraUsage: {
...response.usage.extra
}
}
},
warnings: []
};
}
async doStream(options) {
const messages = this.prepareMessages(options.prompt);
if (typeof ReadableStream === 'undefined') {
throw new Error(`ReadableStream is not available in this environment. Please load a polyfill, such as web-streams-polyfill.`);
}
const generationOptions = {
temperature: options.temperature,
maxTokens: options.maxOutputTokens,
topP: options.topP,
topK: options.topK,
responseFormat: options.responseFormat?.type === 'json' ? {
type: 'json_object',
schema: JSON.stringify(options.responseFormat.schema)
} : undefined,
tools: convertToolsToNativeFormat(options.tools || []),
toolChoice: convertToolChoice(options.toolChoice)
};
let streamId;
let listeners = [];
const cleanup = () => {
listeners.forEach(listener => listener.remove());
listeners = [];
};
const stream = new ReadableStream({
async start(controller) {
try {
const id = streamId = await NativeMLCEngine.streamText(messages, generationOptions);
controller.enqueue({
type: 'text-start',
id
});
const updateListener = NativeMLCEngine.onChatUpdate(data => {
if (data.delta?.content) {
controller.enqueue({
type: 'text-delta',
delta: data.delta.content,
id
});
}
});
const completeListener = NativeMLCEngine.onChatComplete(data => {
controller.enqueue({
type: 'text-end',
id
});
controller.enqueue({
type: 'finish',
finishReason: convertFinishReason(data.finish_reason),
usage: {
inputTokens: {
total: data.usage.prompt_tokens,
noCache: undefined,
cacheRead: undefined,
cacheWrite: undefined
},
outputTokens: {
total: data.usage.completion_tokens,
text: undefined,
reasoning: undefined
}
},
providerMetadata: {
mlc: {
extraUsage: {
...data.usage.extra
}
}
}
});
cleanup();
controller.close();
});
listeners = [updateListener, completeListener];
} catch (error) {
cleanup();
controller.error(new Error(`MLC stream failed: ${error}`));
}
},
cancel() {
cleanup();
if (streamId) {
NativeMLCEngine.cancelStream(streamId);
}
}
});
return {
stream,
rawCall: {
rawPrompt: options.prompt,
rawSettings: {}
}
};
}
}
//# sourceMappingURL=ai-sdk.js.map