@rexdug7005/nvidia-llama4
Version:
Integración de NVIDIA Llama4 con LangChain.js, con soporte para Tools Agent de n8n
419 lines (418 loc) • 16.2 kB
JavaScript
import { BaseChatModel, } from "@langchain/core/language_models/chat_models";
import { AIMessageChunk } from "@langchain/core/messages";
import { ChatGenerationChunk } from "@langchain/core/outputs";
import axios from "axios";
import { zodToJsonSchema } from "zod-to-json-schema";
/**
* Implementación mejorada del modelo de chat NVIDIA Llama4 para n8n
* Optimizada para trabajar con Tools Agent
*/
export class ChatNvidiaLlama4Tools extends BaseChatModel {
static lc_name() {
return "ChatNvidiaLlama4Tools";
}
constructor(fields) {
super(fields);
Object.defineProperty(this, "apiKey", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "baseUrl", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "modelName", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "streaming", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "defaultOptions", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
// Guardar herramientas vinculadas para usarlas en cada llamada
Object.defineProperty(this, "linkedTools", {
enumerable: true,
configurable: true,
writable: true,
value: []
});
// Flag para compatibilidad con n8n Tools Agent
Object.defineProperty(this, "toolCallModel", {
enumerable: true,
configurable: true,
writable: true,
value: true
});
this.apiKey = fields.apiKey;
this.baseUrl =
fields.baseUrl || "https://integrate.api.nvidia.com/v1/chat/completions";
this.modelName = fields.model || "meta/llama-4-maverick-17b-128e-instruct";
this.streaming = fields.streaming ?? false;
// Si se proporcionan herramientas iniciales, guardarlas
if (Array.isArray(fields.tools)) {
this.linkedTools = fields.tools;
}
// Extraer opciones predeterminadas eliminando las propiedades que no son opciones del modelo
const { apiKey, baseUrl, model, streaming, tools, toolChoice, ...rest } = fields;
// Incluir herramientas en las opciones predeterminadas
this.defaultOptions = {
...rest,
tools: true,
toolChoice: "auto",
};
}
_llmType() {
return "nvidia-llama4-n8n-tools";
}
/**
* Método requerido por n8n Tools Agent para vincular herramientas
*/
bindTools(tools) {
if (tools && tools.length > 0) {
this.linkedTools = tools;
}
// Devolver this mantiene compatibilidad con n8n
return this;
}
/**
* Convierte una herramienta al formato compatible con OpenAI/NVIDIA
*/
convertToOpenAITool(tool) {
// Si ya está en formato OpenAI
if ("type" in tool && "function" in tool) {
return tool;
}
// Convertir desde formato LangChain
let schema = {};
if ("schema" in tool && tool.schema) {
if (typeof tool.schema === "object") {
try {
if ("schema" in tool.schema &&
typeof tool.schema.schema === "function") {
schema = tool.schema.schema();
}
else {
schema = zodToJsonSchema(tool.schema);
}
}
catch (e) {
console.error("Error al convertir schema de herramienta:", e);
schema = {};
}
}
}
return {
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: schema,
},
};
}
/**
* Prepara los mensajes para la API de NVIDIA
*/
formatMessagesForNvidia(messages) {
return messages.map((message) => {
const messageType = message._getType();
if (messageType === "system") {
return {
role: "system",
content: message.content,
};
}
else if (messageType === "human") {
// Manejar contenido multimodal para HumanMessage
if (typeof message.content === "string") {
return {
role: "user",
content: message.content,
};
}
else {
// Procesar contenido multimodal (texto + imagen)
const content = [];
// Manejo de contenido multimodal
const parts = message.content;
for (const part of parts) {
if (part.type === "text") {
content.push(part.text);
}
else if (part.type === "image_url") {
content.push({
type: "image",
image_url: {
url: part.image_url.url,
},
});
}
}
return {
role: "user",
content,
};
}
}
else if (messageType === "ai") {
return {
role: "assistant",
content: message.content.toString(),
};
}
else if (message.role === "chat") {
// Manejo de ChatMessage con tipo personalizado
const role = message.role;
let nvidiaRole = "user";
if (role === "system") {
nvidiaRole = "system";
}
else if (role === "assistant") {
nvidiaRole = "assistant";
}
return {
role: nvidiaRole,
content: message.content,
};
}
else {
return {
role: "user",
content: message.content.toString(),
};
}
});
}
/**
* Obtiene los parámetros para la llamada a la API
*/
getParams(messages, options, streaming = false) {
// Extraer opciones para adaptar al formato de NVIDIA
const { temperature, maxTokens, topP, topK, frequencyPenalty, presencePenalty, stop, images, tools: callTools, tool_choice: callToolChoice, ...restOptions } = { ...this.defaultOptions, ...options };
// Opciones base para la API
const baseOptions = {
model: this.modelName,
stream: streaming,
};
// Añadir parámetros opcionales solo si están definidos
if (temperature !== undefined)
baseOptions.temperature = temperature;
if (maxTokens !== undefined)
baseOptions.max_tokens = maxTokens;
if (topP !== undefined)
baseOptions.top_p = topP;
if (topK !== undefined)
baseOptions.top_k = topK;
if (frequencyPenalty !== undefined)
baseOptions.frequency_penalty = frequencyPenalty;
if (presencePenalty !== undefined)
baseOptions.presence_penalty = presencePenalty;
if (stop !== undefined)
baseOptions.stop = stop;
if (images !== undefined)
baseOptions.images = images;
// Formatear los mensajes para la API de NVIDIA
const formattedMessages = this.formatMessagesForNvidia(messages);
// Determinar qué herramientas usar:
// 1. Priorizar las herramientas pasadas en la llamada
// 2. Usar las herramientas vinculadas previamente
// 3. Si hay herramientas, configurar el modelo para usarlas
const toolsToUse = callTools || this.linkedTools;
let toolsParam;
let toolChoice = callToolChoice;
if (toolsToUse && toolsToUse.length > 0) {
toolsParam = toolsToUse.map((tool) => this.convertToOpenAITool(tool));
// Si no se especificó toolChoice, usar "auto" por defecto cuando hay herramientas
if (toolChoice === undefined) {
toolChoice = "auto";
}
}
// Construir el payload final con todas las opciones
return {
...baseOptions,
...restOptions,
messages: formattedMessages,
...(toolsParam && { tools: toolsParam }),
...(toolChoice && { tool_choice: toolChoice }),
};
}
/**
* Procesa una respuesta de la API a formato LangChain
*/
convertResponseToMessage(responseData) {
// Implementación simplificada - devuelve el contenido directamente
const content = responseData.choices?.[0]?.message?.content || "";
const toolCalls = responseData.choices?.[0]?.message?.tool_calls;
// Si hay tool_calls, convertirlos al formato esperado por LangChain
if (toolCalls && toolCalls.length > 0) {
return {
content,
tool_calls: toolCalls.map((tc) => {
let args;
try {
args = JSON.parse(tc.function.arguments);
}
catch (e) {
args = { raw: tc.function.arguments };
}
return {
id: tc.id,
type: "function",
name: tc.function.name,
args,
};
}),
additional_kwargs: {
finish_reason: responseData.choices?.[0]?.finish_reason,
token_usage: responseData.usage,
},
};
}
// Si no hay tool_calls, devolver contenido simple
return {
content,
additional_kwargs: {
finish_reason: responseData.choices?.[0]?.finish_reason,
token_usage: responseData.usage,
},
};
}
/**
* Genera una respuesta sincrónica (no streaming)
*/
async _generate(messages, options = {}) {
const requestOptions = {
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
Accept: "application/json",
},
};
const params = this.getParams(messages, options, false);
try {
const response = await axios.post(this.baseUrl, params, requestOptions);
const responseData = response.data;
const message = this.convertResponseToMessage(responseData);
const generation = {
text: typeof message.content === "string"
? message.content
: JSON.stringify(message.content),
message,
generationInfo: {
finishReason: responseData.choices?.[0]?.finish_reason,
tokenUsage: responseData.usage,
},
};
return {
generations: [generation],
};
}
catch (error) {
throw new Error(`Error al llamar a la API de NVIDIA Llama4: ${String(error)}`);
}
}
/**
* Procesa la respuesta de streaming de la API
*/
async *_streamResponseChunks(messages, options = {}, runManager) {
const requestOptions = {
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
Accept: "text/event-stream",
},
responseType: "stream",
};
const params = this.getParams(messages, options, true);
try {
const response = await axios.post(this.baseUrl, params, requestOptions);
const stream = response.data;
// Buffer para acumular datos del stream
let buffer = "";
for await (const chunk of stream) {
const chunkText = Buffer.from(chunk).toString("utf-8");
buffer += chunkText;
// Procesar eventos completos
let eventIndex;
while ((eventIndex = buffer.indexOf("\n\n")) !== -1) {
const eventText = buffer.substring(0, eventIndex);
buffer = buffer.substring(eventIndex + 2);
if (eventText.startsWith("data: ")) {
const data = eventText.substring(6);
if (data === "[DONE]") {
break;
}
try {
const parsed = JSON.parse(data);
const deltaContent = parsed.choices?.[0]?.delta?.content || "";
const deltaToolCalls = parsed.choices?.[0]?.delta?.tool_calls;
// Manejar contenido de texto normal
if (deltaContent) {
const chunkGen = new ChatGenerationChunk({
text: deltaContent,
message: new AIMessageChunk({ content: deltaContent }),
});
yield chunkGen;
if (runManager) {
await runManager.handleLLMNewToken(deltaContent);
}
}
// Manejar tool_calls
if (deltaToolCalls && deltaToolCalls.length > 0) {
const toolContent = JSON.stringify(deltaToolCalls);
yield new ChatGenerationChunk({
text: toolContent,
message: new AIMessageChunk({
content: "",
// LangChain espera un formato específico
tool_calls: deltaToolCalls,
}),
});
if (runManager) {
await runManager.handleLLMNewToken("Tool call received");
}
}
}
catch (err) {
console.error("Error al analizar chunk de evento:", err);
}
}
}
}
}
catch (error) {
throw new Error(`Error al procesar stream de NVIDIA Llama4: ${String(error)}`);
}
}
/**
* Método principal para llamar al modelo
*/
async _call(messages, options = {}) {
if (this.streaming) {
const stream = await this._streamResponseChunks(messages, options);
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk.text);
}
return chunks.join("");
}
else {
const result = await this._generate(messages, options);
return result.generations[0].text;
}
}
}