UNPKG

@rexdug7005/nvidia-llama4

Version:

Integración de NVIDIA Llama4 con LangChain.js, con soporte para Tools Agent de n8n

305 lines (304 loc) 14.7 kB
import { BaseChatModel, } from "@langchain/core/language_models/chat_models"; import { AIMessageChunk } from "@langchain/core/messages"; import { ChatGenerationChunk } from "@langchain/core/outputs"; import axios from "axios"; import { convertOptionsToNvidiaParams, convertResponseToLangChainMessage, formatMessagesForNvidia, convertToOpenAITool, processToolCallsFromContent, } from "./utils.js"; /** * Implementación del modelo de chat NVIDIA Llama4 para LangChain */ export class ChatNvidiaLlama4 extends BaseChatModel { static lc_name() { return "ChatNvidiaLlama4"; } constructor(fields) { super(fields); Object.defineProperty(this, "apiKey", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "baseUrl", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "modelName", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "defaultOptions", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "streaming", { enumerable: true, configurable: true, writable: true, value: void 0 }); this.apiKey = fields.apiKey; this.baseUrl = fields.baseUrl || "https://integrate.api.nvidia.com/v1/chat/completions"; this.modelName = fields.model || "meta/llama-4-maverick-17b-128e-instruct"; this.streaming = fields.streaming ?? false; // Extraer opciones predeterminadas eliminando las propiedades que no son opciones del modelo // eslint-disable-next-line @typescript-eslint/no-unused-vars const { apiKey, baseUrl, model, streaming, ...rest } = fields; this.defaultOptions = rest; } _llmType() { return "nvidia-llama4"; } /** * Vincula herramientas al modelo para habilitar la funcionalidad de agente * @param tools Lista de herramientas para vincular al modelo * @param kwargs Opciones adicionales para la llamada */ bindTools(tools, kwargs) { return this.bind({ tools: tools.map((tool) => convertToOpenAITool(tool)), ...kwargs, }); } /** * Obtiene los parámetros para la llamada a la API */ getParams(messages, options, streaming = false) { // Convertir las opciones a formato NVIDIA const baseOptions = convertOptionsToNvidiaParams({ ...this.defaultOptions, ...options, model: this.modelName, }); // Formatear los mensajes para la API de NVIDIA const formattedMessages = formatMessagesForNvidia(messages); // Manejar herramientas si están presentes let toolsParam; if (options.tools && options.tools.length > 0) { toolsParam = options.tools.map((tool) => convertToOpenAITool(tool)); } // Manejar tool_choice si está presente const toolChoice = options.tool_choice; // Construir el payload final return { ...baseOptions, messages: formattedMessages, stream: streaming, ...(toolsParam && { tools: toolsParam }), ...(toolChoice && { tool_choice: toolChoice }), }; } /** * Genera una respuesta sincrónica (no streaming) */ async _generate(messages, options) { const requestOptions = { headers: { "Content-Type": "application/json", Authorization: `Bearer ${this.apiKey}`, Accept: "application/json", }, }; const params = this.getParams(messages, options, false); try { const response = await axios.post(this.baseUrl, params, requestOptions); const responseData = response.data; const message = convertResponseToLangChainMessage(responseData); const generation = { text: message.content.toString(), message, generationInfo: { finishReason: responseData.choices?.[0]?.finish_reason, tokenUsage: responseData.usage, }, }; return { generations: [generation], }; } catch (error) { throw new Error(`Error al llamar a la API de NVIDIA Llama4: ${String(error)}`); } } /** * Procesa la respuesta de streaming de la API */ async *_streamResponseChunks(messages, options, runManager) { const requestOptions = { headers: { "Content-Type": "application/json", Authorization: `Bearer ${this.apiKey}`, Accept: "text/event-stream", }, responseType: "stream", }; const params = this.getParams(messages, options, true); try { const response = await axios.post(this.baseUrl, params, requestOptions); const stream = response.data; // Buffer para acumular datos del stream let buffer = ""; // Acumulador para contenido completo para detectar posibles herramientas let completeContent = ""; // Estado para seguimiento de herramientas let toolCallsDetected = false; // Buffer para herramientas usando un objeto en lugar de array const toolsBuffer = {}; for await (const chunk of stream) { const chunkText = Buffer.from(chunk).toString("utf-8"); buffer += chunkText; // Buscar eventos completos en el buffer let eventIndex; while ((eventIndex = buffer.indexOf("\n\n")) !== -1) { const eventText = buffer.substring(0, eventIndex); buffer = buffer.substring(eventIndex + 2); // +2 para saltar los dos saltos de línea if (eventText.startsWith("data: ")) { const data = eventText.substring(6); // saltar "data: " if (data === "[DONE]") { // Fin del stream // Si acumulamos contenido que podría contener herramientas, procesarlo if (completeContent && !toolCallsDetected) { const { processedContent, extractedToolCalls } = processToolCallsFromContent(completeContent); if (extractedToolCalls && extractedToolCalls.length > 0) { // Crear un nuevo chunk con las herramientas extraídas const toolChunk = new ChatGenerationChunk({ text: processedContent, message: new AIMessageChunk({ content: processedContent, tool_calls: extractedToolCalls, }), }); yield toolChunk; if (runManager) { await runManager.handleLLMNewToken("Tool calls processed"); } } } break; } try { const parsed = JSON.parse(data); const deltaContent = parsed.choices?.[0]?.delta?.content || ""; const deltaToolCalls = parsed.choices?.[0]?.delta?.tool_calls; // Manejar llamadas a herramientas explícitas if (deltaToolCalls && deltaToolCalls.length > 0) { toolCallsDetected = true; // Acumular llamadas a herramientas for (const toolCall of deltaToolCalls) { // Si es una nueva llamada a herramienta, inicializar en el buffer if (toolCall.index !== undefined) { const toolIndex = toolCall.index; if (!toolsBuffer[toolIndex]) { toolsBuffer[toolIndex] = { id: toolCall.id || `call-${toolIndex}`, type: "function", function: { name: toolCall.function?.name || "", arguments: toolCall.function?.arguments || "", }, }; } else { // Actualizar llamada existente if (toolCall.function?.name) { toolsBuffer[toolIndex].function.name = toolCall.function.name; } if (toolCall.function?.arguments) { toolsBuffer[toolIndex].function.arguments += toolCall.function.arguments; } } } } // Convertir las herramientas acumuladas al formato de LangChain const toolCalls = Object.values(toolsBuffer).map((tool) => ({ id: tool.id, type: "function", function: { name: tool.function.name, arguments: tool.function.arguments, }, })); const toolChunk = new ChatGenerationChunk({ text: completeContent, message: new AIMessageChunk({ content: completeContent, tool_calls: toolCalls, }), }); yield toolChunk; if (runManager) { await runManager.handleLLMNewToken("Tool call received"); } } // Contenido de texto normal else if (deltaContent) { completeContent += deltaContent; // Verificar si el contenido parece contener herramientas // Solo si no hemos detectado herramientas explícitas if (!toolCallsDetected && (completeContent.includes('{"function":') || completeContent.includes('"tool_calls":') || (completeContent.startsWith("[") && completeContent.includes('"name":')))) { // Esperar a acumular más contenido antes de intentar procesarlo // Solo emitimos chunk si parece un texto normal if (!completeContent.startsWith("{") && !completeContent.startsWith("[") && !completeContent.includes("```json")) { const chunkGen = new ChatGenerationChunk({ text: deltaContent, message: new AIMessageChunk({ content: deltaContent }), }); yield chunkGen; if (runManager) { await runManager.handleLLMNewToken(deltaContent); } } } else { // Contenido normal, emitir directamente const chunkGen = new ChatGenerationChunk({ text: deltaContent, message: new AIMessageChunk({ content: deltaContent }), }); yield chunkGen; if (runManager) { await runManager.handleLLMNewToken(deltaContent); } } } } catch (err) { console.error("Error al analizar chunk de evento:", err); } } } } } catch (error) { throw new Error(`Error al procesar stream de NVIDIA Llama4: ${String(error)}`); } } async _call(messages, options) { if (this.streaming) { const stream = await this._streamResponseChunks(messages, options); const chunks = []; for await (const chunk of stream) { chunks.push(chunk.text); } return chunks.join(""); } else { const result = await this._generate(messages, options); return result.generations[0].text; } } }