@langchain/langgraph-checkpoint-sqlite
Version:
307 lines (305 loc) • 11.1 kB
JavaScript
import Database$1 from "better-sqlite3";
import { BaseCheckpointSaver, TASKS, copyCheckpoint, maxChannelVersion } from "@langchain/langgraph-checkpoint";
//#region src/index.ts
const checkpointMetadataKeys = [
"source",
"step",
"parents"
];
function validateKeys(keys) {
return keys;
}
const validCheckpointMetadataKeys = validateKeys(checkpointMetadataKeys);
function prepareSql(db, checkpointId) {
const sql = `
SELECT
thread_id,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
type,
checkpoint,
metadata,
(
SELECT
json_group_array(
json_object(
'task_id', pw.task_id,
'channel', pw.channel,
'type', pw.type,
'value', CAST(pw.value AS TEXT)
)
)
FROM writes as pw
WHERE pw.thread_id = checkpoints.thread_id
AND pw.checkpoint_ns = checkpoints.checkpoint_ns
AND pw.checkpoint_id = checkpoints.checkpoint_id
) as pending_writes,
(
SELECT
json_group_array(
json_object(
'type', ps.type,
'value', CAST(ps.value AS TEXT)
)
)
FROM writes as ps
WHERE ps.thread_id = checkpoints.thread_id
AND ps.checkpoint_ns = checkpoints.checkpoint_ns
AND ps.checkpoint_id = checkpoints.parent_checkpoint_id
AND ps.channel = '${TASKS}'
ORDER BY ps.idx
) as pending_sends
FROM checkpoints
WHERE thread_id = ? AND checkpoint_ns = ? ${checkpointId ? "AND checkpoint_id = ?" : "ORDER BY checkpoint_id DESC LIMIT 1"}`;
return db.prepare(sql);
}
var SqliteSaver = class SqliteSaver extends BaseCheckpointSaver {
db;
isSetup;
withoutCheckpoint;
withCheckpoint;
constructor(db, serde) {
super(serde);
this.db = db;
this.isSetup = false;
}
static fromConnString(connStringOrLocalPath) {
return new SqliteSaver(new Database$1(connStringOrLocalPath));
}
setup() {
if (this.isSetup) return;
this.db.pragma("journal_mode=WAL");
this.db.exec(`
CREATE TABLE IF NOT EXISTS checkpoints (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
parent_checkpoint_id TEXT,
type TEXT,
checkpoint BLOB,
metadata BLOB,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
);`);
this.db.exec(`
CREATE TABLE IF NOT EXISTS writes (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
task_id TEXT NOT NULL,
idx INTEGER NOT NULL,
channel TEXT NOT NULL,
type TEXT,
value BLOB,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
);`);
this.withoutCheckpoint = prepareSql(this.db, false);
this.withCheckpoint = prepareSql(this.db, true);
this.isSetup = true;
}
async getTuple(config) {
this.setup();
const { thread_id, checkpoint_ns = "", checkpoint_id } = config.configurable ?? {};
const args = [thread_id, checkpoint_ns];
if (checkpoint_id) args.push(checkpoint_id);
const row = (checkpoint_id ? this.withCheckpoint : this.withoutCheckpoint).get(...args);
if (row === void 0) return void 0;
let finalConfig = config;
if (!checkpoint_id) finalConfig = { configurable: {
thread_id: row.thread_id,
checkpoint_ns,
checkpoint_id: row.checkpoint_id
} };
if (finalConfig.configurable?.thread_id === void 0 || finalConfig.configurable?.checkpoint_id === void 0) throw new Error("Missing thread_id or checkpoint_id");
const pendingWrites = await Promise.all(JSON.parse(row.pending_writes).map(async (write) => {
return [
write.task_id,
write.channel,
await this.serde.loadsTyped(write.type ?? "json", write.value ?? "")
];
}));
const checkpoint = await this.serde.loadsTyped(row.type ?? "json", row.checkpoint);
if (checkpoint.v < 4 && row.parent_checkpoint_id != null) await this.migratePendingSends(checkpoint, row.thread_id, row.parent_checkpoint_id);
return {
checkpoint,
config: finalConfig,
metadata: await this.serde.loadsTyped(row.type ?? "json", row.metadata),
parentConfig: row.parent_checkpoint_id ? { configurable: {
thread_id: row.thread_id,
checkpoint_ns,
checkpoint_id: row.parent_checkpoint_id
} } : void 0,
pendingWrites
};
}
async *list(config, options) {
const { limit, before, filter } = options ?? {};
this.setup();
const thread_id = config.configurable?.thread_id;
const checkpoint_ns = config.configurable?.checkpoint_ns;
let sql = `
SELECT
thread_id,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
type,
checkpoint,
metadata,
(
SELECT
json_group_array(
json_object(
'task_id', pw.task_id,
'channel', pw.channel,
'type', pw.type,
'value', CAST(pw.value AS TEXT)
)
)
FROM writes as pw
WHERE pw.thread_id = checkpoints.thread_id
AND pw.checkpoint_ns = checkpoints.checkpoint_ns
AND pw.checkpoint_id = checkpoints.checkpoint_id
) as pending_writes,
(
SELECT
json_group_array(
json_object(
'type', ps.type,
'value', CAST(ps.value AS TEXT)
)
)
FROM writes as ps
WHERE ps.thread_id = checkpoints.thread_id
AND ps.checkpoint_ns = checkpoints.checkpoint_ns
AND ps.checkpoint_id = checkpoints.parent_checkpoint_id
AND ps.channel = '${TASKS}'
ORDER BY ps.idx
) as pending_sends
FROM checkpoints\n`;
const whereClause = [];
if (thread_id) whereClause.push("thread_id = ?");
if (checkpoint_ns !== void 0 && checkpoint_ns !== null) whereClause.push("checkpoint_ns = ?");
if (before?.configurable?.checkpoint_id !== void 0) whereClause.push("checkpoint_id < ?");
const sanitizedFilter = Object.fromEntries(Object.entries(filter ?? {}).filter(([key, value]) => value !== void 0 && validCheckpointMetadataKeys.includes(key)));
whereClause.push(...Object.entries(sanitizedFilter).map(([key]) => `jsonb(CAST(metadata AS TEXT))->'$.${key}' = ?`));
if (whereClause.length > 0) sql += `WHERE\n ${whereClause.join(" AND\n ")}\n`;
sql += "\nORDER BY checkpoint_id DESC";
if (limit) sql += ` LIMIT ${parseInt(limit, 10)}`;
const args = [
thread_id,
checkpoint_ns,
before?.configurable?.checkpoint_id,
...Object.values(sanitizedFilter).map((value) => JSON.stringify(value))
].filter((value) => value !== void 0 && value !== null);
const rows = this.db.prepare(sql).all(...args);
if (rows) for (const row of rows) {
const pendingWrites = await Promise.all(JSON.parse(row.pending_writes).map(async (write) => {
return [
write.task_id,
write.channel,
await this.serde.loadsTyped(write.type ?? "json", write.value ?? "")
];
}));
const checkpoint = await this.serde.loadsTyped(row.type ?? "json", row.checkpoint);
if (checkpoint.v < 4 && row.parent_checkpoint_id != null) await this.migratePendingSends(checkpoint, row.thread_id, row.parent_checkpoint_id);
yield {
config: { configurable: {
thread_id: row.thread_id,
checkpoint_ns: row.checkpoint_ns,
checkpoint_id: row.checkpoint_id
} },
checkpoint,
metadata: await this.serde.loadsTyped(row.type ?? "json", row.metadata),
parentConfig: row.parent_checkpoint_id ? { configurable: {
thread_id: row.thread_id,
checkpoint_ns: row.checkpoint_ns,
checkpoint_id: row.parent_checkpoint_id
} } : void 0,
pendingWrites
};
}
}
async put(config, checkpoint, metadata) {
this.setup();
if (!config.configurable) throw new Error("Empty configuration supplied.");
const thread_id = config.configurable?.thread_id;
const checkpoint_ns = config.configurable?.checkpoint_ns ?? "";
const parent_checkpoint_id = config.configurable?.checkpoint_id;
if (!thread_id) throw new Error(`Missing "thread_id" field in passed "config.configurable".`);
const preparedCheckpoint = copyCheckpoint(checkpoint);
const [[type1, serializedCheckpoint], [type2, serializedMetadata]] = await Promise.all([this.serde.dumpsTyped(preparedCheckpoint), this.serde.dumpsTyped(metadata)]);
if (type1 !== type2) throw new Error("Failed to serialized checkpoint and metadata to the same type.");
const row = [
thread_id,
checkpoint_ns,
checkpoint.id,
parent_checkpoint_id,
type1,
serializedCheckpoint,
serializedMetadata
];
this.db.prepare(`INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)`).run(...row);
return { configurable: {
thread_id,
checkpoint_ns,
checkpoint_id: checkpoint.id
} };
}
async putWrites(config, writes, taskId) {
this.setup();
if (!config.configurable) throw new Error("Empty configuration supplied.");
if (!config.configurable?.thread_id) throw new Error("Missing thread_id field in config.configurable.");
if (!config.configurable?.checkpoint_id) throw new Error("Missing checkpoint_id field in config.configurable.");
const stmt = this.db.prepare(`
INSERT OR REPLACE INTO writes
(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`);
this.db.transaction((rows) => {
for (const row of rows) stmt.run(...row);
})(await Promise.all(writes.map(async (write, idx) => {
const [type, serializedWrite] = await this.serde.dumpsTyped(write[1]);
return [
config.configurable?.thread_id,
config.configurable?.checkpoint_ns,
config.configurable?.checkpoint_id,
taskId,
idx,
write[0],
type,
serializedWrite
];
})));
}
async deleteThread(threadId) {
this.db.transaction(() => {
this.db.prepare(`DELETE FROM checkpoints WHERE thread_id = ?`).run(threadId);
this.db.prepare(`DELETE FROM writes WHERE thread_id = ?`).run(threadId);
})();
}
async migratePendingSends(checkpoint, threadId, parentCheckpointId) {
const { pending_sends } = this.db.prepare(`
SELECT
checkpoint_id,
json_group_array(
json_object(
'type', ps.type,
'value', CAST(ps.value AS TEXT)
)
) as pending_sends
FROM writes as ps
WHERE ps.thread_id = ?
AND ps.checkpoint_id = ?
AND ps.channel = '${TASKS}'
ORDER BY ps.idx
`).get(threadId, parentCheckpointId);
const mutableCheckpoint = checkpoint;
mutableCheckpoint.channel_values ??= {};
mutableCheckpoint.channel_values[TASKS] = await Promise.all(JSON.parse(pending_sends).map(({ type, value }) => this.serde.loadsTyped(type, value)));
mutableCheckpoint.channel_versions[TASKS] = Object.keys(checkpoint.channel_versions).length > 0 ? maxChannelVersion(...Object.values(checkpoint.channel_versions)) : this.getNextVersion(void 0);
}
};
//#endregion
export { SqliteSaver };
//# sourceMappingURL=index.js.map