@rexdug7005/nvidia-llama4
Version:
Integración de NVIDIA Llama4 con LangChain.js, con soporte para Tools Agent de n8n
305 lines (304 loc) • 14.7 kB
JavaScript
import { BaseChatModel, } from "@langchain/core/language_models/chat_models";
import { AIMessageChunk } from "@langchain/core/messages";
import { ChatGenerationChunk } from "@langchain/core/outputs";
import axios from "axios";
import { convertOptionsToNvidiaParams, convertResponseToLangChainMessage, formatMessagesForNvidia, convertToOpenAITool, processToolCallsFromContent, } from "./utils.js";
/**
* Implementación del modelo de chat NVIDIA Llama4 para LangChain
*/
export class ChatNvidiaLlama4 extends BaseChatModel {
static lc_name() {
return "ChatNvidiaLlama4";
}
constructor(fields) {
super(fields);
Object.defineProperty(this, "apiKey", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "baseUrl", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "modelName", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "defaultOptions", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "streaming", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.apiKey = fields.apiKey;
this.baseUrl =
fields.baseUrl || "https://integrate.api.nvidia.com/v1/chat/completions";
this.modelName = fields.model || "meta/llama-4-maverick-17b-128e-instruct";
this.streaming = fields.streaming ?? false;
// Extraer opciones predeterminadas eliminando las propiedades que no son opciones del modelo
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { apiKey, baseUrl, model, streaming, ...rest } = fields;
this.defaultOptions = rest;
}
_llmType() {
return "nvidia-llama4";
}
/**
* Vincula herramientas al modelo para habilitar la funcionalidad de agente
* @param tools Lista de herramientas para vincular al modelo
* @param kwargs Opciones adicionales para la llamada
*/
bindTools(tools, kwargs) {
return this.bind({
tools: tools.map((tool) => convertToOpenAITool(tool)),
...kwargs,
});
}
/**
* Obtiene los parámetros para la llamada a la API
*/
getParams(messages, options, streaming = false) {
// Convertir las opciones a formato NVIDIA
const baseOptions = convertOptionsToNvidiaParams({
...this.defaultOptions,
...options,
model: this.modelName,
});
// Formatear los mensajes para la API de NVIDIA
const formattedMessages = formatMessagesForNvidia(messages);
// Manejar herramientas si están presentes
let toolsParam;
if (options.tools && options.tools.length > 0) {
toolsParam = options.tools.map((tool) => convertToOpenAITool(tool));
}
// Manejar tool_choice si está presente
const toolChoice = options.tool_choice;
// Construir el payload final
return {
...baseOptions,
messages: formattedMessages,
stream: streaming,
...(toolsParam && { tools: toolsParam }),
...(toolChoice && { tool_choice: toolChoice }),
};
}
/**
* Genera una respuesta sincrónica (no streaming)
*/
async _generate(messages, options) {
const requestOptions = {
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
Accept: "application/json",
},
};
const params = this.getParams(messages, options, false);
try {
const response = await axios.post(this.baseUrl, params, requestOptions);
const responseData = response.data;
const message = convertResponseToLangChainMessage(responseData);
const generation = {
text: message.content.toString(),
message,
generationInfo: {
finishReason: responseData.choices?.[0]?.finish_reason,
tokenUsage: responseData.usage,
},
};
return {
generations: [generation],
};
}
catch (error) {
throw new Error(`Error al llamar a la API de NVIDIA Llama4: ${String(error)}`);
}
}
/**
* Procesa la respuesta de streaming de la API
*/
async *_streamResponseChunks(messages, options, runManager) {
const requestOptions = {
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
Accept: "text/event-stream",
},
responseType: "stream",
};
const params = this.getParams(messages, options, true);
try {
const response = await axios.post(this.baseUrl, params, requestOptions);
const stream = response.data;
// Buffer para acumular datos del stream
let buffer = "";
// Acumulador para contenido completo para detectar posibles herramientas
let completeContent = "";
// Estado para seguimiento de herramientas
let toolCallsDetected = false;
// Buffer para herramientas usando un objeto en lugar de array
const toolsBuffer = {};
for await (const chunk of stream) {
const chunkText = Buffer.from(chunk).toString("utf-8");
buffer += chunkText;
// Buscar eventos completos en el buffer
let eventIndex;
while ((eventIndex = buffer.indexOf("\n\n")) !== -1) {
const eventText = buffer.substring(0, eventIndex);
buffer = buffer.substring(eventIndex + 2); // +2 para saltar los dos saltos de línea
if (eventText.startsWith("data: ")) {
const data = eventText.substring(6); // saltar "data: "
if (data === "[DONE]") {
// Fin del stream
// Si acumulamos contenido que podría contener herramientas, procesarlo
if (completeContent && !toolCallsDetected) {
const { processedContent, extractedToolCalls } = processToolCallsFromContent(completeContent);
if (extractedToolCalls && extractedToolCalls.length > 0) {
// Crear un nuevo chunk con las herramientas extraídas
const toolChunk = new ChatGenerationChunk({
text: processedContent,
message: new AIMessageChunk({
content: processedContent,
tool_calls: extractedToolCalls,
}),
});
yield toolChunk;
if (runManager) {
await runManager.handleLLMNewToken("Tool calls processed");
}
}
}
break;
}
try {
const parsed = JSON.parse(data);
const deltaContent = parsed.choices?.[0]?.delta?.content || "";
const deltaToolCalls = parsed.choices?.[0]?.delta?.tool_calls;
// Manejar llamadas a herramientas explícitas
if (deltaToolCalls && deltaToolCalls.length > 0) {
toolCallsDetected = true;
// Acumular llamadas a herramientas
for (const toolCall of deltaToolCalls) {
// Si es una nueva llamada a herramienta, inicializar en el buffer
if (toolCall.index !== undefined) {
const toolIndex = toolCall.index;
if (!toolsBuffer[toolIndex]) {
toolsBuffer[toolIndex] = {
id: toolCall.id || `call-${toolIndex}`,
type: "function",
function: {
name: toolCall.function?.name || "",
arguments: toolCall.function?.arguments || "",
},
};
}
else {
// Actualizar llamada existente
if (toolCall.function?.name) {
toolsBuffer[toolIndex].function.name =
toolCall.function.name;
}
if (toolCall.function?.arguments) {
toolsBuffer[toolIndex].function.arguments +=
toolCall.function.arguments;
}
}
}
}
// Convertir las herramientas acumuladas al formato de LangChain
const toolCalls = Object.values(toolsBuffer).map((tool) => ({
id: tool.id,
type: "function",
function: {
name: tool.function.name,
arguments: tool.function.arguments,
},
}));
const toolChunk = new ChatGenerationChunk({
text: completeContent,
message: new AIMessageChunk({
content: completeContent,
tool_calls: toolCalls,
}),
});
yield toolChunk;
if (runManager) {
await runManager.handleLLMNewToken("Tool call received");
}
}
// Contenido de texto normal
else if (deltaContent) {
completeContent += deltaContent;
// Verificar si el contenido parece contener herramientas
// Solo si no hemos detectado herramientas explícitas
if (!toolCallsDetected &&
(completeContent.includes('{"function":') ||
completeContent.includes('"tool_calls":') ||
(completeContent.startsWith("[") &&
completeContent.includes('"name":')))) {
// Esperar a acumular más contenido antes de intentar procesarlo
// Solo emitimos chunk si parece un texto normal
if (!completeContent.startsWith("{") &&
!completeContent.startsWith("[") &&
!completeContent.includes("```json")) {
const chunkGen = new ChatGenerationChunk({
text: deltaContent,
message: new AIMessageChunk({ content: deltaContent }),
});
yield chunkGen;
if (runManager) {
await runManager.handleLLMNewToken(deltaContent);
}
}
}
else {
// Contenido normal, emitir directamente
const chunkGen = new ChatGenerationChunk({
text: deltaContent,
message: new AIMessageChunk({ content: deltaContent }),
});
yield chunkGen;
if (runManager) {
await runManager.handleLLMNewToken(deltaContent);
}
}
}
}
catch (err) {
console.error("Error al analizar chunk de evento:", err);
}
}
}
}
}
catch (error) {
throw new Error(`Error al procesar stream de NVIDIA Llama4: ${String(error)}`);
}
}
async _call(messages, options) {
if (this.streaming) {
const stream = await this._streamResponseChunks(messages, options);
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk.text);
}
return chunks.join("");
}
else {
const result = await this._generate(messages, options);
return result.generations[0].text;
}
}
}