@langchain/langgraph
Version:
LangGraph
762 lines • 29.3 kB
JavaScript
/* eslint-disable @typescript-eslint/no-use-before-define */
import { _coerceToRunnable, Runnable, } from "@langchain/core/runnables";
import { isBaseChannel } from "../channels/base.js";
import { CompiledGraph, Graph, Branch, } from "./graph.js";
import { ChannelWrite, PASSTHROUGH, } from "../pregel/write.js";
import { ChannelRead, PregelNode } from "../pregel/read.js";
import { NamedBarrierValue, NamedBarrierValueAfterFinish, } from "../channels/named_barrier_value.js";
import { EphemeralValue } from "../channels/ephemeral_value.js";
import { RunnableCallable } from "../utils.js";
import { isCommand, _isSend, CHECKPOINT_NAMESPACE_END, CHECKPOINT_NAMESPACE_SEPARATOR, Command, END, SELF, START, TAG_HIDDEN, } from "../constants.js";
import { InvalidUpdateError, ParentCommand } from "../errors.js";
import { getChannel, } from "./annotation.js";
import { isConfiguredManagedValue } from "../managed/base.js";
import { isPregelLike } from "../pregel/utils/subgraph.js";
import { getChannelsFromZod, applyZodPlugin, isAnyZodObject, } from "./zod/state.js";
import { LastValueAfterFinish } from "../channels/last_value.js";
const ROOT = "__root__";
const PartialStateSchema = Symbol.for("langgraph.state.partial");
/**
* A graph whose nodes communicate by reading and writing to a shared state.
* Each node takes a defined `State` as input and returns a `Partial<State>`.
*
* Each state key can optionally be annotated with a reducer function that
* will be used to aggregate the values of that key received from multiple nodes.
* The signature of a reducer function is (left: Value, right: UpdateValue) => Value.
*
* See {@link Annotation} for more on defining state.
*
* After adding nodes and edges to your graph, you must call `.compile()` on it before
* you can use it.
*
* @example
* ```ts
* import {
* type BaseMessage,
* AIMessage,
* HumanMessage,
* } from "@langchain/core/messages";
* import { StateGraph, Annotation } from "@langchain/langgraph";
*
* // Define a state with a single key named "messages" that will
* // combine a returned BaseMessage or arrays of BaseMessages
* const StateAnnotation = Annotation.Root({
* sentiment: Annotation<string>,
* messages: Annotation<BaseMessage[]>({
* reducer: (left: BaseMessage[], right: BaseMessage | BaseMessage[]) => {
* if (Array.isArray(right)) {
* return left.concat(right);
* }
* return left.concat([right]);
* },
* default: () => [],
* }),
* });
*
* const graphBuilder = new StateGraph(StateAnnotation);
*
* // A node in the graph that returns an object with a "messages" key
* // will update the state by combining the existing value with the returned one.
* const myNode = (state: typeof StateAnnotation.State) => {
* return {
* messages: [new AIMessage("Some new response")],
* sentiment: "positive",
* };
* };
*
* const graph = graphBuilder
* .addNode("myNode", myNode)
* .addEdge("__start__", "myNode")
* .addEdge("myNode", "__end__")
* .compile();
*
* await graph.invoke({ messages: [new HumanMessage("how are you?")] });
*
* // {
* // messages: [HumanMessage("how are you?"), AIMessage("Some new response")],
* // sentiment: "positive",
* // }
* ```
*/
export class StateGraph extends Graph {
constructor(fields, configSchema) {
super();
Object.defineProperty(this, "channels", {
enumerable: true,
configurable: true,
writable: true,
value: {}
});
// TODO: this doesn't dedupe edges as in py, so worth fixing at some point
Object.defineProperty(this, "waitingEdges", {
enumerable: true,
configurable: true,
writable: true,
value: new Set()
});
/** @internal */
Object.defineProperty(this, "_schemaDefinition", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
/** @internal */
Object.defineProperty(this, "_schemaRuntimeDefinition", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
/** @internal */
Object.defineProperty(this, "_inputDefinition", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
/** @internal */
Object.defineProperty(this, "_inputRuntimeDefinition", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
/** @internal */
Object.defineProperty(this, "_outputDefinition", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
/** @internal */
Object.defineProperty(this, "_outputRuntimeDefinition", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
/**
* Map schemas to managed values
* @internal
*/
Object.defineProperty(this, "_schemaDefinitions", {
enumerable: true,
configurable: true,
writable: true,
value: new Map()
});
/** @internal Used only for typing. */
Object.defineProperty(this, "_configSchema", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
/** @internal */
Object.defineProperty(this, "_configRuntimeSchema", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
if (isZodStateGraphArgsWithStateSchema(fields)) {
const stateDef = getChannelsFromZod(fields.state);
const inputDef = fields.input != null ? getChannelsFromZod(fields.input) : stateDef;
const outputDef = fields.output != null ? getChannelsFromZod(fields.output) : stateDef;
this._schemaDefinition = stateDef;
this._schemaRuntimeDefinition = fields.state;
this._inputDefinition = inputDef;
this._inputRuntimeDefinition = fields.input ?? PartialStateSchema;
this._outputDefinition = outputDef;
this._outputRuntimeDefinition = fields.output ?? fields.state;
}
else if (isAnyZodObject(fields)) {
const stateDef = getChannelsFromZod(fields);
this._schemaDefinition = stateDef;
this._schemaRuntimeDefinition = fields;
this._inputDefinition = stateDef;
this._inputRuntimeDefinition = PartialStateSchema;
this._outputDefinition = stateDef;
this._outputRuntimeDefinition = fields;
}
else if (isStateGraphArgsWithInputOutputSchemas(fields)) {
this._schemaDefinition = fields.input.spec;
this._inputDefinition = fields.input.spec;
this._outputDefinition = fields.output.spec;
}
else if (isStateGraphArgsWithStateSchema(fields)) {
this._schemaDefinition = fields.stateSchema.spec;
this._inputDefinition = (fields.input?.spec ??
this._schemaDefinition);
this._outputDefinition = (fields.output?.spec ??
this._schemaDefinition);
}
else if (isStateDefinition(fields) || isAnnotationRoot(fields)) {
const spec = isAnnotationRoot(fields) ? fields.spec : fields;
this._schemaDefinition = spec;
}
else if (isStateGraphArgs(fields)) {
const spec = _getChannels(fields.channels);
this._schemaDefinition = spec;
}
else {
throw new Error("Invalid StateGraph input. Make sure to pass a valid Annotation.Root or Zod schema.");
}
this._inputDefinition ??= this._schemaDefinition;
this._outputDefinition ??= this._schemaDefinition;
this._addSchema(this._schemaDefinition);
this._addSchema(this._inputDefinition);
this._addSchema(this._outputDefinition);
if (isAnyZodObject(configSchema)) {
this._configRuntimeSchema = configSchema.passthrough();
}
}
get allEdges() {
return new Set([
...this.edges,
...Array.from(this.waitingEdges).flatMap(([starts, end]) => starts.map((start) => [start, end])),
]);
}
_addSchema(stateDefinition) {
if (this._schemaDefinitions.has(stateDefinition)) {
return;
}
// TODO: Support managed values
this._schemaDefinitions.set(stateDefinition, stateDefinition);
for (const [key, val] of Object.entries(stateDefinition)) {
let channel;
if (typeof val === "function") {
channel = val();
}
else {
channel = val;
}
if (this.channels[key] !== undefined) {
if (this.channels[key] !== channel) {
if (!isConfiguredManagedValue(channel) &&
channel.lc_graph_name !== "LastValue") {
throw new Error(`Channel "${key}" already exists with a different type.`);
}
}
}
else {
this.channels[key] = channel;
}
}
}
addNode(...args) {
function isMultipleNodes(args) {
return args.length >= 1 && typeof args[0] !== "string";
}
const nodes = (isMultipleNodes(args) // eslint-disable-line no-nested-ternary
? Array.isArray(args[0])
? args[0]
: Object.entries(args[0]).map(([key, action]) => [
key,
action,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
action[Symbol.for("langgraph.state.node")] ?? undefined,
])
: [[args[0], args[1], args[2]]]);
if (nodes.length === 0) {
throw new Error("No nodes provided in `addNode`");
}
for (const [key, action, options] of nodes) {
if (key in this.channels) {
throw new Error(`${key} is already being used as a state attribute (a.k.a. a channel), cannot also be used as a node name.`);
}
for (const reservedChar of [
CHECKPOINT_NAMESPACE_SEPARATOR,
CHECKPOINT_NAMESPACE_END,
]) {
if (key.includes(reservedChar)) {
throw new Error(`"${reservedChar}" is a reserved character and is not allowed in node names.`);
}
}
this.warnIfCompiled(`Adding a node to a graph that has already been compiled. This will not be reflected in the compiled graph.`);
if (key in this.nodes) {
throw new Error(`Node \`${key}\` already present.`);
}
if (key === END || key === START) {
throw new Error(`Node \`${key}\` is reserved.`);
}
let inputSpec = this._schemaDefinition;
if (options?.input !== undefined) {
if (isAnyZodObject(options.input)) {
inputSpec = getChannelsFromZod(options.input);
}
else if (options.input.spec !== undefined) {
inputSpec = options.input.spec;
}
}
if (inputSpec !== undefined) {
this._addSchema(inputSpec);
}
let runnable;
if (Runnable.isRunnable(action)) {
runnable = action;
}
else if (typeof action === "function") {
runnable = new RunnableCallable({
func: action,
name: key,
trace: false,
});
}
else {
runnable = _coerceToRunnable(action);
}
let cachePolicy = options?.cachePolicy;
if (typeof cachePolicy === "boolean") {
cachePolicy = cachePolicy ? {} : undefined;
}
const nodeSpec = {
runnable: runnable,
retryPolicy: options?.retryPolicy,
cachePolicy,
metadata: options?.metadata,
input: inputSpec ?? this._schemaDefinition,
subgraphs: isPregelLike(runnable)
? // eslint-disable-next-line @typescript-eslint/no-explicit-any
[runnable]
: options?.subgraphs,
ends: options?.ends,
defer: options?.defer,
};
this.nodes[key] = nodeSpec;
}
return this;
}
addEdge(startKey, endKey) {
if (typeof startKey === "string") {
return super.addEdge(startKey, endKey);
}
if (this.compiled) {
console.warn("Adding an edge to a graph that has already been compiled. This will " +
"not be reflected in the compiled graph.");
}
for (const start of startKey) {
if (start === END) {
throw new Error("END cannot be a start node");
}
if (!Object.keys(this.nodes).some((node) => node === start)) {
throw new Error(`Need to add a node named "${start}" first`);
}
}
if (endKey === END) {
throw new Error("END cannot be an end node");
}
if (!Object.keys(this.nodes).some((node) => node === endKey)) {
throw new Error(`Need to add a node named "${endKey}" first`);
}
this.waitingEdges.add([startKey, endKey]);
return this;
}
addSequence(nodes) {
const parsedNodes = Array.isArray(nodes)
? nodes
: Object.entries(nodes).map(([key, action]) => [
key,
action,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
action[Symbol.for("langgraph.state.node")] ?? undefined,
]);
if (parsedNodes.length === 0) {
throw new Error("Sequence requires at least one node.");
}
let previousNode;
for (const [key, action, options] of parsedNodes) {
if (key in this.nodes) {
throw new Error(`Node names must be unique: node with the name "${key}" already exists.`);
}
const validKey = key;
this.addNode(validKey, action, options);
if (previousNode != null) {
this.addEdge(previousNode, validKey);
}
previousNode = validKey;
}
return this;
}
compile({ checkpointer, store, cache, interruptBefore, interruptAfter, name, } = {}) {
// validate the graph
this.validate([
...(Array.isArray(interruptBefore) ? interruptBefore : []),
...(Array.isArray(interruptAfter) ? interruptAfter : []),
]);
// prepare output channels
const outputKeys = Object.keys(this._schemaDefinitions.get(this._outputDefinition));
const outputChannels = outputKeys.length === 1 && outputKeys[0] === ROOT ? ROOT : outputKeys;
const streamKeys = Object.keys(this.channels);
const streamChannels = streamKeys.length === 1 && streamKeys[0] === ROOT ? ROOT : streamKeys;
// create empty compiled graph
const compiled = new CompiledStateGraph({
builder: this,
checkpointer,
interruptAfter,
interruptBefore,
autoValidate: false,
nodes: {},
channels: {
...this.channels,
[START]: new EphemeralValue(),
},
inputChannels: START,
outputChannels,
streamChannels,
streamMode: "updates",
store,
cache,
name,
});
// attach nodes, edges and branches
compiled.attachNode(START);
for (const [key, node] of Object.entries(this.nodes)) {
compiled.attachNode(key, node);
}
compiled.attachBranch(START, SELF, _getControlBranch(), {
withReader: false,
});
for (const [key] of Object.entries(this.nodes)) {
compiled.attachBranch(key, SELF, _getControlBranch(), {
withReader: false,
});
}
for (const [start, end] of this.edges) {
compiled.attachEdge(start, end);
}
for (const [starts, end] of this.waitingEdges) {
compiled.attachEdge(starts, end);
}
for (const [start, branches] of Object.entries(this.branches)) {
for (const [name, branch] of Object.entries(branches)) {
compiled.attachBranch(start, name, branch);
}
}
return compiled.validate();
}
}
function _getChannels(schema) {
const channels = {};
for (const [name, val] of Object.entries(schema)) {
if (name === ROOT) {
channels[name] = getChannel(val);
}
else {
const key = name;
channels[name] = getChannel(val);
}
}
return channels;
}
/**
* Final result from building and compiling a {@link StateGraph}.
* Should not be instantiated directly, only using the StateGraph `.compile()`
* instance method.
*/
export class CompiledStateGraph extends CompiledGraph {
attachNode(key, node) {
let outputKeys;
if (key === START) {
// Get input schema keys excluding managed values
outputKeys = Object.entries(this.builder._schemaDefinitions.get(this.builder._inputDefinition))
.filter(([_, v]) => !isConfiguredManagedValue(v))
.map(([k]) => k);
}
else {
outputKeys = Object.keys(this.builder.channels);
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
function _getRoot(input) {
if (isCommand(input)) {
if (input.graph === Command.PARENT) {
return null;
}
return input._updateAsTuples();
}
else if (Array.isArray(input) &&
input.length > 0 &&
input.some((i) => isCommand(i))) {
const updates = [];
for (const i of input) {
if (isCommand(i)) {
if (i.graph === Command.PARENT) {
continue;
}
updates.push(...i._updateAsTuples());
}
else {
updates.push([ROOT, i]);
}
}
return updates;
}
else if (input != null) {
return [[ROOT, input]];
}
return null;
}
// to avoid name collision below
const nodeKey = key;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
function _getUpdates(input) {
if (!input) {
return null;
}
else if (isCommand(input)) {
if (input.graph === Command.PARENT) {
return null;
}
return input._updateAsTuples().filter(([k]) => outputKeys.includes(k));
}
else if (Array.isArray(input) &&
input.length > 0 &&
input.some(isCommand)) {
const updates = [];
for (const item of input) {
if (isCommand(item)) {
if (item.graph === Command.PARENT) {
continue;
}
updates.push(...item._updateAsTuples().filter(([k]) => outputKeys.includes(k)));
}
else {
const itemUpdates = _getUpdates(item);
if (itemUpdates) {
updates.push(...(itemUpdates ?? []));
}
}
}
return updates;
}
else if (typeof input === "object" && !Array.isArray(input)) {
return Object.entries(input).filter(([k]) => outputKeys.includes(k));
}
else {
const typeofInput = Array.isArray(input) ? "array" : typeof input;
throw new InvalidUpdateError(`Expected node "${nodeKey.toString()}" to return an object or an array containing at least one Command object, received ${typeofInput}`, {
lc_error_code: "INVALID_GRAPH_NODE_RETURN_VALUE",
});
}
}
const stateWriteEntries = [
{
value: PASSTHROUGH,
mapper: new RunnableCallable({
func: outputKeys.length && outputKeys[0] === ROOT
? _getRoot
: _getUpdates,
trace: false,
recurse: false,
}),
},
];
// add node and output channel
if (key === START) {
this.nodes[key] = new PregelNode({
tags: [TAG_HIDDEN],
triggers: [START],
channels: [START],
writers: [new ChannelWrite(stateWriteEntries, [TAG_HIDDEN])],
});
}
else {
const inputDefinition = node?.input ?? this.builder._schemaDefinition;
const inputValues = Object.fromEntries(Object.keys(this.builder._schemaDefinitions.get(inputDefinition)).map((k) => [k, k]));
const isSingleInput = Object.keys(inputValues).length === 1 && ROOT in inputValues;
const branchChannel = `branch:to:${key}`;
this.channels[branchChannel] = node?.defer
? new LastValueAfterFinish()
: new EphemeralValue(false);
this.nodes[key] = new PregelNode({
triggers: [branchChannel],
// read state keys
channels: isSingleInput ? Object.keys(inputValues) : inputValues,
// publish to state keys
writers: [new ChannelWrite(stateWriteEntries, [TAG_HIDDEN])],
mapper: isSingleInput
? undefined
: // eslint-disable-next-line @typescript-eslint/no-explicit-any
(input) => {
return Object.fromEntries(Object.entries(input).filter(([k]) => k in inputValues));
},
bound: node?.runnable,
metadata: node?.metadata,
retryPolicy: node?.retryPolicy,
cachePolicy: node?.cachePolicy,
subgraphs: node?.subgraphs,
ends: node?.ends,
});
}
}
attachEdge(starts, end) {
if (end === END)
return;
if (typeof starts === "string") {
this.nodes[starts].writers.push(new ChannelWrite([{ channel: `branch:to:${end}`, value: null }], [TAG_HIDDEN]));
}
else if (Array.isArray(starts)) {
const channelName = `join:${starts.join("+")}:${end}`;
// register channel
this.channels[channelName] = this.builder.nodes[end].defer
? new NamedBarrierValueAfterFinish(new Set(starts))
: new NamedBarrierValue(new Set(starts));
// subscribe to channel
this.nodes[end].triggers.push(channelName);
// publish to channel
for (const start of starts) {
this.nodes[start].writers.push(new ChannelWrite([{ channel: channelName, value: start }], [TAG_HIDDEN]));
}
}
}
attachBranch(start, _, branch, options = { withReader: true }) {
const branchWriter = async (packets, config) => {
const filteredPackets = packets.filter((p) => p !== END);
if (!filteredPackets.length)
return;
const writes = filteredPackets.map((p) => {
if (_isSend(p))
return p;
return { channel: p === END ? p : `branch:to:${p}`, value: start };
});
await ChannelWrite.doWrite({ ...config, tags: (config.tags ?? []).concat([TAG_HIDDEN]) }, writes);
};
// attach branch publisher
this.nodes[start].writers.push(branch.run(branchWriter,
// reader
options.withReader
? (config) => ChannelRead.doRead(config, this.streamChannels ?? this.outputChannels, true)
: undefined));
}
async _validateInput(input) {
if (input == null)
return input;
const schema = (() => {
const input = this.builder._inputRuntimeDefinition;
const schema = this.builder._schemaRuntimeDefinition;
const apply = (schema) => {
if (schema == null)
return undefined;
return applyZodPlugin(schema, { reducer: true });
};
if (isAnyZodObject(input))
return apply(input);
if (input === PartialStateSchema)
return apply(schema)?.partial();
return undefined;
})();
if (isCommand(input)) {
const parsedInput = input;
if (input.update && schema != null)
parsedInput.update = schema.parse(input.update);
return parsedInput;
}
if (schema != null)
return schema.parse(input);
return input;
}
async _validateConfigurable(config) {
const configSchema = this.builder._configRuntimeSchema;
if (isAnyZodObject(configSchema))
configSchema.parse(config);
return config;
}
}
function isStateDefinition(obj) {
return (typeof obj === "object" &&
obj !== null &&
!Array.isArray(obj) &&
Object.keys(obj).length > 0 &&
Object.values(obj).every((v) => typeof v === "function" || isBaseChannel(v)));
}
function isAnnotationRoot(obj) {
return (typeof obj === "object" &&
obj !== null &&
"lc_graph_name" in obj &&
obj.lc_graph_name === "AnnotationRoot");
}
function isStateGraphArgs(obj) {
return (typeof obj === "object" &&
obj !== null &&
obj.channels !== undefined);
}
function isStateGraphArgsWithStateSchema(obj) {
return (typeof obj === "object" &&
obj !== null &&
obj.stateSchema !== undefined);
}
function isStateGraphArgsWithInputOutputSchemas(obj) {
return (typeof obj === "object" &&
obj !== null &&
// eslint-disable-next-line @typescript-eslint/no-explicit-any
obj.stateSchema === undefined &&
obj.input !== undefined &&
obj.output !== undefined);
}
function isZodStateGraphArgsWithStateSchema(value) {
if (typeof value !== "object" || value == null) {
return false;
}
if (!("state" in value) || !isAnyZodObject(value.state)) {
return false;
}
if ("input" in value && !isAnyZodObject(value.input)) {
return false;
}
if ("output" in value && !isAnyZodObject(value.output)) {
return false;
}
return true;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
function _controlBranch(value) {
if (_isSend(value)) {
return [value];
}
const commands = [];
if (isCommand(value)) {
commands.push(value);
}
else if (Array.isArray(value)) {
commands.push(...value.filter(isCommand));
}
const destinations = [];
for (const command of commands) {
if (command.graph === Command.PARENT) {
throw new ParentCommand(command);
}
if (_isSend(command.goto)) {
destinations.push(command.goto);
}
else if (typeof command.goto === "string") {
destinations.push(command.goto);
}
else {
if (Array.isArray(command.goto)) {
destinations.push(...command.goto);
}
}
}
return destinations;
}
function _getControlBranch() {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const CONTROL_BRANCH_PATH = new RunnableCallable({
func: _controlBranch,
tags: [TAG_HIDDEN],
trace: false,
recurse: false,
name: "<control_branch>",
});
return new Branch({
path: CONTROL_BRANCH_PATH,
});
}
export function typedNode(_state, _options) {
return (func, options) => {
Object.assign(func, { [Symbol.for("langgraph.state.node")]: options });
return func;
};
}
//# sourceMappingURL=state.js.map