UNPKG

langchain

Version:
271 lines (270 loc) 10.6 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.createSqlQueryChain = exports.SqlDatabaseChain = void 0; const base_1 = require("@langchain/core/language_models/base"); const runnables_1 = require("@langchain/core/runnables"); const output_parsers_1 = require("@langchain/core/output_parsers"); const sql_db_prompt_js_1 = require("./sql_db_prompt.cjs"); const base_js_1 = require("../base.cjs"); const llm_chain_js_1 = require("../llm_chain.cjs"); const sql_utils_js_1 = require("../../util/sql_utils.cjs"); /** * Class that represents a SQL database chain in the LangChain framework. * It extends the BaseChain class and implements the functionality * specific to a SQL database chain. * * @security **Security Notice** * This chain generates SQL queries for the given database. * The SQLDatabase class provides a getTableInfo method that can be used * to get column information as well as sample data from the table. * To mitigate risk of leaking sensitive data, limit permissions * to read and scope to the tables that are needed. * Optionally, use the includesTables or ignoreTables class parameters * to limit which tables can/cannot be accessed. * * @link See https://js.langchain.com/docs/security for more information. * @example * ```typescript * const chain = new SqlDatabaseChain({ * llm: new OpenAI({ temperature: 0 }), * database: new SqlDatabase({ ...config }), * }); * * const result = await chain.run("How many tracks are there?"); * ``` */ class SqlDatabaseChain extends base_js_1.BaseChain { static lc_name() { return "SqlDatabaseChain"; } constructor(fields) { super(fields); // LLM wrapper to use Object.defineProperty(this, "llm", { enumerable: true, configurable: true, writable: true, value: void 0 }); // SQL Database to connect to. Object.defineProperty(this, "database", { enumerable: true, configurable: true, writable: true, value: void 0 }); // Prompt to use to translate natural language to SQL. Object.defineProperty(this, "prompt", { enumerable: true, configurable: true, writable: true, value: sql_db_prompt_js_1.DEFAULT_SQL_DATABASE_PROMPT }); // Number of results to return from the query Object.defineProperty(this, "topK", { enumerable: true, configurable: true, writable: true, value: 5 }); Object.defineProperty(this, "inputKey", { enumerable: true, configurable: true, writable: true, value: "query" }); Object.defineProperty(this, "outputKey", { enumerable: true, configurable: true, writable: true, value: "result" }); Object.defineProperty(this, "sqlOutputKey", { enumerable: true, configurable: true, writable: true, value: undefined }); // Whether to return the result of querying the SQL table directly. Object.defineProperty(this, "returnDirect", { enumerable: true, configurable: true, writable: true, value: false }); this.llm = fields.llm; this.database = fields.database; this.topK = fields.topK ?? this.topK; this.inputKey = fields.inputKey ?? this.inputKey; this.outputKey = fields.outputKey ?? this.outputKey; this.sqlOutputKey = fields.sqlOutputKey ?? this.sqlOutputKey; this.prompt = fields.prompt ?? (0, sql_utils_js_1.getPromptTemplateFromDataSource)(this.database.appDataSource); } /** @ignore */ async _call(values, runManager) { const llmChain = new llm_chain_js_1.LLMChain({ prompt: this.prompt, llm: this.llm, outputKey: this.outputKey, memory: this.memory, }); if (!(this.inputKey in values)) { throw new Error(`Question key ${this.inputKey} not found.`); } const question = values[this.inputKey]; let inputText = `${question}\nSQLQuery:`; const tablesToUse = values.table_names_to_use; const tableInfo = await this.database.getTableInfo(tablesToUse); const llmInputs = { input: inputText, top_k: this.topK, dialect: this.database.appDataSourceOptions.type, table_info: tableInfo, stop: ["\nSQLResult:"], }; await this.verifyNumberOfTokens(inputText, tableInfo); const sqlCommand = await llmChain.predict(llmInputs, runManager?.getChild("sql_generation")); let queryResult = ""; try { queryResult = await this.database.appDataSource.query(sqlCommand); } catch (error) { console.error(error); } let finalResult; if (this.returnDirect) { finalResult = { [this.outputKey]: queryResult }; } else { inputText += `${sqlCommand}\nSQLResult: ${JSON.stringify(queryResult)}\nAnswer:`; llmInputs.input = inputText; finalResult = { [this.outputKey]: await llmChain.predict(llmInputs, runManager?.getChild("result_generation")), }; } if (this.sqlOutputKey != null) { finalResult[this.sqlOutputKey] = sqlCommand; } return finalResult; } _chainType() { return "sql_database_chain"; } get inputKeys() { return [this.inputKey]; } get outputKeys() { if (this.sqlOutputKey != null) { return [this.outputKey, this.sqlOutputKey]; } return [this.outputKey]; } /** * Private method that verifies the number of tokens in the input text and * table information. It throws an error if the number of tokens exceeds * the maximum allowed by the language model. * @param inputText The input text. * @param tableinfo The table information. * @returns A promise that resolves when the verification is complete. */ async verifyNumberOfTokens(inputText, tableinfo) { // We verify it only for OpenAI for the moment if (this.llm._llmType() !== "openai") { return; } const llm = this.llm; const promptTemplate = this.prompt.template; const stringWeSend = `${inputText}${promptTemplate}${tableinfo}`; const maxToken = await (0, base_1.calculateMaxTokens)({ prompt: stringWeSend, // Cast here to allow for other models that may not fit the union modelName: llm.model, }); if (maxToken < (llm.maxTokens ?? -1)) { throw new Error(`The combination of the database structure and your question is too big for the model ${llm.model} which can compute only a max tokens of ${(0, base_1.getModelContextSize)(llm.model)}. We suggest you to use the includeTables parameters when creating the SqlDatabase object to select only a subset of the tables. You can also use a model which can handle more tokens.`); } } } exports.SqlDatabaseChain = SqlDatabaseChain; const strip = (text) => { // Replace escaped quotes with actual quotes let newText = text.replace(/\\"/g, '"').trim(); // Remove wrapping quotes if the entire string is wrapped in quotes if (newText.startsWith('"') && newText.endsWith('"')) { newText = newText.substring(1, newText.length - 1); } return newText; }; const difference = (setA, setB) => new Set([...setA].filter((x) => !setB.has(x))); /** * Create a SQL query chain that can create SQL queries for the given database. * Returns a Runnable. * * @param {BaseLanguageModel} llm The language model to use in the chain. * @param {SqlDatabase} db The database to use in the chain. * @param {BasePromptTemplate | undefined} prompt The prompt to use in the chain. * @param {BaseLanguageModel | undefined} k The amount of docs/results to return. Passed through the prompt input value `top_k`. * @param {SqlDialect} dialect The SQL dialect to use in the chain. * @returns {Promise<RunnableSequence<Record<string, unknown>, string>>} A runnable sequence representing the chain. * @example ```typescript * const datasource = new DataSource({ * type: "sqlite", * database: "../../../../Chinook.db", * }); * const db = await SqlDatabase.fromDataSourceParams({ * appDataSource: datasource, * }); * const llm = new ChatOpenAI({ temperature: 0 }); * const chain = await createSqlQueryChain({ * llm, * db, * dialect: "sqlite", * }); * ``` */ async function createSqlQueryChain({ llm, db, prompt, k = 5, dialect, }) { let promptToUse; if (prompt) { promptToUse = prompt; } else if (sql_db_prompt_js_1.SQL_PROMPTS_MAP[dialect]) { promptToUse = sql_db_prompt_js_1.SQL_PROMPTS_MAP[dialect]; } else { promptToUse = sql_db_prompt_js_1.DEFAULT_SQL_DATABASE_PROMPT; } if (difference(new Set(["input", "top_k", "table_info"]), new Set(promptToUse.inputVariables)).size > 0) { throw new Error(`Prompt must have input variables: 'input', 'top_k', 'table_info'. Received prompt with input variables: ` + `${promptToUse.inputVariables}. Full prompt:\n\n${promptToUse}`); } if (promptToUse.inputVariables.includes("dialect")) { promptToUse = await promptToUse.partial({ dialect }); } promptToUse = await promptToUse.partial({ top_k: k.toString() }); const inputs = { input: (x) => { if ("question" in x) { return `${x.question}\nSQLQuery: `; } throw new Error("Input must include a question property."); }, table_info: async (x) => db.getTableInfo(x.tableNamesToUse), }; return runnables_1.RunnableSequence.from([ runnables_1.RunnablePassthrough.assign(inputs), (x) => { const newInputs = { ...x }; delete newInputs.question; delete newInputs.tableNamesToUse; return newInputs; }, promptToUse, llm.bind({ stop: ["\nSQLResult:"] }), new output_parsers_1.StringOutputParser(), strip, ]); } exports.createSqlQueryChain = createSqlQueryChain;