intellinode
Version:
Create AI agents using the latest models, including ChatGPT, Llama, Diffusion, Cohere, Gemini, and Hugging Face.
150 lines (132 loc) • 5.57 kB
JavaScript
const OpenAIWrapper = require('../wrappers/OpenAIWrapper');
const CohereAIWrapper = require('../wrappers/CohereAIWrapper');
const ReplicateWrapper = require('../wrappers/ReplicateWrapper');
const GeminiAIWrapper = require('../wrappers/GeminiAIWrapper');
const EmbedInput = require('../model/input/EmbedInput');
const VLLMWrapper = require('../wrappers/VLLMWrapper');
const SupportedEmbedModels = {
OPENAI: 'openai',
COHERE: 'cohere',
REPLICATE: 'replicate',
GEMINI: 'gemini',
NVIDIA: 'nvidia',
VLLM: "vllm"
};
class RemoteEmbedModel {
constructor(keyValue, provider, customProxyHelper = null) {
if (!provider) {
provider = SupportedEmbedModels.OPENAI;
}
const supportedModels = this.getSupportedModels();
if (supportedModels.includes(provider)) {
this.initiate(keyValue, provider, customProxyHelper);
} else {
const models = supportedModels.join(' - ');
throw new Error(`The received keyValue is not supported. Send any model from: ${models}`);
}
}
initiate(keyValue, keyType, customProxyHelper = null) {
this.keyType = keyType;
if (keyType === SupportedEmbedModels.OPENAI) {
this.openaiWrapper = new OpenAIWrapper(keyValue, customProxyHelper);
} else if (keyType === SupportedEmbedModels.COHERE) {
this.cohereWrapper = new CohereAIWrapper(keyValue);
} else if (keyType === SupportedEmbedModels.REPLICATE) {
this.replicateWrapper = new ReplicateWrapper(keyValue);
} else if (keyType === SupportedEmbedModels.GEMINI) {
this.geminiWrapper = new GeminiAIWrapper(keyValue);
} else if (keyType === SupportedEmbedModels.NVIDIA) {
this.nvidiaWrapper = new NvidiaWrapper(keyValue, customProxyHelper);
} else if (keyType === SupportedEmbedModels.VLLM) {
const baseUrl = customProxyHelper.baseUrl;
this.vllmWrapper = new VLLMWrapper(baseUrl);
} else {
throw new Error('Invalid provider name');
}
}
getSupportedModels() {
return Object.values(SupportedEmbedModels);
}
async getEmbeddings(embedInput) {
let inputs;
if (embedInput instanceof EmbedInput) {
if (this.keyType === SupportedEmbedModels.OPENAI) {
inputs = embedInput.getOpenAIInputs();
} else if (this.keyType === SupportedEmbedModels.COHERE) {
inputs = embedInput.getCohereInputs();
} else if (this.keyType === SupportedEmbedModels.REPLICATE) {
inputs = embedInput.getLlamaReplicateInput();
} else if (this.keyType === SupportedEmbedModels.GEMINI) {
inputs = embedInput.getGeminiInputs();
} else if (this.keyType === SupportedEmbedModels.NVIDIA) {
inputs = embedInput.getNvidiaInputs();
} else if (this.keyType === SupportedEmbedModels.VLLM) {
inputs = embedInput.getVLLMInputs();
} else {
throw new Error('The keyType is not supported');
}
} else if (typeof embedInput === 'object') {
inputs = embedInput;
} else {
throw new Error('Invalid input: Must be an instance of EmbedInput or a dictionary');
}
if (this.keyType === SupportedEmbedModels.OPENAI) {
const results = await this.openaiWrapper.getEmbeddings(inputs);
return results.data;
} else if (this.keyType === SupportedEmbedModels.COHERE) {
const results = await this.cohereWrapper.getEmbeddings(inputs);
let embeddings = results.embeddings;
embeddings = embeddings.map((embedding, index) => ({
object: "embedding",
index: index,
embedding: embedding
}));
return embeddings;
} else if (this.keyType === SupportedEmbedModels.REPLICATE) {
const prediction = await this.replicateWrapper.predict('replicate', inputs);
// Return a Promise that resolves with unified embedding result
return new Promise((resolve, reject) => {
const poll = setInterval(async () => {
try {
const status = await this.replicateWrapper.getPredictionStatus(prediction.id);
if (status.status === 'succeeded' || status.status === 'failed') {
clearInterval(poll); // Stop polling
if (status.status === 'succeeded') {
let embeddings = status.output;
embeddings = embeddings.map((embedding, index) => ({
object: "embedding",
index: index,
embedding: embedding
}));
resolve(embeddings);
} else {
reject(new Error('Replicate prediction failed: ' + status.error));
}
}
} catch (error) {
clearInterval(poll);
reject(new Error('Error while polling for Replicate prediction status: ' + error.message));
}
}, 1000);
});
} else if (this.keyType === SupportedEmbedModels.GEMINI) {
return await this.geminiWrapper.getEmbeddings(inputs);
} else if (this.keyType === SupportedEmbedModels.NVIDIA) {
const result = await this.nvidiaWrapper.generateRetrieval(inputs);
return Array.isArray(result) ? result : [];
} else if (this.keyType === SupportedEmbedModels.VLLM) {
const results = await this.vllmWrapper.getEmbeddings(inputs.texts);
return results.embeddings.map((embedding, index) => ({
object: "embedding",
index: index,
embedding: embedding
}));
}else {
throw new Error('The keyType is not supported');
}
}
}
module.exports = {
RemoteEmbedModel,
SupportedEmbedModels,
};