UNPKG

autosnippet

Version:

Extract code patterns into a knowledge base for AI coding assistants

298 lines (297 loc) 12.7 kB
/** * MemoryRetriever — 记忆检索与 Prompt 生成 * * 从 PersistentMemory.js 提取的检索逻辑。 * 负责: * - 三维打分检索 (Generative Agents: recency × importance × relevance) * - 简单文本搜索 * - Prompt section 生成 (预算感知) * - Memory.js 兼容层: load(), append() * * @module MemoryRetriever */ import { cosineSimilarity } from '#shared/similarity.js'; import { MemoryStore } from './MemoryStore.js'; // ─── 常量 (Generative Agents 三维打分) ──────────────── /** 检索打分权重 */ const WEIGHT_RECENCY = 0.2; const WEIGHT_IMPORTANCE = 0.3; const WEIGHT_RELEVANCE = 0.5; /** Recency 半衰期 (天) — 7 天未访问的记忆分数下降一半 */ const RECENCY_HALF_LIFE_DAYS = 7; /** 相似度阈值 (用于 append 去重) */ const SIMILARITY_UPDATE = 0.85; export class MemoryRetriever { #store; /** 向量嵌入函数 */ #embeddingFn; /** 向量嵌入存储 (JSON sidecar) */ #embeddingStore; /** @param [opts.embeddingFn] 向量嵌入函数 (异步) */ constructor(store, opts = {}) { this.#store = store; this.#embeddingFn = typeof opts.embeddingFn === 'function' ? opts.embeddingFn : null; this.#embeddingStore = opts.embeddingStore ?? null; } // ═══════════════════════════════════════════════════════════ // 综合检索 // ═══════════════════════════════════════════════════════════ /** * 综合检索: recency × importance × relevance * * 借鉴 Generative Agents 的三维打分模型: * score = α * recency + β * importance + γ * relevance * * @param query 查询文本 * @returns 按 score 降序排列 */ async retrieve(query, { limit = 10, source, type } = {}) { const all = this.#store.getAllActive({ source, type }); if (all.length === 0) { return []; } const now = Date.now(); const lowerQuery = (query || '').toLowerCase(); const queryTokens = MemoryRetriever.#tokenizeWords(lowerQuery); // 向量检索: 嵌入 query,然后与存储的 embedding 做余弦相似度 let queryVec = null; if (this.#embeddingFn) { try { queryVec = await this.#embeddingFn(query); } catch { // embedding 不可用时 graceful degrade 到纯词汇 } } const scored = all.map((m) => { // Recency: 指数衰减 (半衰期 7 天) const lastAccess = m.last_accessed_at ? new Date(m.last_accessed_at).getTime() : new Date(m.updated_at).getTime(); const daysSinceAccess = (now - lastAccess) / 86400_000; const recency = Math.exp((-daysSinceAccess * Math.LN2) / RECENCY_HALF_LIFE_DAYS); // Importance: 归一化到 0-1 const importance = (m.importance || 5) / 10; // Relevance: 词汇相关性 (lexical) const lexicalRelevance = MemoryRetriever.#computeRelevance(lowerQuery, queryTokens, m.content); // 向量相关性: 从 embeddingStore 查找 embedding 做余弦相似度 const deserialized = MemoryStore.deserialize(m); let vectorRelevance = 0; const storedEmbedding = this.#embeddingStore?.get(m.id) ?? null; if (queryVec && storedEmbedding) { vectorRelevance = Math.max(0, cosineSimilarity(queryVec, storedEmbedding)); } // 混合相关性: 有向量时 0.6 * vector + 0.4 * lexical,否则纯 lexical const relevance = queryVec && storedEmbedding ? 0.6 * vectorRelevance + 0.4 * lexicalRelevance : lexicalRelevance; const score = WEIGHT_RECENCY * recency + WEIGHT_IMPORTANCE * importance + WEIGHT_RELEVANCE * relevance; return { ...deserialized, _score: score, _recency: recency, _relevance: relevance, }; }); scored.sort((a, b) => b._score - a._score); // 更新访问计数 (只更新返回的) const topN = scored.slice(0, limit); for (const m of topN) { this.#store.touchAccess(m.id); } return topN; } /** 简单文本搜索 (不打分, 用于去重检查) */ search(content, { limit = 5 } = {}) { const results = this.#store.findSimilar(content, null, limit); return results.map((r) => MemoryStore.deserialize(r)); } // ═══════════════════════════════════════════════════════════ // Prompt 生成 (预算感知) // ═══════════════════════════════════════════════════════════ /** * 生成供系统提示词的记忆摘要 (预算感知) * * @returns Markdown 格式 */ async toPromptSection({ source, query, limit = 15, tokenBudget, } = {}) { if (tokenBudget && tokenBudget > 0) { const EST_TOKENS_PER_MEMORY = 30; const HEADER_TOKENS = 15; const maxByBudget = Math.max(3, Math.floor((tokenBudget - HEADER_TOKENS) / EST_TOKENS_PER_MEMORY)); limit = Math.min(limit, maxByBudget); } let memories; if (query) { memories = await this.retrieve(query, { limit, source }); } else { memories = this.#store .getAllActive({ source }) .sort((a, b) => { const scoreA = (a.importance || 5) * 0.6 + (a.access_count || 0) * 0.4; const scoreB = (b.importance || 5) * 0.6 + (b.access_count || 0) * 0.4; return scoreB - scoreA; }) .slice(0, limit) .map((m) => MemoryStore.deserialize(m)); } if (memories.length === 0) { return ''; } const lines = memories.map((m) => { const badge = m.importance >= 8 ? '⚠️' : m.importance >= 5 ? '📌' : '💡'; return `- ${badge} [${m.type}] ${m.content}`; }); return `\n## 项目记忆 (${memories.length} 条最相关)\n${lines.join('\n')}\n`; } // ═══════════════════════════════════════════════════════════ // Memory.js 兼容层 // ═══════════════════════════════════════════════════════════ /** 兼容 Memory.load() — 返回最近 N 条记忆 */ load(limit = 20, { source } = {}) { const rows = this.#store .getAllActive({ source }) .sort((a, b) => { const tA = new Date(a.updated_at).getTime(); const tB = new Date(b.updated_at).getTime(); return tB - tA; }) .slice(0, limit); return rows.map((r) => ({ ts: r.updated_at, type: r.type, content: r.content, source: r.source, importance: r.importance, })); } /** 兼容 Memory.append() — 添加一条记忆 (自动去重) */ append(entry) { const content = (entry.content || '').trim().substring(0, 500); if (!content) { return; } // 去重: 检查是否已有高相似度记忆 const similar = this.#store.findSimilar(content, entry.type ?? null, 1); if (similar.length > 0 && (similar[0].similarity ?? 0) >= SIMILARITY_UPDATE) { this.#store.touchAccess(similar[0].id); return; } this.#store.add({ type: entry.type || 'context', content, source: entry.source || 'user', importance: 5, ttlDays: entry.ttl || null, }); } // ═══════════════════════════════════════════════════════════ // 向量嵌入接口 // ═══════════════════════════════════════════════════════════ /** 设置向量嵌入函数 */ setEmbeddingFunction(fn) { this.#embeddingFn = typeof fn === 'function' ? fn : null; } /** 获取当前嵌入函数 */ getEmbeddingFunction() { return this.#embeddingFn; } /** * 为所有缺少 embedding 的记忆批量生成向量嵌入 * @param batchSize 每批数量 (默认 20) * @returns 成功嵌入的记忆数 */ async embedAllMemories(batchSize = 20) { if (!this.#embeddingFn || !this.#embeddingStore) { return 0; } // 从 MemoryStore 获取所有活跃记忆 ID,找出 embeddingStore 中缺失的 const allActive = this.#store.getAllActive(); const allIds = allActive.map((m) => m.id); const missingIds = this.#embeddingStore.getMissingIds(allIds); if (missingIds.length === 0) { return 0; } // 取前 batchSize 条 const batch = missingIds.slice(0, batchSize); const contentMap = new Map(allActive.map((m) => [m.id, m.content])); const entries = []; for (const id of batch) { const content = contentMap.get(id); if (!content) { continue; } try { const vec = await this.#embeddingFn(content); entries.push({ id, embedding: vec }); } catch { // 单条失败不阻塞 } } if (entries.length === 0) { return 0; } return this.#embeddingStore.batchSet(entries); } /** * 使用嵌入函数计算语义相关性 (余弦相似度) * @param query 查询文本 * @param content 记忆内容 * @returns 相似度分数 或 null */ async computeEmbeddingRelevance(query, content) { if (!this.#embeddingFn) { return null; } try { const [queryVec, contentVec] = await Promise.all([ this.#embeddingFn(query), this.#embeddingFn(content), ]); return cosineSimilarity(queryVec, contentVec); } catch { return null; } } // ═══════════════════════════════════════════════════════════ // Private: 相关性计算 // ═══════════════════════════════════════════════════════════ static #computeRelevance(lowerQuery, queryTokens, content) { if (!lowerQuery || !content) { return 0; } const lowerContent = content.toLowerCase(); const contentTokens = MemoryRetriever.#tokenizeWords(lowerContent); if (queryTokens.size === 0) { return 0; } let matchCount = 0; for (const t of queryTokens) { if (contentTokens.has(t)) { matchCount++; } } const tokenOverlap = matchCount / queryTokens.size; const substringMatch = lowerContent.includes(lowerQuery) ? 0.4 : 0; let partialMatch = 0; for (const qt of queryTokens) { if (qt.length >= 3 && lowerContent.includes(qt)) { partialMatch += 0.1; } } partialMatch = Math.min(0.3, partialMatch); return Math.min(1.0, tokenOverlap * 0.5 + substringMatch + partialMatch); } static #tokenizeWords(text) { if (!text) { return new Set(); } return new Set(text .split(/[\s,;:!?。,;:!?\-_/\\|()[\]{}'"<>]+/) .filter((t) => t.length >= 2) .map((t) => t.toLowerCase())); } }