UNPKG

@langchain/langgraph-checkpoint

Version:

Library with base interfaces for LangGraph checkpoint savers.

259 lines 12.1 kB
import { BaseCheckpointSaver, copyCheckpoint, getCheckpointId, WRITES_IDX_MAP, } from "./base.js"; import { TASKS } from "./serde/types.js"; function _generateKey(threadId, checkpointNamespace, checkpointId) { return JSON.stringify([threadId, checkpointNamespace, checkpointId]); } export class MemorySaver extends BaseCheckpointSaver { constructor(serde) { super(serde); // thread ID -> checkpoint namespace -> checkpoint ID -> checkpoint mapping Object.defineProperty(this, "storage", { enumerable: true, configurable: true, writable: true, value: {} }); Object.defineProperty(this, "writes", { enumerable: true, configurable: true, writable: true, value: {} }); } async _getPendingSends(threadId, checkpointNs, parentCheckpointId) { let pendingSends = []; if (parentCheckpointId !== undefined) { const key = _generateKey(threadId, checkpointNs, parentCheckpointId); pendingSends = await Promise.all(Object.values(this.writes[key] || {}) ?.filter(([_taskId, channel]) => { return channel === TASKS; }) .map(([_taskId, _channel, writes]) => { return this.serde.loadsTyped("json", writes); }) ?? []); } return pendingSends; } async getTuple(config) { const thread_id = config.configurable?.thread_id; const checkpoint_ns = config.configurable?.checkpoint_ns ?? ""; let checkpoint_id = getCheckpointId(config); if (checkpoint_id) { const saved = this.storage[thread_id]?.[checkpoint_ns]?.[checkpoint_id]; if (saved !== undefined) { const [checkpoint, metadata, parentCheckpointId] = saved; const key = _generateKey(thread_id, checkpoint_ns, checkpoint_id); const pending_sends = await this._getPendingSends(thread_id, checkpoint_ns, parentCheckpointId); const deserializedCheckpoint = { ...(await this.serde.loadsTyped("json", checkpoint)), pending_sends, }; const pendingWrites = await Promise.all(Object.values(this.writes[key] || {}).map(async ([taskId, channel, value]) => { return [ taskId, channel, await this.serde.loadsTyped("json", value), ]; })); const checkpointTuple = { config, checkpoint: deserializedCheckpoint, metadata: (await this.serde.loadsTyped("json", metadata)), pendingWrites, }; if (parentCheckpointId !== undefined) { checkpointTuple.parentConfig = { configurable: { thread_id, checkpoint_ns, checkpoint_id: parentCheckpointId, }, }; } return checkpointTuple; } } else { const checkpoints = this.storage[thread_id]?.[checkpoint_ns]; if (checkpoints !== undefined) { // eslint-disable-next-line prefer-destructuring checkpoint_id = Object.keys(checkpoints).sort((a, b) => b.localeCompare(a))[0]; const saved = checkpoints[checkpoint_id]; const [checkpoint, metadata, parentCheckpointId] = saved; const key = _generateKey(thread_id, checkpoint_ns, checkpoint_id); const pending_sends = await this._getPendingSends(thread_id, checkpoint_ns, parentCheckpointId); const deserializedCheckpoint = { ...(await this.serde.loadsTyped("json", checkpoint)), pending_sends, }; const pendingWrites = await Promise.all(Object.values(this.writes[key] || {}).map(async ([taskId, channel, value]) => { return [ taskId, channel, await this.serde.loadsTyped("json", value), ]; })); const checkpointTuple = { config: { configurable: { thread_id, checkpoint_id, checkpoint_ns, }, }, checkpoint: deserializedCheckpoint, metadata: (await this.serde.loadsTyped("json", metadata)), pendingWrites, }; if (parentCheckpointId !== undefined) { checkpointTuple.parentConfig = { configurable: { thread_id, checkpoint_ns, checkpoint_id: parentCheckpointId, }, }; } return checkpointTuple; } } return undefined; } async *list(config, options) { // eslint-disable-next-line prefer-const let { before, limit, filter } = options ?? {}; const threadIds = config.configurable?.thread_id ? [config.configurable?.thread_id] : Object.keys(this.storage); const configCheckpointNamespace = config.configurable?.checkpoint_ns; const configCheckpointId = config.configurable?.checkpoint_id; for (const threadId of threadIds) { for (const checkpointNamespace of Object.keys(this.storage[threadId] ?? {})) { if (configCheckpointNamespace !== undefined && checkpointNamespace !== configCheckpointNamespace) { continue; } const checkpoints = this.storage[threadId]?.[checkpointNamespace] ?? {}; const sortedCheckpoints = Object.entries(checkpoints).sort((a, b) => b[0].localeCompare(a[0])); for (const [checkpointId, [checkpoint, metadataStr, parentCheckpointId],] of sortedCheckpoints) { // Filter by checkpoint ID from config if (configCheckpointId && checkpointId !== configCheckpointId) { continue; } // Filter by checkpoint ID from before config if (before && before.configurable?.checkpoint_id && checkpointId >= before.configurable.checkpoint_id) { continue; } // Parse metadata const metadata = (await this.serde.loadsTyped("json", metadataStr)); if (filter && !Object.entries(filter).every(([key, value]) => metadata[key] === value)) { continue; } // Limit search results if (limit !== undefined) { if (limit <= 0) break; limit -= 1; } const key = _generateKey(threadId, checkpointNamespace, checkpointId); const writes = Object.values(this.writes[key] || {}); const pending_sends = await this._getPendingSends(threadId, checkpointNamespace, parentCheckpointId); const pendingWrites = await Promise.all(writes.map(async ([taskId, channel, value]) => { return [ taskId, channel, await this.serde.loadsTyped("json", value), ]; })); const deserializedCheckpoint = { ...(await this.serde.loadsTyped("json", checkpoint)), pending_sends, }; const checkpointTuple = { config: { configurable: { thread_id: threadId, checkpoint_ns: checkpointNamespace, checkpoint_id: checkpointId, }, }, checkpoint: deserializedCheckpoint, metadata, pendingWrites, }; if (parentCheckpointId !== undefined) { checkpointTuple.parentConfig = { configurable: { thread_id: threadId, checkpoint_ns: checkpointNamespace, checkpoint_id: parentCheckpointId, }, }; } yield checkpointTuple; } } } } async put(config, checkpoint, metadata) { const preparedCheckpoint = copyCheckpoint(checkpoint); delete preparedCheckpoint.pending_sends; const threadId = config.configurable?.thread_id; const checkpointNamespace = config.configurable?.checkpoint_ns ?? ""; if (threadId === undefined) { throw new Error(`Failed to put checkpoint. The passed RunnableConfig is missing a required "thread_id" field in its "configurable" property.`); } if (!this.storage[threadId]) { this.storage[threadId] = {}; } if (!this.storage[threadId][checkpointNamespace]) { this.storage[threadId][checkpointNamespace] = {}; } const [, serializedCheckpoint] = this.serde.dumpsTyped(preparedCheckpoint); const [, serializedMetadata] = this.serde.dumpsTyped(metadata); this.storage[threadId][checkpointNamespace][checkpoint.id] = [ serializedCheckpoint, serializedMetadata, config.configurable?.checkpoint_id, // parent ]; return { configurable: { thread_id: threadId, checkpoint_ns: checkpointNamespace, checkpoint_id: checkpoint.id, }, }; } async putWrites(config, writes, taskId) { const threadId = config.configurable?.thread_id; const checkpointNamespace = config.configurable?.checkpoint_ns; const checkpointId = config.configurable?.checkpoint_id; if (threadId === undefined) { throw new Error(`Failed to put writes. The passed RunnableConfig is missing a required "thread_id" field in its "configurable" property`); } if (checkpointId === undefined) { throw new Error(`Failed to put writes. The passed RunnableConfig is missing a required "checkpoint_id" field in its "configurable" property.`); } const outerKey = _generateKey(threadId, checkpointNamespace, checkpointId); const outerWrites_ = this.writes[outerKey]; if (this.writes[outerKey] === undefined) { this.writes[outerKey] = {}; } writes.forEach(([channel, value], idx) => { const [, serializedValue] = this.serde.dumpsTyped(value); const innerKey = [ taskId, WRITES_IDX_MAP[channel] || idx, ]; const innerKeyStr = `${innerKey[0]},${innerKey[1]}`; if (innerKey[1] >= 0 && outerWrites_ && innerKeyStr in outerWrites_) { return; } this.writes[outerKey][innerKeyStr] = [taskId, channel, serializedValue]; }); } } //# sourceMappingURL=memory.js.map