UNPKG

mongodb-rag-core

Version:

Common elements used by MongoDB Chatbot Framework components.

106 lines (103 loc) 3.81 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.makeClassifier = exports.Classification = void 0; require("dotenv/config"); const common_tags_1 = require("common-tags"); const zod_1 = require("zod"); exports.Classification = zod_1.z.object({ type: zod_1.z.string().describe("The classification type."), reason: zod_1.z .string() .optional() .describe("The reason for the classification. Only available if `chainOfThought` was set to `true`."), }); function makeClassifier({ openAiClient, model, classificationTypes, chainOfThought = false, }) { const classificationCategoriesList = classificationTypes .map(({ type, description }) => `- ${type}: ${description}`) .join("\n"); const examplesList = classificationTypes .filter(({ examples }) => examples?.length ?? 0 > 0) .map(({ examples, type }) => (examples ?? []) .map(({ text, reason }) => (0, common_tags_1.html) ` <Example classification={"${type}"} reason={"${reason ?? "null"}"} > ${text} </Example> `.trimEnd()) .join("\n")) .join("\n"); const makeSystemPrompt = (input) => (0, common_tags_1.stripIndents) ` Your task is to classify a provided input into one of the following categories. This information will be used to drive a generative process, so precision is incredibly important. Classification categories: ${classificationCategoriesList} Examples: ${examplesList} Input: ${input} `; // If chainOfThought is true, add a `reason` field to the classification const chainOfThoughtProp = chainOfThought ? { reason: { type: "string", description: "Reason for classification. Be concise. Think step by step.", }, } : {}; const required = chainOfThought ? ["type", "reason"] : ["type"]; const classifyFunc = { name: "classify", description: "Classify the type of the provided input", parameters: { type: "object", properties: { ...chainOfThoughtProp, type: { type: "string", enum: classificationTypes.map(({ type }) => type), description: "Type of the provided input", }, }, required, additionalProperties: false, }, }; return async function classify({ input }) { const messages = [ { role: "system", content: makeSystemPrompt(input), }, ]; const result = await openAiClient.chat.completions.create({ model, messages, temperature: 0, max_tokens: 300, tools: [ { type: "function", function: classifyFunc, }, ], tool_choice: { type: "function", function: { name: classifyFunc.name, }, }, stream: false, }); const response = result.choices[0].message; if (response === undefined) { throw new Error("No response from OpenAI"); } if (response.tool_calls === undefined || response.tool_calls === null) { throw new Error("No function call in response from OpenAI"); } const classification = exports.Classification.parse(JSON.parse(response.tool_calls[0].function.arguments)); return { classification, inputMessages: messages }; }; } exports.makeClassifier = makeClassifier; //# sourceMappingURL=makeClassifier.js.map