@langchain/core
Version:
Core LangChain.js abstractions and schemas
453 lines (452 loc) • 16.5 kB
JavaScript
import { RunnableLambda } from "../runnables/base.js";
import { AIMessage, AIMessageChunk } from "./ai.js";
import { isBaseMessageChunk, } from "./base.js";
import { ChatMessage, ChatMessageChunk, } from "./chat.js";
import { FunctionMessage, FunctionMessageChunk, } from "./function.js";
import { HumanMessage, HumanMessageChunk } from "./human.js";
import { RemoveMessage } from "./modifier.js";
import { SystemMessage, SystemMessageChunk } from "./system.js";
import { ToolMessage, ToolMessageChunk, } from "./tool.js";
import { convertToChunk } from "./utils.js";
const _isMessageType = (msg, types) => {
const typesAsStrings = [
...new Set(types?.map((t) => {
if (typeof t === "string") {
return t;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const instantiatedMsgClass = new t({});
if (!("getType" in instantiatedMsgClass) ||
typeof instantiatedMsgClass.getType !== "function") {
throw new Error("Invalid type provided.");
}
return instantiatedMsgClass.getType();
})),
];
const msgType = msg.getType();
return typesAsStrings.some((t) => t === msgType);
};
export function filterMessages(messagesOrOptions, options) {
if (Array.isArray(messagesOrOptions)) {
return _filterMessages(messagesOrOptions, options);
}
return RunnableLambda.from((input) => {
return _filterMessages(input, messagesOrOptions);
});
}
function _filterMessages(messages, options = {}) {
const { includeNames, excludeNames, includeTypes, excludeTypes, includeIds, excludeIds, } = options;
const filtered = [];
for (const msg of messages) {
if (excludeNames && msg.name && excludeNames.includes(msg.name)) {
continue;
}
else if (excludeTypes && _isMessageType(msg, excludeTypes)) {
continue;
}
else if (excludeIds && msg.id && excludeIds.includes(msg.id)) {
continue;
}
// default to inclusion when no inclusion criteria given.
if (!(includeTypes || includeIds || includeNames)) {
filtered.push(msg);
}
else if (includeNames &&
msg.name &&
includeNames.some((iName) => iName === msg.name)) {
filtered.push(msg);
}
else if (includeTypes && _isMessageType(msg, includeTypes)) {
filtered.push(msg);
}
else if (includeIds && msg.id && includeIds.some((id) => id === msg.id)) {
filtered.push(msg);
}
}
return filtered;
}
export function mergeMessageRuns(messages) {
if (Array.isArray(messages)) {
return _mergeMessageRuns(messages);
}
return RunnableLambda.from(_mergeMessageRuns);
}
function _mergeMessageRuns(messages) {
if (!messages.length) {
return [];
}
const merged = [];
for (const msg of messages) {
const curr = msg;
const last = merged.pop();
if (!last) {
merged.push(curr);
}
else if (curr.getType() === "tool" ||
!(curr.getType() === last.getType())) {
merged.push(last, curr);
}
else {
const lastChunk = convertToChunk(last);
const currChunk = convertToChunk(curr);
const mergedChunks = lastChunk.concat(currChunk);
if (typeof lastChunk.content === "string" &&
typeof currChunk.content === "string") {
mergedChunks.content = `${lastChunk.content}\n${currChunk.content}`;
}
merged.push(_chunkToMsg(mergedChunks));
}
}
return merged;
}
export function trimMessages(messagesOrOptions, options) {
if (Array.isArray(messagesOrOptions)) {
const messages = messagesOrOptions;
if (!options) {
throw new Error("Options parameter is required when providing messages.");
}
return _trimMessagesHelper(messages, options);
}
else {
const trimmerOptions = messagesOrOptions;
return RunnableLambda.from((input) => _trimMessagesHelper(input, trimmerOptions)).withConfig({
runName: "trim_messages",
});
}
}
async function _trimMessagesHelper(messages, options) {
const { maxTokens, tokenCounter, strategy = "last", allowPartial = false, endOn, startOn, includeSystem = false, textSplitter, } = options;
if (startOn && strategy === "first") {
throw new Error("`startOn` should only be specified if `strategy` is 'last'.");
}
if (includeSystem && strategy === "first") {
throw new Error("`includeSystem` should only be specified if `strategy` is 'last'.");
}
let listTokenCounter;
if ("getNumTokens" in tokenCounter) {
listTokenCounter = async (msgs) => {
const tokenCounts = await Promise.all(msgs.map((msg) => tokenCounter.getNumTokens(msg.content)));
return tokenCounts.reduce((sum, count) => sum + count, 0);
};
}
else {
listTokenCounter = async (msgs) => tokenCounter(msgs);
}
let textSplitterFunc = defaultTextSplitter;
if (textSplitter) {
if ("splitText" in textSplitter) {
textSplitterFunc = textSplitter.splitText;
}
else {
textSplitterFunc = async (text) => textSplitter(text);
}
}
if (strategy === "first") {
return _firstMaxTokens(messages, {
maxTokens,
tokenCounter: listTokenCounter,
textSplitter: textSplitterFunc,
partialStrategy: allowPartial ? "first" : undefined,
endOn,
});
}
else if (strategy === "last") {
return _lastMaxTokens(messages, {
maxTokens,
tokenCounter: listTokenCounter,
textSplitter: textSplitterFunc,
allowPartial,
includeSystem,
startOn,
endOn,
});
}
else {
throw new Error(`Unrecognized strategy: '${strategy}'. Must be one of 'first' or 'last'.`);
}
}
async function _firstMaxTokens(messages, options) {
const { maxTokens, tokenCounter, textSplitter, partialStrategy, endOn } = options;
let messagesCopy = [...messages];
let idx = 0;
for (let i = 0; i < messagesCopy.length; i += 1) {
const remainingMessages = i > 0 ? messagesCopy.slice(0, -i) : messagesCopy;
if ((await tokenCounter(remainingMessages)) <= maxTokens) {
idx = messagesCopy.length - i;
break;
}
}
if (idx < messagesCopy.length - 1 && partialStrategy) {
let includedPartial = false;
if (Array.isArray(messagesCopy[idx].content)) {
const excluded = messagesCopy[idx];
if (typeof excluded.content === "string") {
throw new Error("Expected content to be an array.");
}
const numBlock = excluded.content.length;
const reversedContent = partialStrategy === "last"
? [...excluded.content].reverse()
: excluded.content;
for (let i = 1; i <= numBlock; i += 1) {
const partialContent = partialStrategy === "first"
? reversedContent.slice(0, i)
: reversedContent.slice(-i);
const fields = Object.fromEntries(Object.entries(excluded).filter(([k]) => k !== "type" && !k.startsWith("lc_")));
const updatedMessage = _switchTypeToMessage(excluded.getType(), {
...fields,
content: partialContent,
});
const slicedMessages = [...messagesCopy.slice(0, idx), updatedMessage];
if ((await tokenCounter(slicedMessages)) <= maxTokens) {
messagesCopy = slicedMessages;
idx += 1;
includedPartial = true;
}
else {
break;
}
}
if (includedPartial && partialStrategy === "last") {
excluded.content = [...reversedContent].reverse();
}
}
if (!includedPartial) {
const excluded = messagesCopy[idx];
let text;
if (Array.isArray(excluded.content) &&
excluded.content.some((block) => typeof block === "string" || block.type === "text")) {
const textBlock = excluded.content.find((block) => block.type === "text" && block.text);
text = textBlock?.text;
}
else if (typeof excluded.content === "string") {
text = excluded.content;
}
if (text) {
const splitTexts = await textSplitter(text);
const numSplits = splitTexts.length;
if (partialStrategy === "last") {
splitTexts.reverse();
}
for (let _ = 0; _ < numSplits - 1; _ += 1) {
splitTexts.pop();
excluded.content = splitTexts.join("");
if ((await tokenCounter([...messagesCopy.slice(0, idx), excluded])) <=
maxTokens) {
if (partialStrategy === "last") {
excluded.content = [...splitTexts].reverse().join("");
}
messagesCopy = [...messagesCopy.slice(0, idx), excluded];
idx += 1;
break;
}
}
}
}
}
if (endOn) {
const endOnArr = Array.isArray(endOn) ? endOn : [endOn];
while (idx > 0 && !_isMessageType(messagesCopy[idx - 1], endOnArr)) {
idx -= 1;
}
}
return messagesCopy.slice(0, idx);
}
async function _lastMaxTokens(messages, options) {
const { allowPartial = false, includeSystem = false, endOn, startOn, ...rest } = options;
// Create a copy of messages to avoid mutation
let messagesCopy = messages.map((message) => {
const fields = Object.fromEntries(Object.entries(message).filter(([k]) => k !== "type" && !k.startsWith("lc_")));
return _switchTypeToMessage(message.getType(), fields, isBaseMessageChunk(message));
});
if (endOn) {
const endOnArr = Array.isArray(endOn) ? endOn : [endOn];
while (messagesCopy.length > 0 &&
!_isMessageType(messagesCopy[messagesCopy.length - 1], endOnArr)) {
messagesCopy = messagesCopy.slice(0, -1);
}
}
const swappedSystem = includeSystem && messagesCopy[0]?.getType() === "system";
let reversed_ = swappedSystem
? messagesCopy.slice(0, 1).concat(messagesCopy.slice(1).reverse())
: messagesCopy.reverse();
reversed_ = await _firstMaxTokens(reversed_, {
...rest,
partialStrategy: allowPartial ? "last" : undefined,
endOn: startOn,
});
if (swappedSystem) {
return [reversed_[0], ...reversed_.slice(1).reverse()];
}
else {
return reversed_.reverse();
}
}
const _MSG_CHUNK_MAP = {
human: {
message: HumanMessage,
messageChunk: HumanMessageChunk,
},
ai: {
message: AIMessage,
messageChunk: AIMessageChunk,
},
system: {
message: SystemMessage,
messageChunk: SystemMessageChunk,
},
developer: {
message: SystemMessage,
messageChunk: SystemMessageChunk,
},
tool: {
message: ToolMessage,
messageChunk: ToolMessageChunk,
},
function: {
message: FunctionMessage,
messageChunk: FunctionMessageChunk,
},
generic: {
message: ChatMessage,
messageChunk: ChatMessageChunk,
},
remove: {
message: RemoveMessage,
messageChunk: RemoveMessage, // RemoveMessage does not have a chunk class.
},
};
function _switchTypeToMessage(messageType, fields, returnChunk) {
let chunk;
let msg;
switch (messageType) {
case "human":
if (returnChunk) {
chunk = new HumanMessageChunk(fields);
}
else {
msg = new HumanMessage(fields);
}
break;
case "ai":
if (returnChunk) {
let aiChunkFields = {
...fields,
};
if ("tool_calls" in aiChunkFields) {
aiChunkFields = {
...aiChunkFields,
tool_call_chunks: aiChunkFields.tool_calls?.map((tc) => ({
...tc,
type: "tool_call_chunk",
index: undefined,
args: JSON.stringify(tc.args),
})),
};
}
chunk = new AIMessageChunk(aiChunkFields);
}
else {
msg = new AIMessage(fields);
}
break;
case "system":
if (returnChunk) {
chunk = new SystemMessageChunk(fields);
}
else {
msg = new SystemMessage(fields);
}
break;
case "developer":
if (returnChunk) {
chunk = new SystemMessageChunk({
...fields,
additional_kwargs: {
...fields.additional_kwargs,
__openai_role__: "developer",
},
});
}
else {
msg = new SystemMessage({
...fields,
additional_kwargs: {
...fields.additional_kwargs,
__openai_role__: "developer",
},
});
}
break;
case "tool":
if ("tool_call_id" in fields) {
if (returnChunk) {
chunk = new ToolMessageChunk(fields);
}
else {
msg = new ToolMessage(fields);
}
}
else {
throw new Error("Can not convert ToolMessage to ToolMessageChunk if 'tool_call_id' field is not defined.");
}
break;
case "function":
if (returnChunk) {
chunk = new FunctionMessageChunk(fields);
}
else {
if (!fields.name) {
throw new Error("FunctionMessage must have a 'name' field");
}
msg = new FunctionMessage(fields);
}
break;
case "generic":
if ("role" in fields) {
if (returnChunk) {
chunk = new ChatMessageChunk(fields);
}
else {
msg = new ChatMessage(fields);
}
}
else {
throw new Error("Can not convert ChatMessage to ChatMessageChunk if 'role' field is not defined.");
}
break;
default:
throw new Error(`Unrecognized message type ${messageType}`);
}
if (returnChunk && chunk) {
return chunk;
}
if (msg) {
return msg;
}
throw new Error(`Unrecognized message type ${messageType}`);
}
function _chunkToMsg(chunk) {
const chunkType = chunk.getType();
let msg;
const fields = Object.fromEntries(Object.entries(chunk).filter(([k]) => !["type", "tool_call_chunks"].includes(k) && !k.startsWith("lc_")));
if (chunkType in _MSG_CHUNK_MAP) {
msg = _switchTypeToMessage(chunkType, fields);
}
if (!msg) {
throw new Error(`Unrecognized message chunk class ${chunkType}. Supported classes are ${Object.keys(_MSG_CHUNK_MAP)}`);
}
return msg;
}
/**
* The default text splitter function that splits text by newlines.
*
* @param {string} text
* @returns A promise that resolves to an array of strings split by newlines.
*/
export function defaultTextSplitter(text) {
const splits = text.split("\n");
return Promise.resolve([
...splits.slice(0, -1).map((s) => `${s}\n`),
splits[splits.length - 1],
]);
}