@langchain/langgraph-checkpoint
Version:
Library with base interfaces for LangGraph checkpoint savers.
263 lines • 12.2 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.MemorySaver = void 0;
const base_js_1 = require("./base.cjs");
const types_js_1 = require("./serde/types.cjs");
function _generateKey(threadId, checkpointNamespace, checkpointId) {
return JSON.stringify([threadId, checkpointNamespace, checkpointId]);
}
class MemorySaver extends base_js_1.BaseCheckpointSaver {
constructor(serde) {
super(serde);
// thread ID -> checkpoint namespace -> checkpoint ID -> checkpoint mapping
Object.defineProperty(this, "storage", {
enumerable: true,
configurable: true,
writable: true,
value: {}
});
Object.defineProperty(this, "writes", {
enumerable: true,
configurable: true,
writable: true,
value: {}
});
}
async _getPendingSends(threadId, checkpointNs, parentCheckpointId) {
let pendingSends = [];
if (parentCheckpointId !== undefined) {
const key = _generateKey(threadId, checkpointNs, parentCheckpointId);
pendingSends = await Promise.all(Object.values(this.writes[key] || {})
?.filter(([_taskId, channel]) => {
return channel === types_js_1.TASKS;
})
.map(([_taskId, _channel, writes]) => {
return this.serde.loadsTyped("json", writes);
}) ?? []);
}
return pendingSends;
}
async getTuple(config) {
const thread_id = config.configurable?.thread_id;
const checkpoint_ns = config.configurable?.checkpoint_ns ?? "";
let checkpoint_id = (0, base_js_1.getCheckpointId)(config);
if (checkpoint_id) {
const saved = this.storage[thread_id]?.[checkpoint_ns]?.[checkpoint_id];
if (saved !== undefined) {
const [checkpoint, metadata, parentCheckpointId] = saved;
const key = _generateKey(thread_id, checkpoint_ns, checkpoint_id);
const pending_sends = await this._getPendingSends(thread_id, checkpoint_ns, parentCheckpointId);
const deserializedCheckpoint = {
...(await this.serde.loadsTyped("json", checkpoint)),
pending_sends,
};
const pendingWrites = await Promise.all(Object.values(this.writes[key] || {}).map(async ([taskId, channel, value]) => {
return [
taskId,
channel,
await this.serde.loadsTyped("json", value),
];
}));
const checkpointTuple = {
config,
checkpoint: deserializedCheckpoint,
metadata: (await this.serde.loadsTyped("json", metadata)),
pendingWrites,
};
if (parentCheckpointId !== undefined) {
checkpointTuple.parentConfig = {
configurable: {
thread_id,
checkpoint_ns,
checkpoint_id: parentCheckpointId,
},
};
}
return checkpointTuple;
}
}
else {
const checkpoints = this.storage[thread_id]?.[checkpoint_ns];
if (checkpoints !== undefined) {
// eslint-disable-next-line prefer-destructuring
checkpoint_id = Object.keys(checkpoints).sort((a, b) => b.localeCompare(a))[0];
const saved = checkpoints[checkpoint_id];
const [checkpoint, metadata, parentCheckpointId] = saved;
const key = _generateKey(thread_id, checkpoint_ns, checkpoint_id);
const pending_sends = await this._getPendingSends(thread_id, checkpoint_ns, parentCheckpointId);
const deserializedCheckpoint = {
...(await this.serde.loadsTyped("json", checkpoint)),
pending_sends,
};
const pendingWrites = await Promise.all(Object.values(this.writes[key] || {}).map(async ([taskId, channel, value]) => {
return [
taskId,
channel,
await this.serde.loadsTyped("json", value),
];
}));
const checkpointTuple = {
config: {
configurable: {
thread_id,
checkpoint_id,
checkpoint_ns,
},
},
checkpoint: deserializedCheckpoint,
metadata: (await this.serde.loadsTyped("json", metadata)),
pendingWrites,
};
if (parentCheckpointId !== undefined) {
checkpointTuple.parentConfig = {
configurable: {
thread_id,
checkpoint_ns,
checkpoint_id: parentCheckpointId,
},
};
}
return checkpointTuple;
}
}
return undefined;
}
async *list(config, options) {
// eslint-disable-next-line prefer-const
let { before, limit, filter } = options ?? {};
const threadIds = config.configurable?.thread_id
? [config.configurable?.thread_id]
: Object.keys(this.storage);
const configCheckpointNamespace = config.configurable?.checkpoint_ns;
const configCheckpointId = config.configurable?.checkpoint_id;
for (const threadId of threadIds) {
for (const checkpointNamespace of Object.keys(this.storage[threadId] ?? {})) {
if (configCheckpointNamespace !== undefined &&
checkpointNamespace !== configCheckpointNamespace) {
continue;
}
const checkpoints = this.storage[threadId]?.[checkpointNamespace] ?? {};
const sortedCheckpoints = Object.entries(checkpoints).sort((a, b) => b[0].localeCompare(a[0]));
for (const [checkpointId, [checkpoint, metadataStr, parentCheckpointId],] of sortedCheckpoints) {
// Filter by checkpoint ID from config
if (configCheckpointId && checkpointId !== configCheckpointId) {
continue;
}
// Filter by checkpoint ID from before config
if (before &&
before.configurable?.checkpoint_id &&
checkpointId >= before.configurable.checkpoint_id) {
continue;
}
// Parse metadata
const metadata = (await this.serde.loadsTyped("json", metadataStr));
if (filter &&
!Object.entries(filter).every(([key, value]) => metadata[key] === value)) {
continue;
}
// Limit search results
if (limit !== undefined) {
if (limit <= 0)
break;
limit -= 1;
}
const key = _generateKey(threadId, checkpointNamespace, checkpointId);
const writes = Object.values(this.writes[key] || {});
const pending_sends = await this._getPendingSends(threadId, checkpointNamespace, parentCheckpointId);
const pendingWrites = await Promise.all(writes.map(async ([taskId, channel, value]) => {
return [
taskId,
channel,
await this.serde.loadsTyped("json", value),
];
}));
const deserializedCheckpoint = {
...(await this.serde.loadsTyped("json", checkpoint)),
pending_sends,
};
const checkpointTuple = {
config: {
configurable: {
thread_id: threadId,
checkpoint_ns: checkpointNamespace,
checkpoint_id: checkpointId,
},
},
checkpoint: deserializedCheckpoint,
metadata,
pendingWrites,
};
if (parentCheckpointId !== undefined) {
checkpointTuple.parentConfig = {
configurable: {
thread_id: threadId,
checkpoint_ns: checkpointNamespace,
checkpoint_id: parentCheckpointId,
},
};
}
yield checkpointTuple;
}
}
}
}
async put(config, checkpoint, metadata) {
const preparedCheckpoint = (0, base_js_1.copyCheckpoint)(checkpoint);
delete preparedCheckpoint.pending_sends;
const threadId = config.configurable?.thread_id;
const checkpointNamespace = config.configurable?.checkpoint_ns ?? "";
if (threadId === undefined) {
throw new Error(`Failed to put checkpoint. The passed RunnableConfig is missing a required "thread_id" field in its "configurable" property.`);
}
if (!this.storage[threadId]) {
this.storage[threadId] = {};
}
if (!this.storage[threadId][checkpointNamespace]) {
this.storage[threadId][checkpointNamespace] = {};
}
const [, serializedCheckpoint] = this.serde.dumpsTyped(preparedCheckpoint);
const [, serializedMetadata] = this.serde.dumpsTyped(metadata);
this.storage[threadId][checkpointNamespace][checkpoint.id] = [
serializedCheckpoint,
serializedMetadata,
config.configurable?.checkpoint_id, // parent
];
return {
configurable: {
thread_id: threadId,
checkpoint_ns: checkpointNamespace,
checkpoint_id: checkpoint.id,
},
};
}
async putWrites(config, writes, taskId) {
const threadId = config.configurable?.thread_id;
const checkpointNamespace = config.configurable?.checkpoint_ns;
const checkpointId = config.configurable?.checkpoint_id;
if (threadId === undefined) {
throw new Error(`Failed to put writes. The passed RunnableConfig is missing a required "thread_id" field in its "configurable" property`);
}
if (checkpointId === undefined) {
throw new Error(`Failed to put writes. The passed RunnableConfig is missing a required "checkpoint_id" field in its "configurable" property.`);
}
const outerKey = _generateKey(threadId, checkpointNamespace, checkpointId);
const outerWrites_ = this.writes[outerKey];
if (this.writes[outerKey] === undefined) {
this.writes[outerKey] = {};
}
writes.forEach(([channel, value], idx) => {
const [, serializedValue] = this.serde.dumpsTyped(value);
const innerKey = [
taskId,
base_js_1.WRITES_IDX_MAP[channel] || idx,
];
const innerKeyStr = `${innerKey[0]},${innerKey[1]}`;
if (innerKey[1] >= 0 && outerWrites_ && innerKeyStr in outerWrites_) {
return;
}
this.writes[outerKey][innerKeyStr] = [taskId, channel, serializedValue];
});
}
}
exports.MemorySaver = MemorySaver;
//# sourceMappingURL=memory.js.map