UNPKG

@langchain/langgraph

Version:

LangGraph

401 lines 16.5 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.createReactAgentAnnotation = void 0; exports._shouldBindTools = _shouldBindTools; exports._getModel = _getModel; exports.createReactAgent = createReactAgent; const messages_1 = require("@langchain/core/messages"); const runnables_1 = require("@langchain/core/runnables"); const index_js_1 = require("../graph/index.cjs"); const tool_node_js_1 = require("./tool_node.cjs"); const annotation_js_1 = require("../graph/annotation.cjs"); const message_js_1 = require("../graph/message.cjs"); const constants_js_1 = require("../constants.cjs"); const agentName_js_1 = require("./agentName.cjs"); function _convertMessageModifierToPrompt(messageModifier) { // Handle string or SystemMessage if (typeof messageModifier === "string" || ((0, messages_1.isBaseMessage)(messageModifier) && messageModifier._getType() === "system")) { return messageModifier; } // Handle callable function if (typeof messageModifier === "function") { return async (state) => messageModifier(state.messages); } // Handle Runnable if (runnables_1.Runnable.isRunnable(messageModifier)) { return runnables_1.RunnableLambda.from((state) => state.messages).pipe(messageModifier); } throw new Error(`Unexpected type for messageModifier: ${typeof messageModifier}`); } const PROMPT_RUNNABLE_NAME = "prompt"; function _getPromptRunnable(prompt) { let promptRunnable; if (prompt == null) { promptRunnable = runnables_1.RunnableLambda.from((state) => state.messages).withConfig({ runName: PROMPT_RUNNABLE_NAME }); } else if (typeof prompt === "string") { const systemMessage = new messages_1.SystemMessage(prompt); promptRunnable = runnables_1.RunnableLambda.from((state) => { return [systemMessage, ...(state.messages ?? [])]; }).withConfig({ runName: PROMPT_RUNNABLE_NAME }); } else if ((0, messages_1.isBaseMessage)(prompt) && prompt._getType() === "system") { promptRunnable = runnables_1.RunnableLambda.from((state) => [prompt, ...state.messages]).withConfig({ runName: PROMPT_RUNNABLE_NAME }); } else if (typeof prompt === "function") { promptRunnable = runnables_1.RunnableLambda.from(prompt).withConfig({ runName: PROMPT_RUNNABLE_NAME, }); } else if (runnables_1.Runnable.isRunnable(prompt)) { promptRunnable = prompt; } else { throw new Error(`Got unexpected type for 'prompt': ${typeof prompt}`); } return promptRunnable; } function isClientTool(tool) { return runnables_1.Runnable.isRunnable(tool); } function _getPrompt(prompt, stateModifier, messageModifier) { // Check if multiple modifiers exist const definedCount = [prompt, stateModifier, messageModifier].filter((x) => x != null).length; if (definedCount > 1) { throw new Error("Expected only one of prompt, stateModifier, or messageModifier, got multiple values"); } let finalPrompt = prompt; if (stateModifier != null) { finalPrompt = stateModifier; } else if (messageModifier != null) { finalPrompt = _convertMessageModifierToPrompt(messageModifier); } return _getPromptRunnable(finalPrompt); } function _isBaseChatModel(model) { return ("invoke" in model && typeof model.invoke === "function" && "_modelType" in model); } // eslint-disable-next-line @typescript-eslint/no-explicit-any function _isConfigurableModel(model) { return ("_queuedMethodOperations" in model && "_model" in model && typeof model._model === "function"); } async function _shouldBindTools(llm, tools) { // If model is a RunnableSequence, find a RunnableBinding or BaseChatModel in its steps let model = llm; if (runnables_1.RunnableSequence.isRunnableSequence(model)) { model = model.steps.find((step) => runnables_1.RunnableBinding.isRunnableBinding(step) || _isBaseChatModel(step) || _isConfigurableModel(step)) || model; } if (_isConfigurableModel(model)) { model = await model._model(); } // If not a RunnableBinding, we should bind tools if (!runnables_1.RunnableBinding.isRunnableBinding(model)) { return true; } // If no tools in kwargs, we should bind tools if (!model.kwargs || typeof model.kwargs !== "object" || !("tools" in model.kwargs)) { return true; } let boundTools = model.kwargs.tools; // google-style if (boundTools.length === 1 && "functionDeclarations" in boundTools[0]) { boundTools = boundTools[0].functionDeclarations; } // Check if tools count matches if (tools.length !== boundTools.length) { throw new Error("Number of tools in the model.bindTools() and tools passed to createReactAgent must match"); } const toolNames = new Set(tools.flatMap((tool) => (isClientTool(tool) ? tool.name : []))); const boundToolNames = new Set(); for (const boundTool of boundTools) { let boundToolName; // OpenAI-style tool if ("type" in boundTool && boundTool.type === "function") { boundToolName = boundTool.function.name; } // Anthropic or Google-style tool else if ("name" in boundTool) { boundToolName = boundTool.name; } // Bedrock-style tool else if ("toolSpec" in boundTool && "name" in boundTool.toolSpec) { boundToolName = boundTool.toolSpec.name; } // unknown tool type so we'll ignore it else { continue; } if (boundToolName) { boundToolNames.add(boundToolName); } } const missingTools = [...toolNames].filter((x) => !boundToolNames.has(x)); if (missingTools.length > 0) { throw new Error(`Missing tools '${missingTools}' in the model.bindTools().` + `Tools in the model.bindTools() must match the tools passed to createReactAgent.`); } return false; } async function _getModel(llm) { // If model is a RunnableSequence, find a RunnableBinding or BaseChatModel in its steps let model = llm; if (runnables_1.RunnableSequence.isRunnableSequence(model)) { model = model.steps.find((step) => runnables_1.RunnableBinding.isRunnableBinding(step) || _isBaseChatModel(step) || _isConfigurableModel(step)) || model; } if (_isConfigurableModel(model)) { model = await model._model(); } // Get the underlying model from a RunnableBinding if (runnables_1.RunnableBinding.isRunnableBinding(model)) { model = model.bound; } if (!_isBaseChatModel(model)) { throw new Error(`Expected \`llm\` to be a ChatModel or RunnableBinding (e.g. llm.bind_tools(...)) with invoke() and generate() methods, got ${model.constructor.name}`); } return model; } const createReactAgentAnnotation = () => annotation_js_1.Annotation.Root({ messages: (0, annotation_js_1.Annotation)({ reducer: message_js_1.messagesStateReducer, default: () => [], }), structuredResponse: (annotation_js_1.Annotation), }); exports.createReactAgentAnnotation = createReactAgentAnnotation; const PreHookAnnotation = annotation_js_1.Annotation.Root({ llmInputMessages: (0, annotation_js_1.Annotation)({ reducer: message_js_1.messagesStateReducer, default: () => [], }), }); /** * Creates a StateGraph agent that relies on a chat model utilizing tool calling. * * @example * ```ts * import { ChatOpenAI } from "@langchain/openai"; * import { tool } from "@langchain/core/tools"; * import { z } from "zod"; * import { createReactAgent } from "@langchain/langgraph/prebuilt"; * * const model = new ChatOpenAI({ * model: "gpt-4o", * }); * * const getWeather = tool((input) => { * if (["sf", "san francisco"].includes(input.location.toLowerCase())) { * return "It's 60 degrees and foggy."; * } else { * return "It's 90 degrees and sunny."; * } * }, { * name: "get_weather", * description: "Call to get the current weather.", * schema: z.object({ * location: z.string().describe("Location to get the weather for."), * }) * }) * * const agent = createReactAgent({ llm: model, tools: [getWeather] }); * * const inputs = { * messages: [{ role: "user", content: "what is the weather in SF?" }], * }; * * const stream = await agent.stream(inputs, { streamMode: "values" }); * * for await (const { messages } of stream) { * console.log(messages); * } * // Returns the messages in the state at each step of execution * ``` */ function createReactAgent(params) { const { llm, tools, messageModifier, stateModifier, prompt, stateSchema, checkpointSaver, checkpointer, interruptBefore, interruptAfter, store, responseFormat, preModelHook, postModelHook, name, includeAgentName, } = params; let toolClasses; let toolNode; if (!Array.isArray(tools)) { toolClasses = tools.tools; toolNode = tools; } else { toolClasses = tools; toolNode = new tool_node_js_1.ToolNode(toolClasses.filter(isClientTool)); } let cachedModelRunnable = null; const getModelRunnable = async (llm) => { if (cachedModelRunnable) { return cachedModelRunnable; } let modelWithTools; if (await _shouldBindTools(llm, toolClasses)) { if (!("bindTools" in llm) || typeof llm.bindTools !== "function") { throw new Error(`llm ${llm} must define bindTools method.`); } modelWithTools = llm.bindTools(toolClasses); } else { modelWithTools = llm; } const promptRunnable = _getPrompt(prompt, stateModifier, messageModifier); const modelRunnable = includeAgentName === "inline" ? promptRunnable.pipe((0, agentName_js_1.withAgentName)(modelWithTools, includeAgentName)) : promptRunnable.pipe(modelWithTools); cachedModelRunnable = modelRunnable; return modelRunnable; }; // If any of the tools are configured to return_directly after running, // our graph needs to check if these were called const shouldReturnDirect = new Set(toolClasses .filter(isClientTool) .filter((tool) => "returnDirect" in tool && tool.returnDirect) .map((tool) => tool.name)); function getModelInputState(state) { const { messages, llmInputMessages, ...rest } = state; if (llmInputMessages != null && llmInputMessages.length > 0) { return { messages: llmInputMessages, ...rest }; } return { messages, ...rest }; } const generateStructuredResponse = async (state, config) => { if (responseFormat == null) { throw new Error("Attempted to generate structured output with no passed response schema. Please contact us for help."); } const messages = [...state.messages]; let modelWithStructuredOutput; if (typeof responseFormat === "object" && "prompt" in responseFormat && "schema" in responseFormat) { const { prompt, schema } = responseFormat; modelWithStructuredOutput = (await _getModel(llm)).withStructuredOutput(schema); messages.unshift(new messages_1.SystemMessage({ content: prompt })); } else { modelWithStructuredOutput = (await _getModel(llm)).withStructuredOutput(responseFormat); } const response = await modelWithStructuredOutput.invoke(messages, config); return { structuredResponse: response }; }; const callModel = async (state, config) => { // NOTE: we're dynamically creating the model runnable here // to ensure that we can validate ConfigurableModel properly const modelRunnable = await getModelRunnable(llm); // TODO: Auto-promote streaming. const response = (await modelRunnable.invoke(getModelInputState(state), config)); // add agent name to the AIMessage // TODO: figure out if we can avoid mutating the message directly response.name = name; response.lc_kwargs.name = name; return { messages: [response] }; }; const schema = stateSchema ?? (0, exports.createReactAgentAnnotation)(); const workflow = new index_js_1.StateGraph(schema).addNode("tools", toolNode); const allNodeWorkflows = workflow; const conditionalMap = (map) => { return Object.fromEntries(Object.entries(map).filter(([_, v]) => v != null)); }; let entrypoint = "agent"; let inputSchema; if (preModelHook != null) { allNodeWorkflows .addNode("pre_model_hook", preModelHook) .addEdge("pre_model_hook", "agent"); entrypoint = "pre_model_hook"; inputSchema = annotation_js_1.Annotation.Root({ ...schema.spec, ...PreHookAnnotation.spec, }); } else { entrypoint = "agent"; } allNodeWorkflows .addNode("agent", callModel, { input: inputSchema }) .addEdge(constants_js_1.START, entrypoint); if (postModelHook != null) { allNodeWorkflows .addNode("post_model_hook", postModelHook) .addEdge("agent", "post_model_hook") .addConditionalEdges("post_model_hook", (state) => { const { messages } = state; const lastMessage = messages[messages.length - 1]; if ((0, messages_1.isAIMessage)(lastMessage) && lastMessage.tool_calls?.length) { return "tools"; } if ((0, messages_1.isToolMessage)(lastMessage)) return entrypoint; if (responseFormat != null) return "generate_structured_response"; return constants_js_1.END; }, conditionalMap({ tools: "tools", [entrypoint]: entrypoint, generate_structured_response: responseFormat != null ? "generate_structured_response" : null, [constants_js_1.END]: responseFormat != null ? null : constants_js_1.END, })); } if (responseFormat !== undefined) { workflow .addNode("generate_structured_response", generateStructuredResponse) .addEdge("generate_structured_response", constants_js_1.END); } if (postModelHook == null) { allNodeWorkflows.addConditionalEdges("agent", (state) => { const { messages } = state; const lastMessage = messages[messages.length - 1]; // if there's no function call, we finish if (!(0, messages_1.isAIMessage)(lastMessage) || !lastMessage.tool_calls?.length) { if (responseFormat != null) return "generate_structured_response"; return constants_js_1.END; } // there are function calls, we continue return "tools"; }, conditionalMap({ tools: "tools", generate_structured_response: responseFormat != null ? "generate_structured_response" : null, [constants_js_1.END]: responseFormat != null ? null : constants_js_1.END, })); } if (shouldReturnDirect.size > 0) { allNodeWorkflows.addConditionalEdges("tools", (state) => { // Check the last consecutive tool calls for (let i = state.messages.length - 1; i >= 0; i -= 1) { const message = state.messages[i]; if (!(0, messages_1.isToolMessage)(message)) break; // Check if this tool is configured to return directly if (message.name !== undefined && shouldReturnDirect.has(message.name)) { return constants_js_1.END; } } return entrypoint; }, conditionalMap({ [entrypoint]: entrypoint, [constants_js_1.END]: constants_js_1.END })); } else { allNodeWorkflows.addEdge("tools", entrypoint); } return allNodeWorkflows.compile({ checkpointer: checkpointer ?? checkpointSaver, interruptBefore, interruptAfter, store, name, }); } //# sourceMappingURL=react_agent_executor.js.map