@reliverse/rse
Version:
@reliverse/rse is your all-in-one companion for bootstrapping and improving any kind of projects (especially web apps built with frameworks like Next.js) — whether you're kicking off something new or upgrading an existing app. It is also a little AI-power
144 lines (143 loc) • 5.31 kB
JavaScript
import { existsSync } from "@reliverse/relifso";
import {
getAuthTables
} from "better-auth/db";
export function convertToSnakeCase(str) {
return str.replace(/[A-Z]/g, (letter) => `_${letter.toLowerCase()}`);
}
export const generateDrizzleSchema = async ({
options,
file,
adapter
}) => {
const tables = getAuthTables(options);
const filePath = file || "./auth-schema.ts";
const databaseType = adapter.options?.provider;
if (!databaseType) {
throw new Error(
`Database provider type is undefined during Drizzle schema generation. Please define a \`provider\` in the Drizzle adapter config. Read more at https://better-auth.com/docs/adapters/drizzle`
);
}
const fileExist = existsSync(filePath);
let code = generateImport({ databaseType, tables });
for (const tableKey in tables) {
let getType = function(name, field) {
if (!databaseType) {
throw new Error(
`Database provider type is undefined during Drizzle schema generation. Please define a \`provider\` in the Drizzle adapter config. Read more at https://better-auth.com/docs/adapters/drizzle`
);
}
name = convertToSnakeCase(name);
if (field.references?.field === "id") {
if (options.advanced?.database?.useNumberId) {
if (databaseType === "pg") {
return `serial('${name}').primaryKey()`;
} else if (databaseType === "mysql") {
return `int('${name}').autoIncrement().primaryKey()`;
} else {
return `integer({ mode: 'number' }).primaryKey({ autoIncrement: true })`;
}
}
return `text('${name}')`;
}
const type = field.type;
const typeMap = {
string: {
sqlite: `text('${name}')`,
pg: `text('${name}')`,
mysql: field.unique ? `varchar('${name}', { length: 255 })` : field.references ? `varchar('${name}', { length: 36 })` : `text('${name}')`
},
boolean: {
sqlite: `integer('${name}', { mode: 'boolean' })`,
pg: `boolean('${name}')`,
mysql: `boolean('${name}')`
},
number: {
sqlite: `integer('${name}')`,
pg: field.bigint ? `bigint('${name}', { mode: 'number' })` : `integer('${name}')`,
mysql: field.bigint ? `bigint('${name}', { mode: 'number' })` : `int('${name}')`
},
date: {
sqlite: `integer('${name}', { mode: 'timestamp' })`,
pg: `timestamp('${name}')`,
mysql: `timestamp('${name}')`
},
"number[]": {
sqlite: `integer('${name}').array()`,
pg: field.bigint ? `bigint('${name}', { mode: 'number' }).array()` : `integer('${name}').array()`,
mysql: field.bigint ? `bigint('${name}', { mode: 'number' }).array()` : `int('${name}').array()`
},
"string[]": {
sqlite: `text('${name}').array()`,
pg: `text('${name}').array()`,
mysql: `text('${name}').array()`
}
};
return typeMap[type][databaseType];
};
const table = tables[tableKey];
const modelName = getModelName(table.modelName, adapter.options);
const fields = table.fields;
let id = "";
if (options.advanced?.database?.useNumberId) {
id = `int("id").autoincrement.primaryKey()`;
} else {
if (databaseType === "mysql") {
id = `varchar('id', { length: 36 }).primaryKey()`;
} else if (databaseType === "pg") {
id = `text('id').primaryKey()`;
} else {
id = `text('id').primaryKey()`;
}
}
const schema = `export const ${modelName} = ${databaseType}Table("${convertToSnakeCase(
modelName
)}", {
id: ${id},
${Object.keys(fields).map((field) => {
const attr = fields[field];
let type = getType(field, attr);
if (attr.defaultValue) {
if (typeof attr.defaultValue === "function") {
type += `.$defaultFn(${attr.defaultValue})`;
} else {
type += `.default(${attr.defaultValue})`;
}
}
return `${field}: ${type}${attr.required ? ".notNull()" : ""}${attr.unique ? ".unique()" : ""}${attr.references ? `.references(()=> ${getModelName(
attr.references.model,
adapter.options
)}.${attr.references.field}, { onDelete: '${attr.references.onDelete || "cascade"}' })` : ""}`;
}).join(",\n ")}
});`;
code += `
${schema}
`;
}
return {
code,
fileName: filePath,
overwrite: fileExist
};
};
function generateImport({
databaseType,
tables
}) {
const imports = [];
const hasBigint = Object.values(tables).some(
(table) => Object.values(table.fields).some((field) => field.bigint)
);
imports.push(`${databaseType}Table`);
imports.push(
databaseType === "mysql" ? "varchar, text" : databaseType === "pg" ? "text" : "text"
);
imports.push(hasBigint ? databaseType !== "sqlite" ? "bigint" : "" : "");
imports.push(databaseType !== "sqlite" ? "timestamp, boolean" : "");
imports.push(databaseType === "mysql" ? "int" : "integer");
return `import { ${imports.map((x) => x.trim()).filter((x) => x !== "").join(", ")} } from "drizzle-orm/${databaseType}-core";
`;
}
function getModelName(modelName, options) {
return options?.usePlural ? `${modelName}s` : modelName;
}