UNPKG

@langgraph-js/memory

Version:

A memory management system based on PostgreSQL + pgvector for LangGraph workflows

396 lines (351 loc) 15.8 kB
import { AIMessage, HumanMessage, SystemMessage, ToolMessage } from '@langchain/core/messages'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { v4 as uuidv4 } from 'uuid'; import { Pool } from 'pg'; import { DeleteAllMemoryOptions, GetAllMemoryOptions, IdSet, MemoryBase, MemoryFilters, MemoryItem, SearchResult, } from './types.js'; import { PostgresVectorStore } from './vector-store/pg.js'; import { FactRetrievalSchema, getFactRetrievalMessages } from './prompts/fact_extract.js'; import { getUpdateMemoryMessages, UpdateMemorySchema } from './prompts/conflict_message.js'; import { z } from 'zod'; /** * Embedder 接口 - 由外部实现 */ export interface Embedder { embed(text: string): Promise<number[]>; embedBatch(text: string[]): Promise< { embedding: number[]; original: string; }[] >; } export const messagesToText = (messages: (HumanMessage | SystemMessage | AIMessage | ToolMessage)[]) => { return messages .map((i) => { if (i.getType() === 'human') { return `<message type="human">${i.content}</message>`; } else if (i.getType() === 'ai') { if ((i as AIMessage).tool_calls?.length) { return `<message type="ai">${(i as AIMessage).tool_calls?.map((t) => { return `<tool_call name="${t.name}" id="${t.id}"><args>${JSON.stringify( t.args, )}</args></tool_call>`; })}</message>`; } return `<message type="ai">${i.content}</message>`; } else if (i.getType() === 'system') { return ``; } else if (i.getType() === 'tool') { return `<message type="tool">${i.content}</message>`; } }) .join('\n'); }; export interface MemoryDatabaseConfig { pool: Pool; llm: BaseChatModel; embedder: Embedder; tableName?: string; dimension?: number; } /** * 基于 PostgreSQL + pgvector 的记忆数据库实现 */ export class MemoryDataBase implements MemoryBase { constructor( public org_id: string, private llm: BaseChatModel, private embedder: Embedder, public vectorStore: PostgresVectorStore, public customPrompt?: string, ) {} /** * 初始化数据库 */ async setup(): Promise<void> { await this.vectorStore.initialize(); } /** * 添加记忆 */ async add( messages: (HumanMessage | SystemMessage | AIMessage | ToolMessage)[], config: { metadata?: Record<string, any>; filters?: MemoryFilters; infer?: boolean } & IdSet, ): Promise<SearchResult> { const { userId, agentId, runId, metadata = {}, filters = {}, infer = true } = config; // 只合并到 filters,不再污染 metadata if (userId) { filters.userId = userId; } if (agentId) { filters.agentId = agentId; } if (runId) { filters.runId = runId; } // 验证必须的过滤条件 if (!filters.userId && !filters.agentId && !filters.runId) { throw new Error('One of the filters: userId, agentId or runId is required!'); } const facts = await this.extractFacts(messagesToText(messages)); // 处理每个事实 const results: MemoryItem[] = []; const embeddings = await this.embedder.embedBatch(facts); for (const { original: fact, embedding } of embeddings) { // 搜索相似的记忆(限制在当前组织内,增加搜索数量以便更好地检测冲突) const searchFilters = { ...filters, org_id: this.org_id }; const similarMemories = await this.vectorStore.search(embedding, 10, searchFilters); // 决定如何处理这个事实(添加、更新或删除) const actions = await this.decideMemoryAction(fact, similarMemories); for (const action of actions) { switch (action.event) { case 'ADD': { const memoryId = uuidv4(); const insertResult = await this.vectorStore.insert( memoryId, this.org_id, action.text, embedding, { userId, agentId, runId, categories: action.categories, userMetadata: metadata, // 用户自定义数据单独存储 }, ); results.push({ id: insertResult.id, org_id: this.org_id, user_id: userId, agent_id: agentId, run_id: runId, memory: action.text, categories: action.categories, metadata: metadata, // 只返回用户自定义的 metadata created_at: insertResult.created_at, updated_at: insertResult.updated_at, }); break; } case 'UPDATE': { if (action.id && action.id !== '') { const newEmbedding = await this.embedder.embed(action.text); const updateResult = await this.vectorStore.update(action.id, action.text, newEmbedding, { categories: action.categories, userMetadata: { ...metadata, event: action.event, previousMemory: action.old_memory, }, }); // 获取更新后的完整记录 const updatedRecord = await this.vectorStore.get(action.id); results.push({ id: action.id, org_id: this.org_id, user_id: updatedRecord?.metadata.userId, agent_id: updatedRecord?.metadata.agentId, run_id: updatedRecord?.metadata.runId, memory: action.text, categories: action.categories, metadata: updatedRecord?.metadata.userMetadata || metadata, created_at: updatedRecord?.metadata.createdAt || new Date().toISOString(), updated_at: updateResult.updated_at, }); } else { console.warn('UPDATE action missing or empty id field, skipping:', action); } break; } case 'DELETE': { if (action.id && action.id !== '') { // 在删除之前获取记录信息 const recordToDelete = await this.vectorStore.get(action.id); await this.vectorStore.delete(action.id); results.push({ id: action.id, org_id: this.org_id, user_id: recordToDelete?.metadata.userId, agent_id: recordToDelete?.metadata.agentId, run_id: recordToDelete?.metadata.runId, memory: action.text, categories: action.categories, metadata: { ...(recordToDelete?.metadata.userMetadata || {}), event: action.event, }, created_at: recordToDelete?.metadata.createdAt || new Date().toISOString(), updated_at: recordToDelete?.metadata.updatedAt || new Date().toISOString(), }); } else { console.warn('DELETE action missing or empty id field, skipping:', action); } break; } } } } return { results }; } /** * 获取单个记忆 */ async get(memoryId: string): Promise<MemoryItem | null> { const result = await this.vectorStore.get(memoryId); if (!result) return null; // 验证记忆是否属于当前组织 if (result.metadata.org_id && result.metadata.org_id !== this.org_id) { return null; // 不属于当前组织的记忆,返回 null } return { id: result.id, org_id: this.org_id, user_id: result.metadata.userId, agent_id: result.metadata.agentId, run_id: result.metadata.runId, memory: result.memory, categories: result.metadata.categories, metadata: result.metadata.userMetadata || {}, // 只返回用户自定义的 metadata created_at: result.metadata.createdAt || new Date().toISOString(), updated_at: result.metadata.updatedAt || new Date().toISOString(), }; } /** * 搜索记忆 */ async search(query: string, config: { limit?: number; filters?: MemoryFilters } & IdSet): Promise<SearchResult> { const { userId, agentId, runId, limit = 100, filters = {} } = config; if (userId) filters.userId = userId; if (agentId) filters.agentId = agentId; if (runId) filters.runId = runId; if (!filters.userId && !filters.agentId && !filters.runId) { throw new Error('One of the filters: userId, agentId or runId is required!'); } // 强制添加 org_id 过滤条件,确保只能访问当前组织的记忆 const searchFilters = { ...filters, org_id: this.org_id }; const queryEmbedding = await this.embedder.embed(query); const searchResults = await this.vectorStore.search(queryEmbedding, limit, searchFilters); const results: MemoryItem[] = searchResults.map((result) => ({ id: result.id, org_id: this.org_id, user_id: result.metadata.userId, agent_id: result.metadata.agentId, run_id: result.metadata.runId, memory: result.memory, categories: result.metadata.categories, metadata: result.metadata.userMetadata || {}, // 只返回用户自定义的 metadata score: result.score, created_at: result.metadata.createdAt || new Date().toISOString(), updated_at: result.metadata.updatedAt || new Date().toISOString(), })); return { results }; } /** * 更新记忆 */ async update(memoryId: string, data: string): Promise<{ message: string }> { const embedding = await this.embedder.embed(data); await this.vectorStore.update(memoryId, data, embedding, {}); return { message: 'Memory updated successfully!' }; } /** * 删除记忆 */ async delete(memoryId: string): Promise<{ message: string }> { await this.vectorStore.delete(memoryId); return { message: 'Memory deleted successfully!' }; } /** * 批量删除记忆 */ async deleteAll(config: DeleteAllMemoryOptions): Promise<{ message: string }> { const { userId, agentId, runId, ...filters } = config; if (userId) filters.userId = userId; if (agentId) filters.agentId = agentId; if (runId) filters.runId = runId; if (!Object.keys(filters).length) { throw new Error( 'At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method.', ); } // 强制添加 org_id 过滤条件,确保只能删除当前组织的记忆 const deleteFilters = { ...filters, org_id: this.org_id }; const count = await this.vectorStore.deleteAll(deleteFilters); return { message: `${count} memories deleted successfully!` }; } /** * 重置所有记忆 */ async reset(): Promise<void> { // 只删除当前组织的记忆,而不是整个表 await this.vectorStore.deleteAll({ org_id: this.org_id }); } /** * 获取所有记忆 */ async getAll(config: GetAllMemoryOptions): Promise<SearchResult> { const { userId, agentId, runId, limit = 100, ...filters } = config; if (userId) filters.userId = userId; if (agentId) filters.agentId = agentId; if (runId) filters.runId = runId; // 强制添加 org_id 过滤条件,确保只能访问当前组织的记忆 const listFilters = { ...filters, org_id: this.org_id }; const memories = await this.vectorStore.list(listFilters, limit); const results: MemoryItem[] = memories.map((mem) => ({ id: mem.id, org_id: this.org_id, user_id: mem.metadata.userId, agent_id: mem.metadata.agentId, run_id: mem.metadata.runId, memory: mem.memory, categories: mem.metadata.categories, metadata: mem.metadata.userMetadata || {}, // 只返回用户自定义的 metadata created_at: mem.metadata.createdAt || new Date().toISOString(), updated_at: mem.metadata.updatedAt || new Date().toISOString(), expiration_date: mem.metadata.expirationDate, })); return { results }; } /** * 使用 LLM 从文本中提取事实 */ private async extractFacts(messageText: string): Promise<string[]> { const response = await this.llm .withStructuredOutput(FactRetrievalSchema) .invoke(getFactRetrievalMessages(messageText)); return response.facts || []; } /** * 决定对记忆的操作(添加、更新、删除) */ private async decideMemoryAction( newFact: string, similarMemories: Array<{ id: string; memory: string; score: number }>, ): Promise<z.infer<typeof UpdateMemorySchema>['memory']> { // 过滤出相似度较高的记忆(阈值降低到 0.5) const relevantMemories = similarMemories.filter((m) => m.score >= 0.5); try { // 始终让 LLM 来决定操作,即使没有相似记忆也让 LLM 生成合适的 categories const response = (await this.llm.withStructuredOutput(UpdateMemorySchema).invoke( getUpdateMemoryMessages( relevantMemories.map((m) => ({ id: m.id, text: m.memory })), [newFact], ), )) as z.infer<typeof UpdateMemorySchema>; return response.memory; } catch (error) { console.error('Failed to decide memory action:', error); // 降级到简单添加,使用默认分类 return [{ event: 'ADD', text: newFact, id: '', categories: ['general'], old_memory: '' }]; } } }