UNPKG

@rexdug7005/nvidia-llama4

Version:

Integración de NVIDIA Llama4 con LangChain.js

183 lines (182 loc) 7.11 kB
import { BaseChatModel, } from "@langchain/core/language_models/chat_models"; import { AIMessageChunk } from "@langchain/core/messages"; import { ChatGenerationChunk } from "@langchain/core/outputs"; import axios from "axios"; import { convertOptionsToNvidiaParams, convertResponseToLangChainMessage, formatMessagesForNvidia, } from "./utils.js"; /** * Implementación del modelo de chat NVIDIA Llama4 para LangChain */ export class ChatNvidiaLlama4 extends BaseChatModel { static lc_name() { return "ChatNvidiaLlama4"; } constructor(fields) { super(fields); Object.defineProperty(this, "apiKey", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "baseUrl", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "modelName", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "defaultOptions", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "streaming", { enumerable: true, configurable: true, writable: true, value: void 0 }); this.apiKey = fields.apiKey; this.baseUrl = fields.baseUrl || "https://integrate.api.nvidia.com/v1/chat/completions"; this.modelName = fields.model || "meta/llama-4-maverick-17b-128e-instruct"; this.streaming = fields.streaming ?? false; // Extraer opciones predeterminadas eliminando las propiedades que no son opciones del modelo // eslint-disable-next-line @typescript-eslint/no-unused-vars const { apiKey, baseUrl, model, streaming, ...rest } = fields; this.defaultOptions = rest; } _llmType() { return "nvidia-llama4"; } /** * Obtiene los parámetros para la llamada a la API */ getParams(messages, options, streaming = false) { // Convertir las opciones a formato NVIDIA const baseOptions = convertOptionsToNvidiaParams({ ...this.defaultOptions, ...options, model: this.modelName, }); // Formatear los mensajes para la API de NVIDIA const formattedMessages = formatMessagesForNvidia(messages); // Construir el payload final return { ...baseOptions, messages: formattedMessages, stream: streaming, }; } /** * Genera una respuesta sincrónica (no streaming) */ async _generate(messages, options) { const requestOptions = { headers: { "Content-Type": "application/json", Authorization: `Bearer ${this.apiKey}`, Accept: "application/json", }, }; const params = this.getParams(messages, options, false); try { const response = await axios.post(this.baseUrl, params, requestOptions); const responseData = response.data; const message = convertResponseToLangChainMessage(responseData); const generation = { text: message.content.toString(), message, generationInfo: { finishReason: responseData.choices?.[0]?.finish_reason, tokenUsage: responseData.usage, }, }; return { generations: [generation], }; } catch (error) { throw new Error(`Error al llamar a la API de NVIDIA Llama4: ${String(error)}`); } } /** * Procesa la respuesta de streaming de la API */ async *_streamResponseChunks(messages, options, runManager) { const requestOptions = { headers: { "Content-Type": "application/json", Authorization: `Bearer ${this.apiKey}`, Accept: "text/event-stream", }, responseType: "stream", }; const params = this.getParams(messages, options, true); try { const response = await axios.post(this.baseUrl, params, requestOptions); const stream = response.data; // Un buffer para acumular los datos del stream let buffer = ""; for await (const chunk of stream) { const chunkText = Buffer.from(chunk).toString("utf-8"); buffer += chunkText; // Procesar líneas completas while (buffer.includes("\n")) { const newlineIndex = buffer.indexOf("\n"); const line = buffer.substring(0, newlineIndex).trim(); buffer = buffer.substring(newlineIndex + 1); if (line.startsWith("data: ")) { const data = line.substring(6).trim(); // Fin del stream if (data === "[DONE]") { return; } try { const parsedData = JSON.parse(data); const content = parsedData.choices?.[0]?.delta?.content || ""; if (content) { const messageChunk = new AIMessageChunk({ content, }); const chunk = new ChatGenerationChunk({ text: content, message: messageChunk, generationInfo: { finishReason: parsedData.choices?.[0]?.finish_reason, }, }); yield chunk; // Notificar al manager de callbacks si existe if (runManager) { await runManager.handleLLMNewToken(content); } } } catch (error) { // Ignorar líneas no válidas continue; } } } } } catch (error) { throw new Error(`Error al procesar el stream de NVIDIA Llama4: ${String(error)}`); } } /** * Implementación del método _call requerido para los modelos de chat */ async _call(messages, options) { const result = await this._generate(messages, options); const generation = result.generations[0]; return generation.text; } }