@xiaohui-wang/mcpadvisor
Version:
MCP Advisor & Installation - Find the right MCP server for your needs
419 lines (418 loc) • 13.4 kB
JavaScript
import * as use from '@tensorflow-models/universal-sentence-encoder';
import logger from './logger.js';
// 缓存配置
const CACHE_SIZE_LIMIT = 1000; // 最大缓存条目数
const DEFAULT_DIMENSIONS = 512; // Universal Sentence Encoder 默认输出维度
/**
* 创建固定长度的零向量
*/
export const createZeroVector = (size) => new Array(size).fill(0);
/**
* 向量归一化
*/
export const normalizeVector = (vector) => {
const magnitude = calculateMagnitude(vector);
if (magnitude === 0 || !isFinite(magnitude)) {
return vector;
}
return vector.map(val => val / magnitude);
};
/**
* 计算向量幅度
*/
export const calculateMagnitude = (vector) => {
const sumOfSquares = vector.reduce((sum, val) => sum + val * val, 0);
return Math.sqrt(sumOfSquares);
};
/**
* 计算两个向量的余弦相似度
*/
export const cosineSimilarity = (vecA, vecB) => {
if (vecA.length !== vecB.length) {
throw new Error(`Vector dimensions do not match: ${vecA.length} vs ${vecB.length}`);
}
let dotProduct = 0;
let magA = 0;
let magB = 0;
for (let i = 0; i < vecA.length; i++) {
dotProduct += vecA[i] * vecB[i];
magA += vecA[i] * vecA[i];
magB += vecB[i] * vecB[i];
}
magA = Math.sqrt(magA);
magB = Math.sqrt(magB);
if (magA === 0 || magB === 0) {
return 0;
}
return dotProduct / (magA * magB);
};
/**
* Universal Sentence Encoder 嵌入提供者
* 使用 TensorFlow.js 的 USE 模型生成高质量的文本嵌入
*/
class UniversalSentenceEncoderProvider {
model = null;
modelLoading = null;
fallbackProvider = null;
constructor(fallbackProvider = null) {
this.fallbackProvider = fallbackProvider;
// 初始化时开始加载模型
this.ensureModelLoaded().catch(err => {
logger.error(`Failed to load Universal Sentence Encoder model: ${err.message}`, { error: err });
});
}
/**
* 获取提供者名称
*/
getName() {
return 'UniversalSentenceEncoderProvider';
}
/**
* 检查模型是否已加载
*/
isModelLoaded() {
return this.model !== null;
}
/**
* 确保模型已加载
*/
async ensureModelLoaded() {
// 如果模型已经加载,直接返回
if (this.model !== null) {
return;
}
// 如果模型正在加载,等待完成
if (this.modelLoading !== null) {
return this.modelLoading;
}
// 开始加载模型
this.modelLoading = (async () => {
try {
logger.info('Loading Universal Sentence Encoder model...');
// 加载模型
this.model = await use.load();
logger.info('Universal Sentence Encoder model loaded successfully');
}
catch (error) {
const message = error instanceof Error ? error.message : String(error);
logger.error(`Error loading Universal Sentence Encoder model: ${message}`, { error });
// 重置加载状态,允许重试
this.modelLoading = null;
throw error;
}
})();
return this.modelLoading;
}
/**
* 生成单个文本的嵌入
*/
async generateEmbedding(text, dimensions) {
try {
// 确保模型已加载
await this.ensureModelLoaded();
if (!text) {
return createZeroVector(DEFAULT_DIMENSIONS);
}
if (!this.model) {
throw new Error('Model not loaded');
}
// 使用模型生成嵌入
const embeddings = await this.model.embed(text);
// 转换为 JavaScript 数组
const embeddingArray = await embeddings.array();
const result = embeddingArray[0];
// 释放张量资源
embeddings.dispose();
// 归一化向量
return normalizeVector(result);
}
catch (error) {
const message = error instanceof Error ? error.message : String(error);
logger.error(`Error generating embedding with Universal Sentence Encoder: ${message}`, { error, text });
// 如果有备用提供者,使用备用提供者
if (this.fallbackProvider) {
logger.warn(`Falling back to ${this.fallbackProvider.getName()} for embedding generation`);
return this.fallbackProvider.generateEmbedding(text, dimensions);
}
// 如果没有备用提供者,抛出错误
throw error;
}
}
/**
* 批量生成文本嵌入
*/
async generateEmbeddings(texts, dimensions) {
try {
// 确保模型已加载
await this.ensureModelLoaded();
if (!texts || texts.length === 0) {
return [];
}
if (!this.model) {
throw new Error('Model not loaded');
}
// 使用模型生成嵌入
const embeddings = await this.model.embed(texts);
// 转换为 JavaScript 数组
const embeddingArrays = await embeddings.array();
// 释放张量资源
embeddings.dispose();
// 归一化所有向量
return embeddingArrays.map(vector => normalizeVector(vector));
}
catch (error) {
const message = error instanceof Error ? error.message : String(error);
logger.error(`Error generating batch embeddings with Universal Sentence Encoder: ${message}`, { error, textCount: texts.length });
// 如果有备用提供者,使用备用提供者
if (this.fallbackProvider) {
logger.warn(`Falling back to ${this.fallbackProvider.getName()} for batch embedding generation`);
return this.fallbackProvider.generateEmbeddings(texts, dimensions);
}
// 如果没有备用提供者,抛出错误
throw error;
}
}
}
/**
* 简单嵌入提供者
* 使用基本的哈希和字符编码方法生成向量
* 作为备用提供者,当主要提供者失败时使用
*/
class SimpleEmbeddingProvider {
/**
* 获取提供者名称
*/
getName() {
return 'SimpleEmbeddingProvider';
}
/**
* 模型始终被认为已加载
*/
isModelLoaded() {
return true;
}
/**
* 简单提供者不需要加载模型
*/
async ensureModelLoaded() {
return Promise.resolve();
}
/**
* 生成单个文本的嵌入
*/
async generateEmbedding(text, dimensions = DEFAULT_DIMENSIONS) {
if (!text) {
return createZeroVector(dimensions);
}
// 创建初始向量
const vector = createZeroVector(dimensions);
// 使用简单的算法更新向量
const updatedVector = this.generateVectorFromText(text, vector, dimensions);
// 归一化向量
return normalizeVector(updatedVector);
}
/**
* 批量生成文本嵌入
*/
async generateEmbeddings(texts, dimensions = DEFAULT_DIMENSIONS) {
return Promise.all(texts.map(text => this.generateEmbedding(text, dimensions)));
}
/**
* 使用简单的算法从文本生成向量
*/
generateVectorFromText(text, vector, dimensions) {
const result = [...vector];
const normalizedText = text.toLowerCase();
// 基本字符编码映射
for (let i = 0; i < normalizedText.length; i++) {
const charCode = normalizedText.charCodeAt(i);
result[i % dimensions] += charCode / 255;
}
// 添加n-gram特征
this.addNgramFeatures(normalizedText, result, dimensions);
return result;
}
/**
* 添加n-gram特征
*/
addNgramFeatures(text, vector, dimensions) {
// 处理2-gram和3-gram
for (let n = 2; n <= 3; n++) {
for (let i = 0; i < text.length - n + 1; i++) {
const ngram = text.substring(i, i + n);
const hash = this.simpleHash(ngram);
const position = hash % dimensions;
vector[position] += 0.5; // 较小的权重
}
}
}
/**
* 简单的字符串哈希函数
*/
simpleHash(str) {
let hash = 0;
for (let i = 0; i < str.length; i++) {
const char = str.charCodeAt(i);
hash = (hash << 5) - hash + char;
hash = hash & hash; // Convert to 32bit integer
}
return Math.abs(hash);
}
}
/**
* 嵌入缓存类
* 使用LRU策略缓存生成的嵌入
*/
class EmbeddingCache {
cache;
keyOrder;
maxSize;
constructor(maxSize = CACHE_SIZE_LIMIT) {
this.cache = new Map();
this.keyOrder = [];
this.maxSize = maxSize;
}
/**
* 获取缓存的嵌入
*/
get(key) {
const value = this.cache.get(key);
if (value) {
// 更新LRU顺序
this.keyOrder = this.keyOrder.filter(k => k !== key);
this.keyOrder.push(key);
}
return value;
}
/**
* 设置缓存的嵌入
*/
set(key, value) {
// 如果缓存已满,移除最久未使用的项
if (this.cache.size >= this.maxSize && !this.cache.has(key)) {
const oldestKey = this.keyOrder.shift();
if (oldestKey) {
this.cache.delete(oldestKey);
}
}
// 添加新项
this.cache.set(key, value);
this.keyOrder = this.keyOrder.filter(k => k !== key);
this.keyOrder.push(key);
}
/**
* 清除缓存
*/
clear() {
this.cache.clear();
this.keyOrder = [];
}
/**
* 获取缓存大小
*/
size() {
return this.cache.size;
}
}
// 创建全局缓存实例
const embeddingCache = new EmbeddingCache();
// 创建简单嵌入提供者作为备用
const simpleProvider = new SimpleEmbeddingProvider();
// 创建主要嵌入提供者
const mainProvider = new UniversalSentenceEncoderProvider(simpleProvider);
// 默认使用 Universal Sentence Encoder 提供者
let currentProvider = mainProvider;
/**
* 设置当前嵌入提供者
*/
export const setEmbeddingProvider = (provider) => {
currentProvider = provider;
// 切换提供者时清除缓存
embeddingCache.clear();
};
/**
* 获取当前嵌入提供者
*/
export const getEmbeddingProvider = () => currentProvider;
/**
* 生成缓存键
*/
const generateCacheKey = (text, dimensions) => `${text}_${dimensions}_${currentProvider.getName()}`;
/**
* 将文本转换为向量表示
* @param text 输入文本
* @param dimensions 向量维度
* @returns 归一化的向量表示
*/
export const getTextEmbedding = async (text, dimensions = DEFAULT_DIMENSIONS) => {
if (!text) {
return createZeroVector(dimensions);
}
// 生成缓存键
const cacheKey = generateCacheKey(text, dimensions);
// 检查缓存
const cachedVector = embeddingCache.get(cacheKey);
if (cachedVector) {
return cachedVector;
}
try {
// 使用当前提供者生成嵌入
const vector = await currentProvider.generateEmbedding(text, dimensions);
// 缓存结果
embeddingCache.set(cacheKey, vector);
return vector;
}
catch (error) {
logger.error(`Error generating embedding: ${error instanceof Error ? error.message : String(error)}`, { error });
// 如果出错,返回零向量
return createZeroVector(dimensions);
}
};
/**
* 批量生成文本嵌入
* @param texts 输入文本数组
* @param dimensions 向量维度
* @returns 归一化的向量表示数组
*/
export const getTextEmbeddings = async (texts, dimensions = DEFAULT_DIMENSIONS) => {
if (!texts || texts.length === 0) {
return [];
}
try {
// 尝试使用批量生成
return await currentProvider.generateEmbeddings(texts, dimensions);
}
catch (error) {
logger.error(`Error generating batch embeddings: ${error instanceof Error ? error.message : String(error)}`, { error });
// 如果出错,返回零向量数组
return texts.map(() => createZeroVector(dimensions));
}
};
/**
* 清除嵌入缓存
*/
export const clearEmbeddingCache = () => {
embeddingCache.clear();
};
/**
* 切换到简单嵌入提供者
* 当需要轻量级处理或主要提供者失败时使用
*/
export const useSimpleEmbeddingProvider = () => {
setEmbeddingProvider(simpleProvider);
logger.info('Switched to SimpleEmbeddingProvider');
};
/**
* 切换到 Universal Sentence Encoder 嵌入提供者
* 提供高质量的文本嵌入
*/
export const useUniversalSentenceEncoder = () => {
setEmbeddingProvider(mainProvider);
logger.info('Switched to UniversalSentenceEncoderProvider');
};
// 初始化时预加载模型
mainProvider.ensureModelLoaded().catch(err => {
logger.warn(`Failed to preload Universal Sentence Encoder model: ${err.message}. Will use fallback provider.`, { error: err });
// 如果预加载失败,切换到简单提供者
useSimpleEmbeddingProvider();
});