UNPKG

better-auth

Version:

The most comprehensive authentication framework for TypeScript.

339 lines (337 loc) • 12.2 kB
import { createKyselyAdapter } from "../adapters/kysely-adapter/dialect.mjs"; import { getSchema } from "./get-schema.mjs"; import { getAuthTables } from "@better-auth/core/db"; import { createLogger } from "@better-auth/core/env"; import { sql } from "kysely"; import { initGetFieldName, initGetModelName } from "@better-auth/core/db/adapter"; //#region src/db/get-migration.ts const map = { postgres: { string: [ "character varying", "varchar", "text", "uuid" ], number: [ "int4", "integer", "bigint", "smallint", "numeric", "real", "double precision" ], boolean: ["bool", "boolean"], date: [ "timestamptz", "timestamp", "date" ], json: ["json", "jsonb"] }, mysql: { string: [ "varchar", "text", "uuid" ], number: [ "integer", "int", "bigint", "smallint", "decimal", "float", "double" ], boolean: ["boolean", "tinyint"], date: [ "timestamp", "datetime", "date" ], json: ["json"] }, sqlite: { string: ["TEXT"], number: ["INTEGER", "REAL"], boolean: ["INTEGER", "BOOLEAN"], date: ["DATE", "INTEGER"], json: ["TEXT"] }, mssql: { string: [ "varchar", "nvarchar", "uniqueidentifier" ], number: [ "int", "bigint", "smallint", "decimal", "float", "double" ], boolean: ["bit", "smallint"], date: [ "datetime2", "date", "datetime" ], json: ["varchar", "nvarchar"] } }; function matchType(columnDataType, fieldType, dbType) { function normalize(type) { return type.toLowerCase().split("(")[0].trim(); } if (fieldType === "string[]" || fieldType === "number[]") return columnDataType.toLowerCase().includes("json"); const types = map[dbType]; return (Array.isArray(fieldType) ? types["string"].map((t) => t.toLowerCase()) : types[fieldType].map((t) => t.toLowerCase())).includes(normalize(columnDataType)); } /** * Get the current PostgreSQL schema (search_path) for the database connection * Returns the first schema in the search_path, defaulting to 'public' if not found */ async function getPostgresSchema(db) { try { const result = await sql`SHOW search_path`.execute(db); if (result.rows[0]?.search_path) return result.rows[0].search_path.split(",").map((s) => s.trim()).map((s) => s.replace(/^["']|["']$/g, "")).filter((s) => !s.startsWith("$"))[0] || "public"; } catch {} return "public"; } async function getMigrations(config) { const betterAuthSchema = getSchema(config); const logger$1 = createLogger(config.logger); let { kysely: db, databaseType: dbType } = await createKyselyAdapter(config); if (!dbType) { logger$1.warn("Could not determine database type, defaulting to sqlite. Please provide a type in the database options to avoid this."); dbType = "sqlite"; } if (!db) { logger$1.error("Only kysely adapter is supported for migrations. You can use `generate` command to generate the schema, if you're using a different adapter."); process.exit(1); } let currentSchema = "public"; if (dbType === "postgres") { currentSchema = await getPostgresSchema(db); logger$1.debug(`PostgreSQL migration: Using schema '${currentSchema}' (from search_path)`); try { if (!(await sql` SELECT schema_name FROM information_schema.schemata WHERE schema_name = ${currentSchema} `.execute(db)).rows[0]) logger$1.warn(`Schema '${currentSchema}' does not exist. Tables will be inspected from available schemas. Consider creating the schema first or checking your database configuration.`); } catch (error) { logger$1.debug(`Could not verify schema existence: ${error instanceof Error ? error.message : String(error)}`); } } const allTableMetadata = await db.introspection.getTables(); let tableMetadata = allTableMetadata; if (dbType === "postgres") try { const tablesInSchema = await sql` SELECT table_name FROM information_schema.tables WHERE table_schema = ${currentSchema} AND table_type = 'BASE TABLE' `.execute(db); const tableNamesInSchema = new Set(tablesInSchema.rows.map((row) => row.table_name)); tableMetadata = allTableMetadata.filter((table) => table.schema === currentSchema && tableNamesInSchema.has(table.name)); logger$1.debug(`Found ${tableMetadata.length} table(s) in schema '${currentSchema}': ${tableMetadata.map((t) => t.name).join(", ") || "(none)"}`); } catch (error) { logger$1.warn(`Could not filter tables by schema. Using all discovered tables. Error: ${error instanceof Error ? error.message : String(error)}`); } const toBeCreated = []; const toBeAdded = []; for (const [key, value] of Object.entries(betterAuthSchema)) { const table = tableMetadata.find((t) => t.name === key); if (!table) { const tIndex = toBeCreated.findIndex((t) => t.table === key); const tableData = { table: key, fields: value.fields, order: value.order || Infinity }; const insertIndex = toBeCreated.findIndex((t) => (t.order || Infinity) > tableData.order); if (insertIndex === -1) if (tIndex === -1) toBeCreated.push(tableData); else toBeCreated[tIndex].fields = { ...toBeCreated[tIndex].fields, ...value.fields }; else toBeCreated.splice(insertIndex, 0, tableData); continue; } let toBeAddedFields = {}; for (const [fieldName, field] of Object.entries(value.fields)) { const column = table.columns.find((c) => c.name === fieldName); if (!column) { toBeAddedFields[fieldName] = field; continue; } if (matchType(column.dataType, field.type, dbType)) continue; else logger$1.warn(`Field ${fieldName} in table ${key} has a different type in the database. Expected ${field.type} but got ${column.dataType}.`); } if (Object.keys(toBeAddedFields).length > 0) toBeAdded.push({ table: key, fields: toBeAddedFields, order: value.order || Infinity }); } const migrations = []; const useUUIDs = config.advanced?.database?.generateId === "uuid"; const useNumberId = config.advanced?.database?.useNumberId || config.advanced?.database?.generateId === "serial"; function getType(field, fieldName) { const type = field.type; const provider = dbType || "sqlite"; const typeMap = { string: { sqlite: "text", postgres: "text", mysql: field.unique ? "varchar(255)" : field.references ? "varchar(36)" : field.sortable ? "varchar(255)" : field.index ? "varchar(255)" : "text", mssql: field.unique || field.sortable ? "varchar(255)" : field.references ? "varchar(36)" : "varchar(8000)" }, boolean: { sqlite: "integer", postgres: "boolean", mysql: "boolean", mssql: "smallint" }, number: { sqlite: field.bigint ? "bigint" : "integer", postgres: field.bigint ? "bigint" : "integer", mysql: field.bigint ? "bigint" : "integer", mssql: field.bigint ? "bigint" : "integer" }, date: { sqlite: "date", postgres: "timestamptz", mysql: "timestamp(3)", mssql: sql`datetime2(3)` }, json: { sqlite: "text", postgres: "jsonb", mysql: "json", mssql: "varchar(8000)" }, id: { postgres: useNumberId ? sql`integer GENERATED BY DEFAULT AS IDENTITY` : useUUIDs ? "uuid" : "text", mysql: useNumberId ? "integer" : useUUIDs ? "varchar(36)" : "varchar(36)", mssql: useNumberId ? "integer" : useUUIDs ? "varchar(36)" : "varchar(36)", sqlite: useNumberId ? "integer" : "text" }, foreignKeyId: { postgres: useNumberId ? "integer" : useUUIDs ? "uuid" : "text", mysql: useNumberId ? "integer" : useUUIDs ? "varchar(36)" : "varchar(36)", mssql: useNumberId ? "integer" : useUUIDs ? "varchar(36)" : "varchar(36)", sqlite: useNumberId ? "integer" : "text" }, "string[]": { sqlite: "text", postgres: "jsonb", mysql: "json", mssql: "varchar(8000)" }, "number[]": { sqlite: "text", postgres: "jsonb", mysql: "json", mssql: "varchar(8000)" } }; if (fieldName === "id" || field.references?.field === "id") { if (fieldName === "id") return typeMap.id[provider]; return typeMap.foreignKeyId[provider]; } if (Array.isArray(type)) return "text"; if (!(type in typeMap)) throw new Error(`Unsupported field type '${String(type)}' for field '${fieldName}'. Allowed types are: string, number, boolean, date, string[], number[]. If you need to store structured data, store it as a JSON string (type: "string") or split it into primitive fields. See https://better-auth.com/docs/advanced/schema#additional-fields`); return typeMap[type][provider]; } const getModelName = initGetModelName({ schema: getAuthTables(config), usePlural: false }); const getFieldName = initGetFieldName({ schema: getAuthTables(config), usePlural: false }); function getReferencePath(model, field) { try { return `${getModelName(model)}.${getFieldName({ model, field })}`; } catch { return `${model}.${field}`; } } if (toBeAdded.length) for (const table of toBeAdded) for (const [fieldName, field] of Object.entries(table.fields)) { const type = getType(field, fieldName); let builder = db.schema.alterTable(table.table); if (field.index) { const index = db.schema.alterTable(table.table).addIndex(`${table.table}_${fieldName}_idx`); migrations.push(index); } let built = builder.addColumn(fieldName, type, (col) => { col = field.required !== false ? col.notNull() : col; if (field.references) col = col.references(getReferencePath(field.references.model, field.references.field)).onDelete(field.references.onDelete || "cascade"); if (field.unique) col = col.unique(); if (field.type === "date" && typeof field.defaultValue === "function" && (dbType === "postgres" || dbType === "mysql" || dbType === "mssql")) if (dbType === "mysql") col = col.defaultTo(sql`CURRENT_TIMESTAMP(3)`); else col = col.defaultTo(sql`CURRENT_TIMESTAMP`); return col; }); migrations.push(built); } let toBeIndexed = []; if (config.advanced?.database?.useNumberId) logger$1.warn("`useNumberId` is deprecated. Please use `generateId` with `serial` instead."); if (toBeCreated.length) for (const table of toBeCreated) { const idType = getType({ type: useNumberId ? "number" : "string" }, "id"); let dbT = db.schema.createTable(table.table).addColumn("id", idType, (col) => { if (useNumberId) { if (dbType === "postgres") return col.primaryKey().notNull(); else if (dbType === "sqlite") return col.primaryKey().notNull(); else if (dbType === "mssql") return col.identity().primaryKey().notNull(); return col.autoIncrement().primaryKey().notNull(); } if (useUUIDs) { if (dbType === "postgres") return col.primaryKey().defaultTo(sql`pg_catalog.gen_random_uuid()`).notNull(); return col.primaryKey().notNull(); } return col.primaryKey().notNull(); }); for (const [fieldName, field] of Object.entries(table.fields)) { const type = getType(field, fieldName); dbT = dbT.addColumn(fieldName, type, (col) => { col = field.required !== false ? col.notNull() : col; if (field.references) col = col.references(getReferencePath(field.references.model, field.references.field)).onDelete(field.references.onDelete || "cascade"); if (field.unique) col = col.unique(); if (field.type === "date" && typeof field.defaultValue === "function" && (dbType === "postgres" || dbType === "mysql" || dbType === "mssql")) if (dbType === "mysql") col = col.defaultTo(sql`CURRENT_TIMESTAMP(3)`); else col = col.defaultTo(sql`CURRENT_TIMESTAMP`); return col; }); if (field.index) { let builder = db.schema.createIndex(`${table.table}_${fieldName}_${field.unique ? "uidx" : "idx"}`).on(table.table).columns([fieldName]); toBeIndexed.push(field.unique ? builder.unique() : builder); } } migrations.push(dbT); } if (toBeIndexed.length) for (const index of toBeIndexed) migrations.push(index); async function runMigrations() { for (const migration of migrations) await migration.execute(); } async function compileMigrations() { return migrations.map((m) => m.compile().sql).join(";\n\n") + ";"; } return { toBeCreated, toBeAdded, runMigrations, compileMigrations }; } //#endregion export { getMigrations, matchType }; //# sourceMappingURL=get-migration.mjs.map