@rexdug7005/nvidia-llama4
Version:
Integración de NVIDIA Llama4 con LangChain.js
325 lines (320 loc) • 11.7 kB
JavaScript
'use strict';
var chat_models = require('@langchain/core/language_models/chat_models');
var messages = require('@langchain/core/messages');
var outputs = require('@langchain/core/outputs');
var axios = require('axios');
var zod = require('zod');
/**
* Convierte opciones en formato camelCase a los parámetros esperados por la API de NVIDIA
*/
function convertOptionsToNvidiaParams(options) {
const result = {};
// Mapeo de nombres camelCase a los nombres de la API
if (options.model !== undefined)
result.model = options.model;
if (options.maxTokens !== undefined)
result.max_tokens = options.maxTokens;
if (options.temperature !== undefined)
result.temperature = options.temperature;
if (options.topP !== undefined)
result.top_p = options.topP;
if (options.topK !== undefined)
result.top_k = options.topK;
if (options.presencePenalty !== undefined)
result.presence_penalty = options.presencePenalty;
if (options.frequencyPenalty !== undefined)
result.frequency_penalty = options.frequencyPenalty;
if (options.stop !== undefined)
result.stop = options.stop;
if (options.images !== undefined)
result.images = options.images;
return result;
}
/**
* Definición del tipo para los mensajes en formato NVIDIA
*/
zod.z.object({
role: zod.z.enum(["system", "user", "assistant"]),
content: zod.z.string().or(zod.z.array(zod.z.union([
zod.z.string(),
zod.z.object({
type: zod.z.literal("image"),
image_url: zod.z.object({
url: zod.z.string(),
}),
}),
]))),
});
/**
* Formatea los mensajes de LangChain para la API de NVIDIA
*/
function formatMessagesForNvidia(messages) {
return messages.map((message) => {
// Convertir de mensajes de LangChain a formato NVIDIA
const messageType = message.constructor.name;
if (messageType === "SystemMessage") {
return {
role: "system",
content: message.content,
};
}
else if (messageType === "HumanMessage") {
// Manejar contenido multimodal para HumanMessage
if (typeof message.content === "string") {
return {
role: "user",
content: message.content,
};
}
else {
// Procesar contenido multimodal (texto + imagen)
const content = [];
const parts = message.content;
for (const part of parts) {
if (part.type === "text") {
content.push(part.text);
}
else if (part.type === "image_url") {
content.push({
type: "image",
image_url: {
url: part.image_url.url,
},
});
}
}
return {
role: "user",
content,
};
}
}
else if (messageType === "AIMessage") {
return {
role: "assistant",
content: message.content.toString(),
};
}
else if (messageType === "ChatMessage") {
// Mapear los roles de ChatMessage a los roles de NVIDIA
let role = "user";
const chatMessage = message;
if (chatMessage.role === "system") {
role = "system";
}
else if (chatMessage.role === "assistant") {
role = "assistant";
}
else {
// Por defecto, asignar cualquier otro rol como "user"
role = "user";
}
return {
role,
content: message.content,
};
}
else {
// Para cualquier otro tipo de mensaje, usar el rol de usuario
return {
role: "user",
content: message.content.toString(),
};
}
});
}
/**
* Convierte la respuesta de NVIDIA a un mensaje de LangChain
*/
function convertResponseToLangChainMessage(response) {
// Extraer el contenido del mensaje de la respuesta
const responseObj = response;
const content = responseObj.choices?.[0]?.message?.content || "";
// Crear un mensaje de IA con el contenido extraído
return new messages.AIMessage({
content,
// Opcional: Incluir metadatos adicionales si están disponibles
additional_kwargs: {
finish_reason: responseObj.choices?.[0]?.finish_reason,
token_usage: responseObj.usage,
},
});
}
/**
* Implementación del modelo de chat NVIDIA Llama4 para LangChain
*/
class ChatNvidiaLlama4 extends chat_models.BaseChatModel {
static lc_name() {
return "ChatNvidiaLlama4";
}
constructor(fields) {
super(fields);
Object.defineProperty(this, "apiKey", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "baseUrl", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "modelName", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "defaultOptions", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "streaming", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.apiKey = fields.apiKey;
this.baseUrl =
fields.baseUrl || "https://integrate.api.nvidia.com/v1/chat/completions";
this.modelName = fields.model || "meta/llama-4-maverick-17b-128e-instruct";
this.streaming = fields.streaming ?? false;
// Extraer opciones predeterminadas eliminando las propiedades que no son opciones del modelo
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { apiKey, baseUrl, model, streaming, ...rest } = fields;
this.defaultOptions = rest;
}
_llmType() {
return "nvidia-llama4";
}
/**
* Obtiene los parámetros para la llamada a la API
*/
getParams(messages, options, streaming = false) {
// Convertir las opciones a formato NVIDIA
const baseOptions = convertOptionsToNvidiaParams({
...this.defaultOptions,
...options,
model: this.modelName,
});
// Formatear los mensajes para la API de NVIDIA
const formattedMessages = formatMessagesForNvidia(messages);
// Construir el payload final
return {
...baseOptions,
messages: formattedMessages,
stream: streaming,
};
}
/**
* Genera una respuesta sincrónica (no streaming)
*/
async _generate(messages, options) {
const requestOptions = {
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
Accept: "application/json",
},
};
const params = this.getParams(messages, options, false);
try {
const response = await axios.post(this.baseUrl, params, requestOptions);
const responseData = response.data;
const message = convertResponseToLangChainMessage(responseData);
const generation = {
text: message.content.toString(),
message,
generationInfo: {
finishReason: responseData.choices?.[0]?.finish_reason,
tokenUsage: responseData.usage,
},
};
return {
generations: [generation],
};
}
catch (error) {
throw new Error(`Error al llamar a la API de NVIDIA Llama4: ${String(error)}`);
}
}
/**
* Procesa la respuesta de streaming de la API
*/
async *_streamResponseChunks(messages$1, options, runManager) {
const requestOptions = {
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
Accept: "text/event-stream",
},
responseType: "stream",
};
const params = this.getParams(messages$1, options, true);
try {
const response = await axios.post(this.baseUrl, params, requestOptions);
const stream = response.data;
// Un buffer para acumular los datos del stream
let buffer = "";
for await (const chunk of stream) {
const chunkText = Buffer.from(chunk).toString("utf-8");
buffer += chunkText;
// Procesar líneas completas
while (buffer.includes("\n")) {
const newlineIndex = buffer.indexOf("\n");
const line = buffer.substring(0, newlineIndex).trim();
buffer = buffer.substring(newlineIndex + 1);
if (line.startsWith("data: ")) {
const data = line.substring(6).trim();
// Fin del stream
if (data === "[DONE]") {
return;
}
try {
const parsedData = JSON.parse(data);
const content = parsedData.choices?.[0]?.delta?.content || "";
if (content) {
const messageChunk = new messages.AIMessageChunk({
content,
});
const chunk = new outputs.ChatGenerationChunk({
text: content,
message: messageChunk,
generationInfo: {
finishReason: parsedData.choices?.[0]?.finish_reason,
},
});
yield chunk;
// Notificar al manager de callbacks si existe
if (runManager) {
await runManager.handleLLMNewToken(content);
}
}
}
catch (error) {
// Ignorar líneas no válidas
continue;
}
}
}
}
}
catch (error) {
throw new Error(`Error al procesar el stream de NVIDIA Llama4: ${String(error)}`);
}
}
/**
* Implementación del método _call requerido para los modelos de chat
*/
async _call(messages, options) {
const result = await this._generate(messages, options);
const generation = result.generations[0];
return generation.text;
}
}
exports.ChatNvidiaLlama4 = ChatNvidiaLlama4;