mongodb-rag-core
Version:
Common elements used by MongoDB Chatbot Framework components.
106 lines (103 loc) • 3.81 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.makeClassifier = exports.Classification = void 0;
require("dotenv/config");
const common_tags_1 = require("common-tags");
const zod_1 = require("zod");
exports.Classification = zod_1.z.object({
type: zod_1.z.string().describe("The classification type."),
reason: zod_1.z
.string()
.optional()
.describe("The reason for the classification. Only available if `chainOfThought` was set to `true`."),
});
function makeClassifier({ openAiClient, model, classificationTypes, chainOfThought = false, }) {
const classificationCategoriesList = classificationTypes
.map(({ type, description }) => `- ${type}: ${description}`)
.join("\n");
const examplesList = classificationTypes
.filter(({ examples }) => examples?.length ?? 0 > 0)
.map(({ examples, type }) => (examples ?? [])
.map(({ text, reason }) => (0, common_tags_1.html) `
<Example classification={"${type}"} reason={"${reason ?? "null"}"} >
${text}
</Example>
`.trimEnd())
.join("\n"))
.join("\n");
const makeSystemPrompt = (input) => (0, common_tags_1.stripIndents) `
Your task is to classify a provided input into one of the following categories.
This information will be used to drive a generative process, so precision is incredibly important.
Classification categories:
${classificationCategoriesList}
Examples:
${examplesList}
Input:
${input}
`;
// If chainOfThought is true, add a `reason` field to the classification
const chainOfThoughtProp = chainOfThought
? {
reason: {
type: "string",
description: "Reason for classification. Be concise. Think step by step.",
},
}
: {};
const required = chainOfThought ? ["type", "reason"] : ["type"];
const classifyFunc = {
name: "classify",
description: "Classify the type of the provided input",
parameters: {
type: "object",
properties: {
...chainOfThoughtProp,
type: {
type: "string",
enum: classificationTypes.map(({ type }) => type),
description: "Type of the provided input",
},
},
required,
additionalProperties: false,
},
};
return async function classify({ input }) {
const messages = [
{
role: "system",
content: makeSystemPrompt(input),
},
];
const result = await openAiClient.chat.completions.create({
model,
messages,
temperature: 0,
max_tokens: 300,
tools: [
{
type: "function",
function: classifyFunc,
},
],
tool_choice: {
type: "function",
function: {
name: classifyFunc.name,
},
},
stream: false,
});
const response = result.choices[0].message;
if (response === undefined) {
throw new Error("No response from OpenAI");
}
if (response.tool_calls === undefined || response.tool_calls === null) {
throw new Error("No function call in response from OpenAI");
}
const classification = exports.Classification.parse(JSON.parse(response.tool_calls[0].function.arguments));
return { classification, inputMessages: messages };
};
}
exports.makeClassifier = makeClassifier;
//# sourceMappingURL=makeClassifier.js.map