langchain
Version:
Typescript bindings for langchain
335 lines (334 loc) • 16.5 kB
JavaScript
Object.defineProperty(exports, "__esModule", { value: true });
exports.ViolationOfExpectationsChain = void 0;
const messages_1 = require("@langchain/core/messages");
const output_parsers_1 = require("@langchain/core/output_parsers");
const openai_functions_js_1 = require("../../../output_parsers/openai_functions.cjs");
const base_js_1 = require("../../../chains/base.cjs");
const types_js_1 = require("./types.cjs");
const violation_of_expectations_prompt_js_1 = require("./violation_of_expectations_prompt.cjs");
/**
* Chain that generates key insights/facts of a user based on a
* a chat conversation with an AI.
*/
class ViolationOfExpectationsChain extends base_js_1.BaseChain {
static lc_name() {
return "ViolationOfExpectationsChain";
}
_chainType() {
return "violation_of_expectation_chain";
}
get inputKeys() {
return [this.chatHistoryKey];
}
get outputKeys() {
return [this.thoughtsKey];
}
constructor(fields) {
super(fields);
Object.defineProperty(this, "chatHistoryKey", {
enumerable: true,
configurable: true,
writable: true,
value: "chat_history"
});
Object.defineProperty(this, "thoughtsKey", {
enumerable: true,
configurable: true,
writable: true,
value: "thoughts"
});
Object.defineProperty(this, "retriever", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "llm", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "jsonOutputParser", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "stringOutputParser", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.retriever = fields.retriever;
this.llm = fields.llm;
this.jsonOutputParser = new openai_functions_js_1.JsonOutputFunctionsParser();
this.stringOutputParser = new output_parsers_1.StringOutputParser();
}
getChatHistoryString(chatHistory) {
return chatHistory
.map((chatMessage) => {
if (chatMessage._getType() === "human") {
return `Human: ${chatMessage.content}`;
}
else if (chatMessage._getType() === "ai") {
return `AI: ${chatMessage.content}`;
}
else {
return `${chatMessage.content}`;
}
})
.join("\n");
}
removeDuplicateStrings(strings) {
return [...new Set(strings)];
}
/**
* This method breaks down the chat history into chunks of messages.
* Each chunk consists of a sequence of messages ending with an AI message and the subsequent user response, if any.
*
* @param {BaseMessage[]} chatHistory - The chat history to be chunked.
*
* @returns {MessageChunkResult[]} An array of message chunks. Each chunk includes a sequence of messages and the subsequent user response.
*
* @description
* The method iterates over the chat history and pushes each message into a temporary array.
* When it encounters an AI message, it checks for a subsequent user message.
* If a user message is found, it is considered as the user response to the AI message.
* If no user message is found after the AI message, the user response is undefined.
* The method then pushes the chunk (sequence of messages and user response) into the result array.
* This process continues until all messages in the chat history have been processed.
*/
chunkMessagesByAIResponse(chatHistory) {
const newArray = [];
const tempArray = [];
chatHistory.forEach((item, index) => {
tempArray.push(item);
if (item._getType() === "ai") {
let userResponse = chatHistory[index + 1];
if (!userResponse || userResponse._getType() !== "human") {
userResponse = undefined;
}
newArray.push({
chunkedMessages: tempArray,
userResponse: userResponse
? new messages_1.HumanMessage(userResponse)
: undefined,
});
}
});
return newArray;
}
/**
* This method processes a chat history to generate insights about the user.
*
* @param {ChainValues} values - The input values for the chain. It should contain a key for chat history.
* @param {CallbackManagerForChainRun} [runManager] - Optional callback manager for the chain run.
*
* @returns {Promise<ChainValues>} A promise that resolves to a list of insights about the user.
*
* @throws {Error} If the chat history key is not found in the input values or if the chat history is not an array of BaseMessages.
*
* @description
* The method performs the following steps:
* 1. Checks if the chat history key is present in the input values and if the chat history is an array of BaseMessages.
* 2. Breaks the chat history into chunks of messages.
* 3. For each chunk, it generates an initial prediction for the user's next message.
* 4. For each prediction, it generates insights and prediction violations, and regenerates the prediction based on the violations.
* 5. For each set of messages, it generates a fact/insight about the user.
* The method returns a list of these insights.
*/
async _call(values, runManager) {
if (!(this.chatHistoryKey in values)) {
throw new Error(`Chat history key ${this.chatHistoryKey} not found`);
}
const chatHistory = values[this.chatHistoryKey];
const isEveryMessageBaseMessage = chatHistory.every((message) => (0, messages_1.isBaseMessage)(message));
if (!isEveryMessageBaseMessage) {
throw new Error("Chat history must be an array of BaseMessages");
}
const messageChunks = this.chunkMessagesByAIResponse(chatHistory);
// Generate the initial prediction for every user message.
const userPredictions = await Promise.all(messageChunks.map(async (chatHistoryChunk) => ({
userPredictions: await this.predictNextUserMessage(chatHistoryChunk.chunkedMessages),
userResponse: chatHistoryChunk.userResponse,
runManager,
})));
// Generate insights, and prediction violations for every user message.
// This call also regenerates the prediction based on the violations.
const predictionViolations = await Promise.all(userPredictions.map((prediction) => this.getPredictionViolations({
userPredictions: prediction.userPredictions,
userResponse: prediction.userResponse,
runManager,
})));
// Generate a fact/insight about the user for every set of messages.
const insights = await Promise.all(predictionViolations.map((violation) => this.generateFacts({
userResponse: violation.userResponse,
predictions: {
revisedPrediction: violation.revisedPrediction,
explainedPredictionErrors: violation.explainedPredictionErrors,
},
})));
return {
insights,
};
}
/**
* This method predicts the next user message based on the chat history.
*
* @param {BaseMessage[]} chatHistory - The chat history based on which the next user message is predicted.
* @param {CallbackManagerForChainRun} [runManager] - Optional callback manager for the chain run.
*
* @returns {Promise<PredictNextUserMessageResponse>} A promise that resolves to the predicted next user message, the user state, and any insights.
*
* @throws {Error} If the response from the language model does not contain the expected keys: 'userState', 'predictedUserMessage', and 'insights'.
*/
async predictNextUserMessage(chatHistory, runManager) {
const messageString = this.getChatHistoryString(chatHistory);
const llmWithFunctions = this.llm.bind({
functions: [types_js_1.PREDICT_NEXT_USER_MESSAGE_FUNCTION],
function_call: { name: types_js_1.PREDICT_NEXT_USER_MESSAGE_FUNCTION.name },
});
const chain = violation_of_expectations_prompt_js_1.PREDICT_NEXT_USER_MESSAGE_PROMPT.pipe(llmWithFunctions).pipe(this.jsonOutputParser);
const res = await chain.invoke({
chat_history: messageString,
}, runManager?.getChild("prediction"));
if (!("userState" in res &&
"predictedUserMessage" in res &&
"insights" in res)) {
throw new Error(`Invalid response from LLM: ${JSON.stringify(res)}`);
}
const predictionResponse = res;
// Query the retriever for relevant insights. Use the generates insights as a query.
const retrievedDocs = await this.retrieveRelevantInsights(predictionResponse.insights);
const relevantDocs = this.removeDuplicateStrings([
...predictionResponse.insights,
...retrievedDocs,
]);
return {
...predictionResponse,
insights: relevantDocs,
};
}
/**
* Retrieves relevant insights based on the provided insights.
*
* @param {Array<string>} insights - An array of insights to be used for retrieving relevant documents.
*
* @returns {Promise<Array<string>>} A promise that resolves to an array of relevant insights content.
*/
async retrieveRelevantInsights(insights) {
// Only extract the first relevant doc from the retriever. We don't need more than one.
const relevantInsightsDocuments = await Promise.all(insights.map(async (insight) => {
const relevantInsight = await this.retriever.getRelevantDocuments(insight);
return relevantInsight[0];
}));
const relevantInsightsContent = relevantInsightsDocuments.map((document) => document.pageContent);
return relevantInsightsContent;
}
/**
* This method generates prediction violations based on the predicted and actual user responses.
* It also generates a revised prediction based on the identified violations.
*
* @param {Object} params - The parameters for the method.
* @param {PredictNextUserMessageResponse} params.userPredictions - The predicted user message, user state, and insights.
* @param {BaseMessage} [params.userResponse] - The actual user response.
* @param {CallbackManagerForChainRun} [params.runManager] - Optional callback manager for the chain run.
*
* @returns {Promise<{ userResponse: BaseMessage | undefined; revisedPrediction: string; explainedPredictionErrors: Array<string>; }>} A promise that resolves to an object containing the actual user response, the revised prediction, and the explained prediction errors.
*
* @throws {Error} If the response from the language model does not contain the expected keys: 'violationExplanation', 'explainedPredictionErrors', and 'accuratePrediction'.
*/
async getPredictionViolations({ userPredictions, userResponse, runManager, }) {
const llmWithFunctions = this.llm.bind({
functions: [types_js_1.PREDICTION_VIOLATIONS_FUNCTION],
function_call: { name: types_js_1.PREDICTION_VIOLATIONS_FUNCTION.name },
});
const chain = violation_of_expectations_prompt_js_1.PREDICTION_VIOLATIONS_PROMPT.pipe(llmWithFunctions).pipe(this.jsonOutputParser);
if (typeof userResponse?.content !== "string") {
throw new Error("This chain does not support non-string model output.");
}
const res = (await chain.invoke({
predicted_output: userPredictions.predictedUserMessage,
actual_output: userResponse?.content ?? "",
user_insights: userPredictions.insights.join("\n"),
}, runManager?.getChild("prediction_violations")));
// Generate a revised prediction based on violations.
const revisedPrediction = await this.generateRevisedPrediction({
originalPrediction: userPredictions.predictedUserMessage,
explainedPredictionErrors: res.explainedPredictionErrors,
userInsights: userPredictions.insights,
runManager,
});
return {
userResponse,
revisedPrediction,
explainedPredictionErrors: res.explainedPredictionErrors,
};
}
/**
* This method generates a revised prediction based on the original prediction, explained prediction errors, and user insights.
*
* @param {Object} params - The parameters for the method.
* @param {string} params.originalPrediction - The original prediction made by the model.
* @param {Array<string>} params.explainedPredictionErrors - An array of explained prediction errors.
* @param {Array<string>} params.userInsights - An array of insights about the user.
* @param {CallbackManagerForChainRun} [params.runManager] - Optional callback manager for the chain run.
*
* @returns {Promise<string>} A promise that resolves to a revised prediction.
*/
async generateRevisedPrediction({ originalPrediction, explainedPredictionErrors, userInsights, runManager, }) {
const revisedPredictionChain = violation_of_expectations_prompt_js_1.GENERATE_REVISED_PREDICTION_PROMPT.pipe(this.llm).pipe(this.stringOutputParser);
const revisedPredictionRes = await revisedPredictionChain.invoke({
prediction: originalPrediction,
explained_prediction_errors: explainedPredictionErrors.join("\n"),
user_insights: userInsights.join("\n"),
}, runManager?.getChild("prediction_revision"));
return revisedPredictionRes;
}
/**
* This method generates facts or insights about the user based on the revised prediction, explained prediction errors, and the user's response.
*
* @param {Object} params - The parameters for the method.
* @param {BaseMessage} [params.userResponse] - The actual user response.
* @param {Object} params.predictions - The revised prediction and explained prediction errors.
* @param {string} params.predictions.revisedPrediction - The revised prediction made by the model.
* @param {Array<string>} params.predictions.explainedPredictionErrors - An array of explained prediction errors.
* @param {CallbackManagerForChainRun} [params.runManager] - Optional callback manager for the chain run.
*
* @returns {Promise<string>} A promise that resolves to a string containing the generated facts or insights about the user.
*/
async generateFacts({ userResponse, predictions, runManager, }) {
const chain = violation_of_expectations_prompt_js_1.GENERATE_FACTS_PROMPT.pipe(this.llm).pipe(this.stringOutputParser);
if (typeof userResponse?.content !== "string") {
throw new Error("This chain does not support non-string model output.");
}
const res = await chain.invoke({
prediction_violations: predictions.explainedPredictionErrors.join("\n"),
prediction: predictions.revisedPrediction,
user_message: userResponse?.content ?? "",
}, runManager?.getChild("generate_facts"));
return res;
}
/**
* Static method that creates a ViolationOfExpectationsChain instance from a
* ChatOpenAI and retriever. It also accepts optional options
* to customize the chain.
*
* @param llm The ChatOpenAI instance.
* @param retriever The retriever used for similarity search.
* @param options Optional options to customize the chain.
*
* @returns A new instance of ViolationOfExpectationsChain.
*/
static fromLLM(llm, retriever, options) {
return new this({
retriever,
llm,
...options,
});
}
}
exports.ViolationOfExpectationsChain = ViolationOfExpectationsChain;
;