intellinode
Version:
Create AI agents using the latest models, including ChatGPT, Llama, Diffusion, Cohere, Gemini, and Hugging Face.
597 lines (498 loc) • 23.6 kB
JavaScript
/*
Apache License
Copyright 2023 Github.com/Barqawiz/IntelliNode
Licensed under the Apache License, Version 2.0 (the "License");
*/
const OpenAIWrapper = require("../wrappers/OpenAIWrapper");
const ReplicateWrapper = require('../wrappers/ReplicateWrapper');
const AWSEndpointWrapper = require('../wrappers/AWSEndpointWrapper');
const { GPTStreamParser, CohereStreamParser, VLLMStreamParser } = require('../utils/StreamParser');
const CohereAIWrapper = require('../wrappers/CohereAIWrapper');
const IntellicloudWrapper = require("../wrappers/IntellicloudWrapper");
const MistralAIWrapper = require('../wrappers/MistralAIWrapper');
const GeminiAIWrapper = require('../wrappers/GeminiAIWrapper');
const AnthropicWrapper = require('../wrappers/AnthropicWrapper');
const SystemHelper = require("../utils/SystemHelper");
const NvidiaWrapper = require("../wrappers/NvidiaWrapper");
const VLLMWrapper = require('../wrappers/VLLMWrapper');
const {
ChatGPTInput,
ChatModelInput,
ChatGPTMessage,
ChatLLamaInput,
LLamaReplicateInput,
CohereInput,
LLamaSageInput,
MistralInput,
GeminiInput,
AnthropicInput,
NvidiaInput,
VLLMInput
} = require("../model/input/ChatModelInput");
const SupportedChatModels = {
OPENAI: "openai",
REPLICATE: "replicate",
SAGEMAKER: "sagemaker",
COHERE: "cohere",
MISTRAL: "mistral",
GEMINI: "gemini",
ANTHROPIC: "anthropic",
NVIDIA: "nvidia",
VLLM: "vllm"
};
class Chatbot {
constructor(keyValue, provider = SupportedChatModels.OPENAI, customProxyHelper = null, options = {}) {
const supportedModels = this.getSupportedModels();
if (supportedModels.includes(provider)) {
this.initiate(keyValue, provider, customProxyHelper, options);
} else {
const models = supportedModels.join(" - ");
throw new Error(
`The received keyValue is not supported. Send any model from: ${models}`
);
}
}
initiate(keyValue, provider, customProxyHelper = null, options = {}) {
this.provider = provider;
if (provider === SupportedChatModels.OPENAI) {
this.openaiWrapper = new OpenAIWrapper(keyValue, customProxyHelper);
} else if (provider === SupportedChatModels.REPLICATE) {
this.replicateWrapper = new ReplicateWrapper(keyValue);
} else if (provider === SupportedChatModels.SAGEMAKER) {
this.sagemakerWrapper = new AWSEndpointWrapper(customProxyHelper.url, keyValue);
} else if (provider === SupportedChatModels.COHERE) {
this.cohereWrapper = new CohereAIWrapper(keyValue);
} else if (provider === SupportedChatModels.MISTRAL) {
this.mistralWrapper = new MistralAIWrapper(keyValue);
} else if (provider === SupportedChatModels.GEMINI) {
this.geminiWrapper = new GeminiAIWrapper(keyValue);
} else if (provider === SupportedChatModels.ANTHROPIC) {
this.anthropicWrapper = new AnthropicWrapper(keyValue);
} else if (provider === SupportedChatModels.NVIDIA) {
const my_options = options || {};
const baseUrl = (my_options.nvidiaOptions && my_options.nvidiaOptions.baseUrl) || my_options.baseUrl;
if (baseUrl) {
this.nvidiaWrapper = new NvidiaWrapper(keyValue, { baseUrl: baseUrl });
} else {
this.nvidiaWrapper = new NvidiaWrapper(keyValue);
}
} else if (provider === SupportedChatModels.VLLM) {
const baseUrl = options.baseUrl;
if (!baseUrl) throw new Error("VLLM requires 'baseUrl' in options.");
this.vllmWrapper = new VLLMWrapper(baseUrl);
} else {
throw new Error("Invalid provider name");
}
// initiate the optional search feature
if (options && options.oneKey) {
const apiBase = options.intelliBase ? options.intelliBase : null;
this.extendedController = options.oneKey.startsWith("in") ? new IntellicloudWrapper(options.oneKey, apiBase) : null;
}
}
getSupportedModels() {
return Object.values(SupportedChatModels);
}
async chat(modelInput, functions = null, function_call = null, debugMode = true) {
// call semantic search
let references = await this.getSemanticSearchContext(modelInput);
// verify the extra params
if (this.provider != SupportedChatModels.OPENAI && (functions != null || function_call != null)) {
throw new Error('The functions and function_call are supported for chatGPT models only.');
}
// call the chatbot
if (this.provider === SupportedChatModels.OPENAI) {
const result = await this._chatGPT(modelInput, functions, function_call);
return modelInput.attachReference ? { result, references } : result;
} else if (this.provider === SupportedChatModels.REPLICATE) {
const result = await this._chatReplicateLLama(modelInput, debugMode);
return modelInput.attachReference ? { result, references } : result;
} else if (this.provider === SupportedChatModels.SAGEMAKER) {
const result = await this._chatSageMaker(modelInput);
return modelInput.attachReference ? { result, references } : result;
} else if (this.provider === SupportedChatModels.COHERE) {
const result = await this._chatCohere(modelInput);
return modelInput.attachReference ? { result, references } : result;
} else if (this.provider === SupportedChatModels.MISTRAL) {
const result = await this._chatMistral(modelInput);
return modelInput.attachReference ? { result, references } : result;
} else if (this.provider === SupportedChatModels.GEMINI) {
const result = await this._chatGemini(modelInput);
return modelInput.attachReference ? { result, references } : result;
} else if (this.provider === SupportedChatModels.ANTHROPIC) {
const result = await this._chatAnthropic(modelInput);
return modelInput.attachReference ? { result, references } : result;
} else if (this.provider === SupportedChatModels.NVIDIA) {
let result = await this._chatNvidia(modelInput);
return modelInput.attachReference ? { result: result, references } : result;
} else if (this.provider === SupportedChatModels.VLLM) {
let result = await this._chatVLLM(modelInput);
return modelInput.attachReference ? { result: result, references } : result;
} else {
throw new Error("The provider is not supported");
}
}
async *stream(modelInput) {
await this.getSemanticSearchContext(modelInput);
if (this.provider === SupportedChatModels.OPENAI) {
yield* this._chatGPTStream(modelInput);
} else if (this.provider === SupportedChatModels.COHERE) {
yield* this._streamCohere(modelInput)
} else if (this.provider === SupportedChatModels.NVIDIA) {
yield* this._streamNvidia(modelInput);
} else if (this.provider === SupportedChatModels.VLLM) {
yield* this._streamVLLM(modelInput);
} else {
throw new Error("The stream function support only chatGPT, for other providers use chat function.");
}
}
async *_streamVLLM(modelInput) {
let params = modelInput instanceof VLLMInput ? modelInput.getChatInput() : modelInput;
params.stream = true;
// Check for completion-only models
const completionOnlyModels = ["google/gemma-2-2b-it"];
const isCompletionOnly = completionOnlyModels.includes(params.model);
let stream;
if (isCompletionOnly) {
// Convert messages to prompt string
const promptMessages = params.messages
.map(msg => `${msg.role.charAt(0).toUpperCase() + msg.role.slice(1)}: ${msg.content}`)
.join("\n") + "\nAssistant:";
const completionParams = {
model: params.model,
prompt: promptMessages,
max_tokens: params.max_tokens || 100,
temperature: params.temperature || 0.7,
stream: true
};
stream = await this.vllmWrapper.generateText(completionParams);
} else {
stream = await this.vllmWrapper.generateChatText(params);
}
const streamParser = new VLLMStreamParser();
// Process the streaming response
for await (const chunk of stream) {
const chunkText = chunk.toString('utf8');
yield* streamParser.feed(chunkText);
}
}
async getSemanticSearchContext(modelInput) {
let references = {};
if (!this.extendedController) {
return references;
}
// Initialize variables for messages or prompt
let messages, lastMessage;
if (modelInput instanceof ChatLLamaInput && typeof modelInput.prompt === "string") {
messages = modelInput.prompt.split('\n').map(line => {
const role = line.startsWith('User:') ? 'user' : 'assistant';
const content = line.replace(/^(User|Assistant): /, '');
return { role, content };
});
} else if (modelInput instanceof GeminiInput) {
messages = modelInput.messages.map(message => {
const role = message.role;
const content = message.parts.map(part => part.text).join(" ");
return { role, content };
});
} else if (Array.isArray(modelInput.messages)) {
messages = modelInput.messages;
} else {
console.log('The input format does not support augmented search.');
return references;
}
lastMessage = messages[messages.length - 1];
if (lastMessage && lastMessage.role === "user") {
const semanticResult = await this.extendedController.semanticSearch(lastMessage.content, modelInput.searchK);
if (semanticResult && semanticResult.length > 0) {
references = semanticResult.reduce((acc, doc) => {
// check if the document_name exists in the accumulator
if (!acc[doc.document_name]) {
acc[doc.document_name] = { pages: [] };
}
return acc;
}, {});
let contextData = semanticResult.map(doc => doc.data.map(dataItem => dataItem.text).join('\n')).join('\n').trim();
const templateWrapper = new SystemHelper().loadStaticPrompt("augmented_chatbot");
const augmentedMessage = templateWrapper.replace('${semantic_search}', contextData).replace('${user_query}', lastMessage.content);
if (modelInput instanceof ChatLLamaInput && modelInput.prompt) {
const promptLines = modelInput.prompt.trim().split('\n');
promptLines.pop();
promptLines.push(`User: ${augmentedMessage}`);
modelInput.prompt = promptLines.join('\n');
} else if (modelInput instanceof ChatModelInput) {
modelInput.deleteLastMessage(lastMessage);
modelInput.addUserMessage(augmentedMessage);
} else if (typeof modelInput === "object" && Array.isArray(modelInput.messages) && messages.length > 0) {
// replace the user message directly in the array
if (lastMessage.content) {
lastMessage.content = augmentedMessage;
}
}
}
}
return references;
}
async _chatVLLM(modelInput) {
let params = modelInput instanceof ChatModelInput ? modelInput.getChatInput() : modelInput;
// Explicit for Gemma (completion-only model)
const completionOnlyModels = ["google/gemma-2-2b-it",];
const isCompletionOnly = completionOnlyModels.includes(params.model);
if (isCompletionOnly) {
// Convert messages to prompt string
const promptMessages = params.messages
.map(msg => `${msg.role.charAt(0).toUpperCase() + msg.role.slice(1)}: ${msg.content}`)
.join("\n") + "\nAssistant:";
const completionParams = {
model: params.model,
prompt: promptMessages,
max_tokens: params.max_tokens || 100,
temperature: params.temperature || 0.7,
};
const result = await this.vllmWrapper.generateText(completionParams);
return result.choices.map(c => c.text.trim());
} else {
const result = await this.vllmWrapper.generateChatText(params);
return result.choices.map(c => c.message.content);
}
}
async *_chatGPTStream(modelInput) {
let params;
if (modelInput instanceof ChatModelInput) {
params = modelInput.getChatInput();
params.stream = true;
} else if (typeof modelInput === "object") {
params = modelInput;
params.stream = true;
} else {
throw new Error("Invalid input: Must be an instance of ChatGPTInput or a dictionary");
}
// Check if this is GPT-5
const isGPT5 = params.model && params.model.toLowerCase().includes('gpt-5');
if (isGPT5) {
// GPT-5 doesn't support streaming in the same way
// For now, throw an error or handle as non-streaming
throw new Error("GPT-5 streaming is not yet supported. Please use the chat() method instead.");
}
const streamParser = new GPTStreamParser();
const stream = await this.openaiWrapper.generateChatText(params);
// Collect data from the stream
for await (const chunk of stream) {
const chunkText = chunk.toString('utf8');
yield* streamParser.feed(chunkText);
}
}
async _chatGPT(modelInput, functions = null, function_call = null) {
let params;
if (modelInput instanceof ChatModelInput) {
params = modelInput.getChatInput();
} else if (typeof modelInput === "object") {
params = modelInput;
} else {
throw new Error("Invalid input: Must be an instance of ChatGPTInput or a dictionary");
}
// Check if this is GPT-5
const isGPT5 = params.model && params.model.toLowerCase().includes('gpt-5');
if (isGPT5) {
// GPT-5 uses different endpoint and response format
const results = await this.openaiWrapper.generateGPT5Response(params);
// GPT-5 response format: { output: [ {type: 'reasoning'}, {type: 'message', content: [...]} ] }
if (results.output && Array.isArray(results.output)) {
// Extract text from the message content
const messageObjects = results.output.filter(item => item.type === 'message');
const responses = messageObjects.map(msg => {
if (msg.content && Array.isArray(msg.content)) {
return msg.content.map(c => c.text || c).join('');
}
return msg.content || '';
});
return responses.length > 0 ? responses : [''];
} else if (results.choices && results.choices.length > 0) {
// Fallback to choices format if available
return results.choices.map(choice => choice.output || choice.text || choice.message?.content);
}
return [''];
} else {
// Standard chat completion for GPT-4 and others
const results = await this.openaiWrapper.generateChatText(params, functions, function_call);
return results.choices.map((choice) => {
if (choice.finish_reason === 'function_call' && choice.message.function_call) {
return {
content: choice.message.content,
function_call: choice.message.function_call
};
} else {
return choice.message.content;
}
});
}
}
async _chatReplicateLLama(modelInput, debugMode) {
let params;
const waitTime = 2500,
maxIterate = 200;
let iteration = 0;
if (modelInput instanceof ChatModelInput) {
params = modelInput.getChatInput();
} else if (typeof modelInput === "object") {
params = modelInput;
} else {
throw new Error("Invalid input: Must be an instance of ChatLLamaInput or a dictionary");
}
try {
const modelName = params.model;
const inputData = params.inputData;
const prediction = await this.replicateWrapper.predict(modelName, inputData);
return new Promise((resolve, reject) => {
const poll = setInterval(async () => {
const status = await this.replicateWrapper.getPredictionStatus(prediction.id);
if (debugMode) {
console.log('The current status:', status.status);
}
if (status.status === 'succeeded' || status.status === 'failed') {
// stop the loop if prediction has completed or failed
clearInterval(poll);
if (status.status === 'succeeded') {
resolve([status.output.join('')]);
} else {
console.error('LLama prediction failed:', status.error);
reject(new Error('LLama prediction failed.'));
}
}
if (iteration > maxIterate) {
reject(new Error('Replicate taking too long to process the input, try again later!'));
}
iteration += 1
}, waitTime);
});
} catch (error) {
console.error('LLama Error:', error);
throw error;
}
}
async _chatSageMaker(modelInput) {
let params;
if (modelInput instanceof LLamaSageInput) {
params = modelInput.getChatInput();
} else if (typeof modelInput === "object") {
params = modelInput;
} else {
throw new Error("Invalid input: Must be an instance of LLamaSageInput or a dictionary");
}
const results = await this.sagemakerWrapper.predict(params);
return results.map(result => result.generation ? result.generation.content : result);
}
async _chatCohere(modelInput) {
let params;
if (modelInput instanceof CohereInput) {
params = modelInput.getChatInput();
} else if (typeof modelInput === "object") {
params = modelInput;
} else {
throw new Error("Invalid input: Must be an instance of ChatGPTInput or an object");
}
const results = await this.cohereWrapper.generateChatText(params);
const responseText = results.text;
return [responseText];
}
async *_streamCohere(modelInput) {
let params;
if (modelInput instanceof CohereInput) {
params = modelInput.getChatInput();
params.stream = true;
} else if (typeof modelInput === "object") {
params = modelInput;
params.stream = true;
} else {
throw new Error("Invalid input: Must be an instance of ChatGPTInput or a dictionary");
}
const streamParser = new CohereStreamParser();
const stream = await this.cohereWrapper.generateChatText(params);
// Collect data from the stream
for await (const chunk of stream) {
const chunkText = chunk.toString('utf8');
yield* streamParser.feed(chunkText);
}
}
async _chatMistral(modelInput) {
let params;
if (modelInput instanceof MistralInput) {
params = modelInput.getChatInput();
} if (modelInput instanceof ChatGPTInput) {
params = modelInput.getChatInput();
} else if (typeof modelInput === "object") {
params = modelInput;
} else {
throw new Error("Invalid input: Must be an instance of MistralInput or an object");
}
const results = await this.mistralWrapper.generateText(params);
return results.choices.map(choice => choice.message.content);
}
async _chatGemini(modelInput) {
let params;
if (modelInput instanceof GeminiInput) {
params = modelInput.getChatInput();
} else if (typeof modelInput === "object") {
params = modelInput;
} else {
throw new Error("Invalid input: Must be an instance of GeminiInput");
}
// call Gemini
const result = await this.geminiWrapper.generateContent(params);
if (!Array.isArray(result.candidates) || result.candidates.length === 0) {
throw new Error("Invalid response from Gemini API: Expected 'candidates' array with content");
}
// iterate over all the candidates
const responses = result.candidates.map(candidate => {
// combine text from all parts
return candidate.content.parts
.map(part => part.text)
.join(' ');
});
return responses;
}
async _chatAnthropic(modelInput) {
let params;
if (modelInput instanceof AnthropicInput) {
params = modelInput.getChatInput();
} else {
throw new Error("Invalid input: Must be an instance of AnthropicInput");
}
const results = await this.anthropicWrapper.generateText(params);
return results.content.map(choice => choice.text);
}
async _chatNvidia(modelInput) {
let params = modelInput instanceof NvidiaInput ? modelInput.getChatInput() : modelInput;
if (params.stream) throw new Error("Use stream() for NVIDIA streaming.");
let resp = await this.nvidiaWrapper.generateText(params);
return resp.choices.map(c => c.message.content);
}
async *_streamNvidia(modelInput) {
let params = modelInput instanceof NvidiaInput ? modelInput.getChatInput() : modelInput;
params.stream = true;
const stream = await this.nvidiaWrapper.generateTextStream(params);
let buffer = '';
for await (const chunk of stream) {
const lines = chunk.toString('utf8').split('\n');
for (let line of lines) {
line = line.trim();
if (!line) continue;
if (line.startsWith('data: [DONE]')) {
yield buffer;
return;
}
if (line.startsWith('data: ')) {
try {
let parsed = JSON.parse(line.replace('data: ', ''));
let content = parsed.choices?.[0]?.delta?.content || '';
buffer += content;
yield content;
} catch(e) {}
}
}
}
}
} /*chatbot class*/
module.exports = {
Chatbot,
SupportedChatModels,
};