@langgraph-js/memory
Version:
A memory management system based on PostgreSQL + pgvector for LangGraph workflows
396 lines (351 loc) • 15.8 kB
text/typescript
import { AIMessage, HumanMessage, SystemMessage, ToolMessage } from '@langchain/core/messages';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { v4 as uuidv4 } from 'uuid';
import { Pool } from 'pg';
import {
DeleteAllMemoryOptions,
GetAllMemoryOptions,
IdSet,
MemoryBase,
MemoryFilters,
MemoryItem,
SearchResult,
} from './types.js';
import { PostgresVectorStore } from './vector-store/pg.js';
import { FactRetrievalSchema, getFactRetrievalMessages } from './prompts/fact_extract.js';
import { getUpdateMemoryMessages, UpdateMemorySchema } from './prompts/conflict_message.js';
import { z } from 'zod';
/**
* Embedder 接口 - 由外部实现
*/
export interface Embedder {
embed(text: string): Promise<number[]>;
embedBatch(text: string[]): Promise<
{
embedding: number[];
original: string;
}[]
>;
}
export const messagesToText = (messages: (HumanMessage | SystemMessage | AIMessage | ToolMessage)[]) => {
return messages
.map((i) => {
if (i.getType() === 'human') {
return `<message type="human">${i.content}</message>`;
} else if (i.getType() === 'ai') {
if ((i as AIMessage).tool_calls?.length) {
return `<message type="ai">${(i as AIMessage).tool_calls?.map((t) => {
return `<tool_call name="${t.name}" id="${t.id}"><args>${JSON.stringify(
t.args,
)}</args></tool_call>`;
})}</message>`;
}
return `<message type="ai">${i.content}</message>`;
} else if (i.getType() === 'system') {
return ``;
} else if (i.getType() === 'tool') {
return `<message type="tool">${i.content}</message>`;
}
})
.join('\n');
};
export interface MemoryDatabaseConfig {
pool: Pool;
llm: BaseChatModel;
embedder: Embedder;
tableName?: string;
dimension?: number;
}
/**
* 基于 PostgreSQL + pgvector 的记忆数据库实现
*/
export class MemoryDataBase implements MemoryBase {
constructor(
public org_id: string,
private llm: BaseChatModel,
private embedder: Embedder,
public vectorStore: PostgresVectorStore,
public customPrompt?: string,
) {}
/**
* 初始化数据库
*/
async setup(): Promise<void> {
await this.vectorStore.initialize();
}
/**
* 添加记忆
*/
async add(
messages: (HumanMessage | SystemMessage | AIMessage | ToolMessage)[],
config: { metadata?: Record<string, any>; filters?: MemoryFilters; infer?: boolean } & IdSet,
): Promise<SearchResult> {
const { userId, agentId, runId, metadata = {}, filters = {}, infer = true } = config;
// 只合并到 filters,不再污染 metadata
if (userId) {
filters.userId = userId;
}
if (agentId) {
filters.agentId = agentId;
}
if (runId) {
filters.runId = runId;
}
// 验证必须的过滤条件
if (!filters.userId && !filters.agentId && !filters.runId) {
throw new Error('One of the filters: userId, agentId or runId is required!');
}
const facts = await this.extractFacts(messagesToText(messages));
// 处理每个事实
const results: MemoryItem[] = [];
const embeddings = await this.embedder.embedBatch(facts);
for (const { original: fact, embedding } of embeddings) {
// 搜索相似的记忆(限制在当前组织内,增加搜索数量以便更好地检测冲突)
const searchFilters = { ...filters, org_id: this.org_id };
const similarMemories = await this.vectorStore.search(embedding, 10, searchFilters);
// 决定如何处理这个事实(添加、更新或删除)
const actions = await this.decideMemoryAction(fact, similarMemories);
for (const action of actions) {
switch (action.event) {
case 'ADD': {
const memoryId = uuidv4();
const insertResult = await this.vectorStore.insert(
memoryId,
this.org_id,
action.text,
embedding,
{
userId,
agentId,
runId,
categories: action.categories,
userMetadata: metadata, // 用户自定义数据单独存储
},
);
results.push({
id: insertResult.id,
org_id: this.org_id,
user_id: userId,
agent_id: agentId,
run_id: runId,
memory: action.text,
categories: action.categories,
metadata: metadata, // 只返回用户自定义的 metadata
created_at: insertResult.created_at,
updated_at: insertResult.updated_at,
});
break;
}
case 'UPDATE': {
if (action.id && action.id !== '') {
const newEmbedding = await this.embedder.embed(action.text);
const updateResult = await this.vectorStore.update(action.id, action.text, newEmbedding, {
categories: action.categories,
userMetadata: {
...metadata,
event: action.event,
previousMemory: action.old_memory,
},
});
// 获取更新后的完整记录
const updatedRecord = await this.vectorStore.get(action.id);
results.push({
id: action.id,
org_id: this.org_id,
user_id: updatedRecord?.metadata.userId,
agent_id: updatedRecord?.metadata.agentId,
run_id: updatedRecord?.metadata.runId,
memory: action.text,
categories: action.categories,
metadata: updatedRecord?.metadata.userMetadata || metadata,
created_at: updatedRecord?.metadata.createdAt || new Date().toISOString(),
updated_at: updateResult.updated_at,
});
} else {
console.warn('UPDATE action missing or empty id field, skipping:', action);
}
break;
}
case 'DELETE': {
if (action.id && action.id !== '') {
// 在删除之前获取记录信息
const recordToDelete = await this.vectorStore.get(action.id);
await this.vectorStore.delete(action.id);
results.push({
id: action.id,
org_id: this.org_id,
user_id: recordToDelete?.metadata.userId,
agent_id: recordToDelete?.metadata.agentId,
run_id: recordToDelete?.metadata.runId,
memory: action.text,
categories: action.categories,
metadata: {
...(recordToDelete?.metadata.userMetadata || {}),
event: action.event,
},
created_at: recordToDelete?.metadata.createdAt || new Date().toISOString(),
updated_at: recordToDelete?.metadata.updatedAt || new Date().toISOString(),
});
} else {
console.warn('DELETE action missing or empty id field, skipping:', action);
}
break;
}
}
}
}
return { results };
}
/**
* 获取单个记忆
*/
async get(memoryId: string): Promise<MemoryItem | null> {
const result = await this.vectorStore.get(memoryId);
if (!result) return null;
// 验证记忆是否属于当前组织
if (result.metadata.org_id && result.metadata.org_id !== this.org_id) {
return null; // 不属于当前组织的记忆,返回 null
}
return {
id: result.id,
org_id: this.org_id,
user_id: result.metadata.userId,
agent_id: result.metadata.agentId,
run_id: result.metadata.runId,
memory: result.memory,
categories: result.metadata.categories,
metadata: result.metadata.userMetadata || {}, // 只返回用户自定义的 metadata
created_at: result.metadata.createdAt || new Date().toISOString(),
updated_at: result.metadata.updatedAt || new Date().toISOString(),
};
}
/**
* 搜索记忆
*/
async search(query: string, config: { limit?: number; filters?: MemoryFilters } & IdSet): Promise<SearchResult> {
const { userId, agentId, runId, limit = 100, filters = {} } = config;
if (userId) filters.userId = userId;
if (agentId) filters.agentId = agentId;
if (runId) filters.runId = runId;
if (!filters.userId && !filters.agentId && !filters.runId) {
throw new Error('One of the filters: userId, agentId or runId is required!');
}
// 强制添加 org_id 过滤条件,确保只能访问当前组织的记忆
const searchFilters = { ...filters, org_id: this.org_id };
const queryEmbedding = await this.embedder.embed(query);
const searchResults = await this.vectorStore.search(queryEmbedding, limit, searchFilters);
const results: MemoryItem[] = searchResults.map((result) => ({
id: result.id,
org_id: this.org_id,
user_id: result.metadata.userId,
agent_id: result.metadata.agentId,
run_id: result.metadata.runId,
memory: result.memory,
categories: result.metadata.categories,
metadata: result.metadata.userMetadata || {}, // 只返回用户自定义的 metadata
score: result.score,
created_at: result.metadata.createdAt || new Date().toISOString(),
updated_at: result.metadata.updatedAt || new Date().toISOString(),
}));
return { results };
}
/**
* 更新记忆
*/
async update(memoryId: string, data: string): Promise<{ message: string }> {
const embedding = await this.embedder.embed(data);
await this.vectorStore.update(memoryId, data, embedding, {});
return { message: 'Memory updated successfully!' };
}
/**
* 删除记忆
*/
async delete(memoryId: string): Promise<{ message: string }> {
await this.vectorStore.delete(memoryId);
return { message: 'Memory deleted successfully!' };
}
/**
* 批量删除记忆
*/
async deleteAll(config: DeleteAllMemoryOptions): Promise<{ message: string }> {
const { userId, agentId, runId, ...filters } = config;
if (userId) filters.userId = userId;
if (agentId) filters.agentId = agentId;
if (runId) filters.runId = runId;
if (!Object.keys(filters).length) {
throw new Error(
'At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method.',
);
}
// 强制添加 org_id 过滤条件,确保只能删除当前组织的记忆
const deleteFilters = { ...filters, org_id: this.org_id };
const count = await this.vectorStore.deleteAll(deleteFilters);
return { message: `${count} memories deleted successfully!` };
}
/**
* 重置所有记忆
*/
async reset(): Promise<void> {
// 只删除当前组织的记忆,而不是整个表
await this.vectorStore.deleteAll({ org_id: this.org_id });
}
/**
* 获取所有记忆
*/
async getAll(config: GetAllMemoryOptions): Promise<SearchResult> {
const { userId, agentId, runId, limit = 100, ...filters } = config;
if (userId) filters.userId = userId;
if (agentId) filters.agentId = agentId;
if (runId) filters.runId = runId;
// 强制添加 org_id 过滤条件,确保只能访问当前组织的记忆
const listFilters = { ...filters, org_id: this.org_id };
const memories = await this.vectorStore.list(listFilters, limit);
const results: MemoryItem[] = memories.map((mem) => ({
id: mem.id,
org_id: this.org_id,
user_id: mem.metadata.userId,
agent_id: mem.metadata.agentId,
run_id: mem.metadata.runId,
memory: mem.memory,
categories: mem.metadata.categories,
metadata: mem.metadata.userMetadata || {}, // 只返回用户自定义的 metadata
created_at: mem.metadata.createdAt || new Date().toISOString(),
updated_at: mem.metadata.updatedAt || new Date().toISOString(),
expiration_date: mem.metadata.expirationDate,
}));
return { results };
}
/**
* 使用 LLM 从文本中提取事实
*/
private async extractFacts(messageText: string): Promise<string[]> {
const response = await this.llm
.withStructuredOutput(FactRetrievalSchema)
.invoke(getFactRetrievalMessages(messageText));
return response.facts || [];
}
/**
* 决定对记忆的操作(添加、更新、删除)
*/
private async decideMemoryAction(
newFact: string,
similarMemories: Array<{ id: string; memory: string; score: number }>,
): Promise<z.infer<typeof UpdateMemorySchema>['memory']> {
// 过滤出相似度较高的记忆(阈值降低到 0.5)
const relevantMemories = similarMemories.filter((m) => m.score >= 0.5);
try {
// 始终让 LLM 来决定操作,即使没有相似记忆也让 LLM 生成合适的 categories
const response = (await this.llm.withStructuredOutput(UpdateMemorySchema).invoke(
getUpdateMemoryMessages(
relevantMemories.map((m) => ({ id: m.id, text: m.memory })),
[newFact],
),
)) as z.infer<typeof UpdateMemorySchema>;
return response.memory;
} catch (error) {
console.error('Failed to decide memory action:', error);
// 降级到简单添加,使用默认分类
return [{ event: 'ADD', text: newFact, id: '', categories: ['general'], old_memory: '' }];
}
}
}