@langchain/langgraph-checkpoint-postgres
Version:
450 lines • 17.4 kB
JavaScript
var __importDefault = (this && this.__importDefault) || function (mod) {
return (mod && mod.__esModule) ? mod : { "default": mod };
};
Object.defineProperty(exports, "__esModule", { value: true });
exports.PostgresSaver = void 0;
const langgraph_checkpoint_1 = require("@langchain/langgraph-checkpoint");
const pg_1 = __importDefault(require("pg"));
const migrations_js_1 = require("./migrations.cjs");
const sql_js_1 = require("./sql.cjs");
const _defaultOptions = {
schema: "public",
};
const _ensureCompleteOptions = (options) => {
return {
...options,
schema: options?.schema ?? _defaultOptions.schema,
};
};
const { Pool } = pg_1.default;
/**
* LangGraph checkpointer that uses a Postgres instance as the backing store.
* Uses the [node-postgres](https://node-postgres.com/) package internally
* to connect to a Postgres instance.
*
* @example
* ```
* import { ChatOpenAI } from "@langchain/openai";
* import { PostgresSaver } from "@langchain/langgraph-checkpoint-postgres";
* import { createReactAgent } from "@langchain/langgraph/prebuilt";
*
* const checkpointer = PostgresSaver.fromConnString(
* "postgresql://user:password@localhost:5432/db",
* // optional configuration object
* {
* schema: "custom_schema" // defaults to "public"
* }
* );
*
* // NOTE: you need to call .setup() the first time you're using your checkpointer
* await checkpointer.setup();
*
* const graph = createReactAgent({
* tools: [getWeather],
* llm: new ChatOpenAI({
* model: "gpt-4o-mini",
* }),
* checkpointSaver: checkpointer,
* });
* const config = { configurable: { thread_id: "1" } };
*
* await graph.invoke({
* messages: [{
* role: "user",
* content: "what's the weather in sf"
* }],
* }, config);
* ```
*/
class PostgresSaver extends langgraph_checkpoint_1.BaseCheckpointSaver {
constructor(pool, serde, options) {
super(serde);
Object.defineProperty(this, "pool", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "options", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "SQL_STATEMENTS", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "isSetup", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.pool = pool;
this.isSetup = false;
this.options = _ensureCompleteOptions(options);
this.SQL_STATEMENTS = (0, sql_js_1.getSQLStatements)(this.options.schema);
}
/**
* Creates a new instance of PostgresSaver from a connection string.
*
* @param {string} connString - The connection string to connect to the Postgres database.
* @param {PostgresSaverOptions} [options] - Optional configuration object.
* @returns {PostgresSaver} A new instance of PostgresSaver.
*
* @example
* const connString = "postgresql://user:password@localhost:5432/db";
* const checkpointer = PostgresSaver.fromConnString(connString, {
* schema: "custom_schema" // defaults to "public"
* });
* await checkpointer.setup();
*/
static fromConnString(connString, options) {
const pool = new Pool({ connectionString: connString });
return new PostgresSaver(pool, undefined, options);
}
/**
* Set up the checkpoint database asynchronously.
*
* This method creates the necessary tables in the Postgres database if they don't
* already exist and runs database migrations. It MUST be called directly by the user
* the first time checkpointer is used.
*/
async setup() {
const client = await this.pool.connect();
const SCHEMA_TABLES = (0, sql_js_1.getTablesWithSchema)(this.options.schema);
try {
await client.query(`CREATE SCHEMA IF NOT EXISTS ${this.options.schema}`);
let version = -1;
const MIGRATIONS = (0, migrations_js_1.getMigrations)(this.options.schema);
try {
const result = await client.query(`SELECT v FROM ${SCHEMA_TABLES.checkpoint_migrations} ORDER BY v DESC LIMIT 1`);
if (result.rows.length > 0) {
version = result.rows[0].v;
}
}
catch (error) {
// Assume table doesn't exist if there's an error
if (typeof error === "object" &&
error !== null &&
"code" in error &&
typeof error.code === "string" &&
error.code === "42P01" // Postgres error code for undefined_table
) {
version = -1;
}
else {
throw error;
}
}
for (let v = version + 1; v < MIGRATIONS.length; v += 1) {
await client.query(MIGRATIONS[v]);
await client.query(`INSERT INTO ${SCHEMA_TABLES.checkpoint_migrations} (v) VALUES ($1)`, [v]);
}
}
finally {
client.release();
}
}
async _loadCheckpoint(checkpoint, channelValues, pendingSends) {
return {
...checkpoint,
pending_sends: await Promise.all((pendingSends || []).map(([c, b]) => this.serde.loadsTyped(c.toString(), b))),
channel_values: await this._loadBlobs(channelValues),
};
}
async _loadBlobs(blobValues) {
if (!blobValues || blobValues.length === 0) {
return {};
}
const entries = await Promise.all(blobValues
.filter(([, t]) => new TextDecoder().decode(t) !== "empty")
.map(async ([k, t, v]) => [
new TextDecoder().decode(k),
await this.serde.loadsTyped(new TextDecoder().decode(t), v),
]));
return Object.fromEntries(entries);
}
async _loadMetadata(metadata) {
const [type, dumpedValue] = this.serde.dumpsTyped(metadata);
return this.serde.loadsTyped(type, dumpedValue);
}
async _loadWrites(writes) {
const decoder = new TextDecoder();
return writes
? await Promise.all(writes.map(async ([tid, channel, t, v]) => [
decoder.decode(tid),
decoder.decode(channel),
await this.serde.loadsTyped(decoder.decode(t), v),
]))
: [];
}
_dumpBlobs(threadId, checkpointNs, values, versions) {
if (Object.keys(versions).length === 0) {
return [];
}
return Object.entries(versions).map(([k, ver]) => {
const [type, value] = k in values ? this.serde.dumpsTyped(values[k]) : ["empty", null];
return [
threadId,
checkpointNs,
k,
ver.toString(),
type,
value ? new Uint8Array(value) : undefined,
];
});
}
_dumpCheckpoint(checkpoint) {
const serialized = {
...checkpoint,
pending_sends: [],
};
if ("channel_values" in serialized) {
delete serialized.channel_values;
}
return serialized;
}
_dumpMetadata(metadata) {
const [, serializedMetadata] = this.serde.dumpsTyped(metadata);
// We need to remove null characters before writing
return JSON.parse(new TextDecoder().decode(serializedMetadata).replace(/\0/g, ""));
}
_dumpWrites(threadId, checkpointNs, checkpointId, taskId, writes) {
return writes.map(([channel, value], idx) => {
const [type, serializedValue] = this.serde.dumpsTyped(value);
return [
threadId,
checkpointNs,
checkpointId,
taskId,
langgraph_checkpoint_1.WRITES_IDX_MAP[channel] ?? idx,
channel,
type,
new Uint8Array(serializedValue),
];
});
}
/**
* Return WHERE clause predicates for a given list() config, filter, cursor.
*
* This method returns a tuple of a string and a tuple of values. The string
* is the parameterized WHERE clause predicate (including the WHERE keyword):
* "WHERE column1 = $1 AND column2 IS $2". The list of values contains the
* values for each of the corresponding parameters.
*/
_searchWhere(config, filter, before) {
const wheres = [];
const paramValues = [];
// construct predicate for config filter
if (config?.configurable?.thread_id) {
wheres.push(`thread_id = $${paramValues.length + 1}`);
paramValues.push(config.configurable.thread_id);
}
// strict checks for undefined/null because empty strings are falsy
if (config?.configurable?.checkpoint_ns !== undefined &&
config?.configurable?.checkpoint_ns !== null) {
wheres.push(`checkpoint_ns = $${paramValues.length + 1}`);
paramValues.push(config.configurable.checkpoint_ns);
}
if (config?.configurable?.checkpoint_id) {
wheres.push(`checkpoint_id = $${paramValues.length + 1}`);
paramValues.push(config.configurable.checkpoint_id);
}
// construct predicate for metadata filter
if (filter && Object.keys(filter).length > 0) {
wheres.push(`metadata @> $${paramValues.length + 1}`);
paramValues.push(JSON.stringify(filter));
}
// construct predicate for `before`
if (before?.configurable?.checkpoint_id !== undefined) {
wheres.push(`checkpoint_id < $${paramValues.length + 1}`);
paramValues.push(before.configurable.checkpoint_id);
}
return [
wheres.length > 0 ? `WHERE ${wheres.join(" AND ")}` : "",
paramValues,
];
}
/**
* Get a checkpoint tuple from the database.
* This method retrieves a checkpoint tuple from the Postgres database
* based on the provided config. If the config's configurable field contains
* a "checkpoint_id" key, the checkpoint with the matching thread_id and
* namespace is retrieved. Otherwise, the latest checkpoint for the given
* thread_id is retrieved.
* @param config The config to use for retrieving the checkpoint.
* @returns The retrieved checkpoint tuple, or undefined.
*/
async getTuple(config) {
const { thread_id, checkpoint_ns = "", checkpoint_id, } = config.configurable ?? {};
let args;
let where;
if (checkpoint_id) {
where =
"WHERE thread_id = $1 AND checkpoint_ns = $2 AND checkpoint_id = $3";
args = [thread_id, checkpoint_ns, checkpoint_id];
}
else {
where =
"WHERE thread_id = $1 AND checkpoint_ns = $2 ORDER BY checkpoint_id DESC LIMIT 1";
args = [thread_id, checkpoint_ns];
}
const result = await this.pool.query(this.SQL_STATEMENTS.SELECT_SQL + where, args);
const [row] = result.rows;
if (row === undefined) {
return undefined;
}
const checkpoint = await this._loadCheckpoint(row.checkpoint, row.channel_values, row.pending_sends);
const finalConfig = {
configurable: {
thread_id,
checkpoint_ns,
checkpoint_id: row.checkpoint_id,
},
};
const metadata = await this._loadMetadata(row.metadata);
const parentConfig = row.parent_checkpoint_id
? {
configurable: {
thread_id,
checkpoint_ns,
checkpoint_id: row.parent_checkpoint_id,
},
}
: undefined;
const pendingWrites = await this._loadWrites(row.pending_writes);
return {
config: finalConfig,
checkpoint,
metadata,
parentConfig,
pendingWrites,
};
}
/**
* List checkpoints from the database.
*
* This method retrieves a list of checkpoint tuples from the Postgres database based
* on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first).
*/
async *list(config, options) {
const { filter, before, limit } = options ?? {};
const [where, args] = this._searchWhere(config, filter, before);
let query = `${this.SQL_STATEMENTS.SELECT_SQL}${where} ORDER BY checkpoint_id DESC`;
if (limit !== undefined) {
query += ` LIMIT ${Number.parseInt(limit.toString(), 10)}`; // sanitize via parseInt, as limit could be an externally provided value
}
const result = await this.pool.query(query, args);
for (const value of result.rows) {
yield {
config: {
configurable: {
thread_id: value.thread_id,
checkpoint_ns: value.checkpoint_ns,
checkpoint_id: value.checkpoint_id,
},
},
checkpoint: await this._loadCheckpoint(value.checkpoint, value.channel_values, value.pending_sends),
metadata: await this._loadMetadata(value.metadata),
parentConfig: value.parent_checkpoint_id
? {
configurable: {
thread_id: value.thread_id,
checkpoint_ns: value.checkpoint_ns,
checkpoint_id: value.parent_checkpoint_id,
},
}
: undefined,
pendingWrites: await this._loadWrites(value.pending_writes),
};
}
}
/**
* Save a checkpoint to the database.
*
* This method saves a checkpoint to the Postgres database. The checkpoint is associated
* with the provided config and its parent config (if any).
* @param config
* @param checkpoint
* @param metadata
* @returns
*/
async put(config, checkpoint, metadata, newVersions) {
if (config.configurable === undefined) {
throw new Error(`Missing "configurable" field in "config" param`);
}
const { thread_id, checkpoint_ns = "", checkpoint_id, } = config.configurable;
const nextConfig = {
configurable: {
thread_id,
checkpoint_ns,
checkpoint_id: checkpoint.id,
},
};
const client = await this.pool.connect();
const serializedCheckpoint = this._dumpCheckpoint(checkpoint);
try {
await client.query("BEGIN");
const serializedBlobs = this._dumpBlobs(thread_id, checkpoint_ns, checkpoint.channel_values, newVersions);
for (const serializedBlob of serializedBlobs) {
await client.query(this.SQL_STATEMENTS.UPSERT_CHECKPOINT_BLOBS_SQL, serializedBlob);
}
await client.query(this.SQL_STATEMENTS.UPSERT_CHECKPOINTS_SQL, [
thread_id,
checkpoint_ns,
checkpoint.id,
checkpoint_id,
serializedCheckpoint,
this._dumpMetadata(metadata),
]);
await client.query("COMMIT");
}
catch (e) {
await client.query("ROLLBACK");
throw e;
}
finally {
client.release();
}
return nextConfig;
}
/**
* Store intermediate writes linked to a checkpoint.
*
* This method saves intermediate writes associated with a checkpoint to the Postgres database.
* @param config Configuration of the related checkpoint.
* @param writes List of writes to store.
* @param taskId Identifier for the task creating the writes.
*/
async putWrites(config, writes, taskId) {
const query = writes.every((w) => w[0] in langgraph_checkpoint_1.WRITES_IDX_MAP)
? this.SQL_STATEMENTS.UPSERT_CHECKPOINT_WRITES_SQL
: this.SQL_STATEMENTS.INSERT_CHECKPOINT_WRITES_SQL;
const dumpedWrites = this._dumpWrites(config.configurable?.thread_id, config.configurable?.checkpoint_ns, config.configurable?.checkpoint_id, taskId, writes);
const client = await this.pool.connect();
try {
await client.query("BEGIN");
for await (const dumpedWrite of dumpedWrites) {
await client.query(query, dumpedWrite);
}
await client.query("COMMIT");
}
catch (error) {
await client.query("ROLLBACK");
throw error;
}
finally {
client.release();
}
}
async end() {
return this.pool.end();
}
}
exports.PostgresSaver = PostgresSaver;
//# sourceMappingURL=index.js.map
;