@langchain/core
Version:
Core LangChain.js abstractions and schemas
311 lines (309 loc) • 12.4 kB
JavaScript
import { isBaseMessageChunk } from "./base.js";
import { ToolMessage, ToolMessageChunk } from "./tool.js";
import { AIMessage, AIMessageChunk } from "./ai.js";
import { ChatMessage, ChatMessageChunk } from "./chat.js";
import { FunctionMessage, FunctionMessageChunk } from "./function.js";
import { HumanMessage, HumanMessageChunk } from "./human.js";
import { SystemMessage, SystemMessageChunk } from "./system.js";
import { RemoveMessage } from "./modifier.js";
import { convertToChunk } from "./utils.js";
import { RunnableLambda } from "../runnables/base.js";
//#region src/messages/transformers.ts
const _isMessageType = (msg, types) => {
const typesAsStrings = [...new Set(types?.map((t) => {
if (typeof t === "string") return t;
const instantiatedMsgClass = new t({});
if (!("getType" in instantiatedMsgClass) || typeof instantiatedMsgClass.getType !== "function") throw new Error("Invalid type provided.");
return instantiatedMsgClass.getType();
}))];
const msgType = msg.getType();
return typesAsStrings.some((t) => t === msgType);
};
function filterMessages(messagesOrOptions, options) {
if (Array.isArray(messagesOrOptions)) return _filterMessages(messagesOrOptions, options);
return RunnableLambda.from((input) => {
return _filterMessages(input, messagesOrOptions);
});
}
function _filterMessages(messages, options = {}) {
const { includeNames, excludeNames, includeTypes, excludeTypes, includeIds, excludeIds } = options;
const filtered = [];
for (const msg of messages) {
if (excludeNames && msg.name && excludeNames.includes(msg.name)) continue;
else if (excludeTypes && _isMessageType(msg, excludeTypes)) continue;
else if (excludeIds && msg.id && excludeIds.includes(msg.id)) continue;
if (!(includeTypes || includeIds || includeNames)) filtered.push(msg);
else if (includeNames && msg.name && includeNames.some((iName) => iName === msg.name)) filtered.push(msg);
else if (includeTypes && _isMessageType(msg, includeTypes)) filtered.push(msg);
else if (includeIds && msg.id && includeIds.some((id) => id === msg.id)) filtered.push(msg);
}
return filtered;
}
function mergeMessageRuns(messages) {
if (Array.isArray(messages)) return _mergeMessageRuns(messages);
return RunnableLambda.from(_mergeMessageRuns);
}
function _mergeMessageRuns(messages) {
if (!messages.length) return [];
const merged = [];
for (const msg of messages) {
const curr = msg;
const last = merged.pop();
if (!last) merged.push(curr);
else if (curr.getType() === "tool" || !(curr.getType() === last.getType())) merged.push(last, curr);
else {
const lastChunk = convertToChunk(last);
const currChunk = convertToChunk(curr);
const mergedChunks = lastChunk.concat(currChunk);
if (typeof lastChunk.content === "string" && typeof currChunk.content === "string") mergedChunks.content = `${lastChunk.content}\n${currChunk.content}`;
merged.push(_chunkToMsg(mergedChunks));
}
}
return merged;
}
function trimMessages(messagesOrOptions, options) {
if (Array.isArray(messagesOrOptions)) {
const messages = messagesOrOptions;
if (!options) throw new Error("Options parameter is required when providing messages.");
return _trimMessagesHelper(messages, options);
} else {
const trimmerOptions = messagesOrOptions;
return RunnableLambda.from((input) => _trimMessagesHelper(input, trimmerOptions)).withConfig({ runName: "trim_messages" });
}
}
async function _trimMessagesHelper(messages, options) {
const { maxTokens, tokenCounter, strategy = "last", allowPartial = false, endOn, startOn, includeSystem = false, textSplitter } = options;
if (startOn && strategy === "first") throw new Error("`startOn` should only be specified if `strategy` is 'last'.");
if (includeSystem && strategy === "first") throw new Error("`includeSystem` should only be specified if `strategy` is 'last'.");
let listTokenCounter;
if ("getNumTokens" in tokenCounter) listTokenCounter = async (msgs) => {
const tokenCounts = await Promise.all(msgs.map((msg) => tokenCounter.getNumTokens(msg.content)));
return tokenCounts.reduce((sum, count) => sum + count, 0);
};
else listTokenCounter = async (msgs) => tokenCounter(msgs);
let textSplitterFunc = defaultTextSplitter;
if (textSplitter) if ("splitText" in textSplitter) textSplitterFunc = textSplitter.splitText;
else textSplitterFunc = async (text) => textSplitter(text);
if (strategy === "first") return _firstMaxTokens(messages, {
maxTokens,
tokenCounter: listTokenCounter,
textSplitter: textSplitterFunc,
partialStrategy: allowPartial ? "first" : void 0,
endOn
});
else if (strategy === "last") return _lastMaxTokens(messages, {
maxTokens,
tokenCounter: listTokenCounter,
textSplitter: textSplitterFunc,
allowPartial,
includeSystem,
startOn,
endOn
});
else throw new Error(`Unrecognized strategy: '${strategy}'. Must be one of 'first' or 'last'.`);
}
async function _firstMaxTokens(messages, options) {
const { maxTokens, tokenCounter, textSplitter, partialStrategy, endOn } = options;
let messagesCopy = [...messages];
let idx = 0;
for (let i = 0; i < messagesCopy.length; i += 1) {
const remainingMessages = i > 0 ? messagesCopy.slice(0, -i) : messagesCopy;
if (await tokenCounter(remainingMessages) <= maxTokens) {
idx = messagesCopy.length - i;
break;
}
}
if (idx < messagesCopy.length && partialStrategy) {
let includedPartial = false;
if (Array.isArray(messagesCopy[idx].content)) {
const excluded = messagesCopy[idx];
if (typeof excluded.content === "string") throw new Error("Expected content to be an array.");
const numBlock = excluded.content.length;
const reversedContent = partialStrategy === "last" ? [...excluded.content].reverse() : excluded.content;
for (let i = 1; i <= numBlock; i += 1) {
const partialContent = partialStrategy === "first" ? reversedContent.slice(0, i) : reversedContent.slice(-i);
const fields = Object.fromEntries(Object.entries(excluded).filter(([k]) => k !== "type" && !k.startsWith("lc_")));
const updatedMessage = _switchTypeToMessage(excluded.getType(), {
...fields,
content: partialContent
});
const slicedMessages = [...messagesCopy.slice(0, idx), updatedMessage];
if (await tokenCounter(slicedMessages) <= maxTokens) {
messagesCopy = slicedMessages;
idx += 1;
includedPartial = true;
} else break;
}
if (includedPartial && partialStrategy === "last") excluded.content = [...reversedContent].reverse();
}
if (!includedPartial) {
const excluded = messagesCopy[idx];
let text;
if (Array.isArray(excluded.content) && excluded.content.some((block) => typeof block === "string" || block.type === "text")) {
const textBlock = excluded.content.find((block) => block.type === "text" && block.text);
text = textBlock?.text;
} else if (typeof excluded.content === "string") text = excluded.content;
if (text) {
const splitTexts = await textSplitter(text);
const numSplits = splitTexts.length;
if (partialStrategy === "last") splitTexts.reverse();
for (let _ = 0; _ < numSplits - 1; _ += 1) {
splitTexts.pop();
excluded.content = splitTexts.join("");
if (await tokenCounter([...messagesCopy.slice(0, idx), excluded]) <= maxTokens) {
if (partialStrategy === "last") excluded.content = [...splitTexts].reverse().join("");
messagesCopy = [...messagesCopy.slice(0, idx), excluded];
idx += 1;
break;
}
}
}
}
}
if (endOn) {
const endOnArr = Array.isArray(endOn) ? endOn : [endOn];
while (idx > 0 && !_isMessageType(messagesCopy[idx - 1], endOnArr)) idx -= 1;
}
return messagesCopy.slice(0, idx);
}
async function _lastMaxTokens(messages, options) {
const { allowPartial = false, includeSystem = false, endOn, startOn,...rest } = options;
let messagesCopy = messages.map((message) => {
const fields = Object.fromEntries(Object.entries(message).filter(([k]) => k !== "type" && !k.startsWith("lc_")));
return _switchTypeToMessage(message.getType(), fields, isBaseMessageChunk(message));
});
if (endOn) {
const endOnArr = Array.isArray(endOn) ? endOn : [endOn];
while (messagesCopy.length > 0 && !_isMessageType(messagesCopy[messagesCopy.length - 1], endOnArr)) messagesCopy = messagesCopy.slice(0, -1);
}
const swappedSystem = includeSystem && messagesCopy[0]?.getType() === "system";
let reversed_ = swappedSystem ? messagesCopy.slice(0, 1).concat(messagesCopy.slice(1).reverse()) : messagesCopy.reverse();
reversed_ = await _firstMaxTokens(reversed_, {
...rest,
partialStrategy: allowPartial ? "last" : void 0,
endOn: startOn
});
if (swappedSystem) return [reversed_[0], ...reversed_.slice(1).reverse()];
else return reversed_.reverse();
}
const _MSG_CHUNK_MAP = {
human: {
message: HumanMessage,
messageChunk: HumanMessageChunk
},
ai: {
message: AIMessage,
messageChunk: AIMessageChunk
},
system: {
message: SystemMessage,
messageChunk: SystemMessageChunk
},
developer: {
message: SystemMessage,
messageChunk: SystemMessageChunk
},
tool: {
message: ToolMessage,
messageChunk: ToolMessageChunk
},
function: {
message: FunctionMessage,
messageChunk: FunctionMessageChunk
},
generic: {
message: ChatMessage,
messageChunk: ChatMessageChunk
},
remove: {
message: RemoveMessage,
messageChunk: RemoveMessage
}
};
function _switchTypeToMessage(messageType, fields, returnChunk) {
let chunk;
let msg;
switch (messageType) {
case "human":
if (returnChunk) chunk = new HumanMessageChunk(fields);
else msg = new HumanMessage(fields);
break;
case "ai":
if (returnChunk) {
let aiChunkFields = { ...fields };
if ("tool_calls" in aiChunkFields) aiChunkFields = {
...aiChunkFields,
tool_call_chunks: aiChunkFields.tool_calls?.map((tc) => ({
...tc,
type: "tool_call_chunk",
index: void 0,
args: JSON.stringify(tc.args)
}))
};
chunk = new AIMessageChunk(aiChunkFields);
} else msg = new AIMessage(fields);
break;
case "system":
if (returnChunk) chunk = new SystemMessageChunk(fields);
else msg = new SystemMessage(fields);
break;
case "developer":
if (returnChunk) chunk = new SystemMessageChunk({
...fields,
additional_kwargs: {
...fields.additional_kwargs,
__openai_role__: "developer"
}
});
else msg = new SystemMessage({
...fields,
additional_kwargs: {
...fields.additional_kwargs,
__openai_role__: "developer"
}
});
break;
case "tool":
if ("tool_call_id" in fields) if (returnChunk) chunk = new ToolMessageChunk(fields);
else msg = new ToolMessage(fields);
else throw new Error("Can not convert ToolMessage to ToolMessageChunk if 'tool_call_id' field is not defined.");
break;
case "function":
if (returnChunk) chunk = new FunctionMessageChunk(fields);
else {
if (!fields.name) throw new Error("FunctionMessage must have a 'name' field");
msg = new FunctionMessage(fields);
}
break;
case "generic":
if ("role" in fields) if (returnChunk) chunk = new ChatMessageChunk(fields);
else msg = new ChatMessage(fields);
else throw new Error("Can not convert ChatMessage to ChatMessageChunk if 'role' field is not defined.");
break;
default: throw new Error(`Unrecognized message type ${messageType}`);
}
if (returnChunk && chunk) return chunk;
if (msg) return msg;
throw new Error(`Unrecognized message type ${messageType}`);
}
function _chunkToMsg(chunk) {
const chunkType = chunk.getType();
let msg;
const fields = Object.fromEntries(Object.entries(chunk).filter(([k]) => !["type", "tool_call_chunks"].includes(k) && !k.startsWith("lc_")));
if (chunkType in _MSG_CHUNK_MAP) msg = _switchTypeToMessage(chunkType, fields);
if (!msg) throw new Error(`Unrecognized message chunk class ${chunkType}. Supported classes are ${Object.keys(_MSG_CHUNK_MAP)}`);
return msg;
}
/**
* The default text splitter function that splits text by newlines.
*
* @param {string} text
* @returns A promise that resolves to an array of strings split by newlines.
*/
function defaultTextSplitter(text) {
const splits = text.split("\n");
return Promise.resolve([...splits.slice(0, -1).map((s) => `${s}\n`), splits[splits.length - 1]]);
}
//#endregion
export { defaultTextSplitter, filterMessages, mergeMessageRuns, trimMessages };
//# sourceMappingURL=transformers.js.map