@auth/drizzle-adapter
Version:
Drizzle adapter for Auth.js.
640 lines (606 loc) • 18.1 kB
text/typescript
import { GeneratedColumnConfig, and, eq, getTableColumns } from "drizzle-orm"
import {
MySqlColumn,
MySqlDatabase,
boolean,
int,
mysqlTable,
primaryKey,
timestamp,
varchar,
PreparedQueryHKTBase,
MySqlTableWithColumns,
MySqlQueryResultHKT,
} from "drizzle-orm/mysql-core"
import type {
Adapter,
AdapterAccount,
AdapterAccountType,
AdapterSession,
AdapterUser,
VerificationToken,
AdapterAuthenticator,
} from "@auth/core/adapters"
import { Awaitable } from "@auth/core/types"
export function defineTables(
schema: Partial<DefaultMySqlSchema> = {}
): Required<DefaultMySqlSchema> {
const usersTable =
schema.usersTable ??
(mysqlTable("user", {
id: varchar("id", { length: 255 })
.primaryKey()
.$defaultFn(() => crypto.randomUUID()),
name: varchar("name", { length: 255 }),
email: varchar("email", { length: 255 }).unique(),
emailVerified: timestamp("emailVerified", { mode: "date", fsp: 3 }),
image: varchar("image", { length: 255 }),
}) satisfies DefaultMySqlUsersTable)
const accountsTable =
schema.accountsTable ??
(mysqlTable(
"account",
{
userId: varchar("userId", { length: 255 })
.notNull()
.references(() => usersTable.id, { onDelete: "cascade" }),
type: varchar("type", { length: 255 })
.$type<AdapterAccountType>()
.notNull(),
provider: varchar("provider", { length: 255 }).notNull(),
providerAccountId: varchar("providerAccountId", {
length: 255,
}).notNull(),
refresh_token: varchar("refresh_token", { length: 255 }),
access_token: varchar("access_token", { length: 255 }),
expires_at: int("expires_at"),
token_type: varchar("token_type", { length: 255 }),
scope: varchar("scope", { length: 255 }),
id_token: varchar("id_token", { length: 2048 }),
session_state: varchar("session_state", { length: 255 }),
},
(account) => ({
compositePk: primaryKey({
columns: [account.provider, account.providerAccountId],
}),
})
) satisfies DefaultMySqlAccountsTable)
const sessionsTable =
schema.sessionsTable ??
(mysqlTable("session", {
sessionToken: varchar("sessionToken", { length: 255 }).primaryKey(),
userId: varchar("userId", { length: 255 })
.notNull()
.references(() => usersTable.id, { onDelete: "cascade" }),
expires: timestamp("expires", { mode: "date" }).notNull(),
}) satisfies DefaultMySqlSessionsTable)
const verificationTokensTable =
schema.verificationTokensTable ??
(mysqlTable(
"verificationToken",
{
identifier: varchar("identifier", { length: 255 }).notNull(),
token: varchar("token", { length: 255 }).notNull(),
expires: timestamp("expires", { mode: "date" }).notNull(),
},
(verficationToken) => ({
compositePk: primaryKey({
columns: [verficationToken.identifier, verficationToken.token],
}),
})
) satisfies DefaultMySqlVerificationTokenTable)
const authenticatorsTable =
schema.authenticatorsTable ??
(mysqlTable(
"authenticator",
{
credentialID: varchar("credentialID", { length: 255 })
.notNull()
.unique(),
userId: varchar("userId", { length: 255 })
.notNull()
.references(() => usersTable.id, { onDelete: "cascade" }),
providerAccountId: varchar("providerAccountId", {
length: 255,
}).notNull(),
credentialPublicKey: varchar("credentialPublicKey", {
length: 255,
}).notNull(),
counter: int("counter").notNull(),
credentialDeviceType: varchar("credentialDeviceType", {
length: 255,
}).notNull(),
credentialBackedUp: boolean("credentialBackedUp").notNull(),
transports: varchar("transports", { length: 255 }),
},
(authenticator) => ({
compositePk: primaryKey({
columns: [authenticator.userId, authenticator.credentialID],
}),
})
) satisfies DefaultMySqlAuthenticatorTable)
return {
usersTable,
accountsTable,
sessionsTable,
verificationTokensTable,
authenticatorsTable,
}
}
export function MySqlDrizzleAdapter(
client: MySqlDatabase<MySqlQueryResultHKT, PreparedQueryHKTBase, any>,
schema?: DefaultMySqlSchema
): Adapter {
const {
usersTable,
accountsTable,
sessionsTable,
verificationTokensTable,
authenticatorsTable,
} = defineTables(schema)
return {
async createUser(data: AdapterUser) {
const { id, ...insertData } = data
const hasDefaultId = getTableColumns(usersTable)["id"]["defaultFn"]
const [insertedUser] = (await client
.insert(usersTable)
.values(hasDefaultId ? insertData : { ...insertData, id })
.$returningId()) as [{ id: string }] | []
return client
.select()
.from(usersTable)
.where(eq(usersTable.id, insertedUser ? insertedUser.id : id))
.then((res) => res[0]) as Awaitable<AdapterUser>
},
async getUser(userId: string) {
return client
.select()
.from(usersTable)
.where(eq(usersTable.id, userId))
.then((res) =>
res.length > 0 ? res[0] : null
) as Awaitable<AdapterUser | null>
},
async getUserByEmail(email: string) {
return client
.select()
.from(usersTable)
.where(eq(usersTable.email, email))
.then((res) =>
res.length > 0 ? res[0] : null
) as Awaitable<AdapterUser | null>
},
async createSession(data: {
sessionToken: string
userId: string
expires: Date
}) {
await client.insert(sessionsTable).values(data)
return client
.select()
.from(sessionsTable)
.where(eq(sessionsTable.sessionToken, data.sessionToken))
.then((res) => res[0])
},
async getSessionAndUser(sessionToken: string) {
return client
.select({
session: sessionsTable,
user: usersTable,
})
.from(sessionsTable)
.where(eq(sessionsTable.sessionToken, sessionToken))
.innerJoin(usersTable, eq(usersTable.id, sessionsTable.userId))
.then((res) => (res.length > 0 ? res[0] : null)) as Awaitable<{
session: AdapterSession
user: AdapterUser
} | null>
},
async updateUser(data: Partial<AdapterUser> & Pick<AdapterUser, "id">) {
if (!data.id) {
throw new Error("No user id.")
}
await client
.update(usersTable)
.set(data)
.where(eq(usersTable.id, data.id))
const [result] = await client
.select()
.from(usersTable)
.where(eq(usersTable.id, data.id))
if (!result) {
throw new Error("No user found.")
}
return result as Awaitable<AdapterUser>
},
async updateSession(
data: Partial<AdapterSession> & Pick<AdapterSession, "sessionToken">
) {
await client
.update(sessionsTable)
.set(data)
.where(eq(sessionsTable.sessionToken, data.sessionToken))
return client
.select()
.from(sessionsTable)
.where(eq(sessionsTable.sessionToken, data.sessionToken))
.then((res) => res[0])
},
async linkAccount(data: AdapterAccount) {
await client.insert(accountsTable).values(data)
},
async getUserByAccount(
account: Pick<AdapterAccount, "provider" | "providerAccountId">
) {
const result = await client
.select({
account: accountsTable,
user: usersTable,
})
.from(accountsTable)
.innerJoin(usersTable, eq(accountsTable.userId, usersTable.id))
.where(
and(
eq(accountsTable.provider, account.provider),
eq(accountsTable.providerAccountId, account.providerAccountId)
)
)
.then((res) => res[0])
const user = result?.user ?? null
return user as Awaitable<AdapterUser | null>
},
async deleteSession(sessionToken: string) {
await client
.delete(sessionsTable)
.where(eq(sessionsTable.sessionToken, sessionToken))
},
async createVerificationToken(data: VerificationToken) {
await client.insert(verificationTokensTable).values(data)
return client
.select()
.from(verificationTokensTable)
.where(eq(verificationTokensTable.identifier, data.identifier))
.then((res) => res[0])
},
async useVerificationToken(params: { identifier: string; token: string }) {
const deletedToken = await client
.select()
.from(verificationTokensTable)
.where(
and(
eq(verificationTokensTable.identifier, params.identifier),
eq(verificationTokensTable.token, params.token)
)
)
.then((res) => (res.length > 0 ? res[0] : null))
if (deletedToken) {
await client
.delete(verificationTokensTable)
.where(
and(
eq(verificationTokensTable.identifier, params.identifier),
eq(verificationTokensTable.token, params.token)
)
)
}
return deletedToken
},
async deleteUser(id: string) {
await client.delete(usersTable).where(eq(usersTable.id, id))
},
async unlinkAccount(
params: Pick<AdapterAccount, "provider" | "providerAccountId">
) {
await client
.delete(accountsTable)
.where(
and(
eq(accountsTable.provider, params.provider),
eq(accountsTable.providerAccountId, params.providerAccountId)
)
)
},
async getAccount(providerAccountId: string, provider: string) {
return client
.select()
.from(accountsTable)
.where(
and(
eq(accountsTable.provider, provider),
eq(accountsTable.providerAccountId, providerAccountId)
)
)
.then((res) => res[0] ?? null) as Promise<AdapterAccount | null>
},
async createAuthenticator(data: AdapterAuthenticator) {
await client.insert(authenticatorsTable).values(data)
return (await client
.select()
.from(authenticatorsTable)
.where(eq(authenticatorsTable.credentialID, data.credentialID))
.then((res) => res[0] ?? null)) as Awaitable<AdapterAuthenticator>
},
async getAuthenticator(credentialID: string) {
return (await client
.select()
.from(authenticatorsTable)
.where(eq(authenticatorsTable.credentialID, credentialID))
.then(
(res) => res[0] ?? null
)) as Awaitable<AdapterAuthenticator | null>
},
async listAuthenticatorsByUserId(userId: string) {
return (await client
.select()
.from(authenticatorsTable)
.where(eq(authenticatorsTable.userId, userId))
.then((res) => res)) as Awaitable<AdapterAuthenticator[]>
},
async updateAuthenticatorCounter(credentialID: string, newCounter: number) {
await client
.update(authenticatorsTable)
.set({ counter: newCounter })
.where(eq(authenticatorsTable.credentialID, credentialID))
const authenticator = await client
.select()
.from(authenticatorsTable)
.where(eq(authenticatorsTable.credentialID, credentialID))
.then((res) => res[0])
if (!authenticator) throw new Error("Authenticator not found.")
return authenticator as Awaitable<AdapterAuthenticator>
},
}
}
type DefaultMyqlColumn<
T extends {
data: string | number | boolean | Date
dataType: "string" | "number" | "boolean" | "date"
notNull: boolean
isPrimaryKey?: boolean
columnType:
| "MySqlVarChar"
| "MySqlText"
| "MySqlBoolean"
| "MySqlTimestamp"
| "MySqlInt"
},
> = MySqlColumn<{
isAutoincrement: boolean
isPrimaryKey: T["isPrimaryKey"] extends true ? true : false
hasRuntimeDefault: boolean
generated: GeneratedColumnConfig<T["data"]> | undefined
name: string
columnType: T["columnType"]
data: T["data"]
driverParam: string | number | boolean
notNull: T["notNull"]
hasDefault: boolean
enumValues: any
dataType: T["dataType"]
tableName: string
}>
export type DefaultMySqlUsersTable = MySqlTableWithColumns<{
name: string
columns: {
id: DefaultMyqlColumn<{
isPrimaryKey: true
data: string
dataType: "string"
notNull: true
columnType: "MySqlVarChar" | "MySqlText"
}>
name: DefaultMyqlColumn<{
data: string
dataType: "string"
notNull: boolean
columnType: "MySqlVarChar" | "MySqlText"
}>
email: DefaultMyqlColumn<{
data: string
dataType: "string"
notNull: boolean
columnType: "MySqlVarChar" | "MySqlText"
}>
emailVerified: DefaultMyqlColumn<{
data: Date
dataType: "date"
notNull: boolean
columnType: "MySqlTimestamp"
}>
image: DefaultMyqlColumn<{
data: string
dataType: "string"
notNull: boolean
columnType: "MySqlVarChar" | "MySqlText"
}>
}
dialect: "mysql"
schema: string | undefined
}>
export type DefaultMySqlAccountsTable = MySqlTableWithColumns<{
name: string
columns: {
userId: DefaultMyqlColumn<{
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: true
dataType: "string"
}>
type: DefaultMyqlColumn<{
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: true
dataType: "string"
}>
provider: DefaultMyqlColumn<{
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: true
dataType: "string"
}>
providerAccountId: DefaultMyqlColumn<{
dataType: "string"
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: true
}>
refresh_token: DefaultMyqlColumn<{
dataType: "string"
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: boolean
}>
access_token: DefaultMyqlColumn<{
dataType: "string"
columnType: "MySqlVarChar" | "MySqlText"
data: string
driverParam: string | number
notNull: boolean
}>
expires_at: DefaultMyqlColumn<{
dataType: "number"
columnType: "MySqlInt"
data: number
notNull: boolean
}>
token_type: DefaultMyqlColumn<{
dataType: "string"
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: boolean
}>
scope: DefaultMyqlColumn<{
dataType: "string"
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: boolean
}>
id_token: DefaultMyqlColumn<{
dataType: "string"
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: boolean
}>
session_state: DefaultMyqlColumn<{
dataType: "string"
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: boolean
}>
}
dialect: "mysql"
schema: string | undefined
}>
export type DefaultMySqlSessionsTable = MySqlTableWithColumns<{
name: string
columns: {
sessionToken: DefaultMyqlColumn<{
isPrimaryKey: true
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: true
dataType: "string"
}>
userId: DefaultMyqlColumn<{
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: true
dataType: "string"
}>
expires: DefaultMyqlColumn<{
dataType: "date"
columnType: "MySqlTimestamp"
data: Date
notNull: true
}>
}
dialect: "mysql"
schema: string | undefined
}>
export type DefaultMySqlVerificationTokenTable = MySqlTableWithColumns<{
name: string
columns: {
identifier: DefaultMyqlColumn<{
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: true
dataType: "string"
}>
token: DefaultMyqlColumn<{
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: true
dataType: "string"
}>
expires: DefaultMyqlColumn<{
dataType: "date"
columnType: "MySqlTimestamp"
data: Date
notNull: true
}>
}
dialect: "mysql"
schema: string | undefined
}>
export type DefaultMySqlAuthenticatorTable = MySqlTableWithColumns<{
name: string
columns: {
credentialID: DefaultMyqlColumn<{
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: true
dataType: "string"
}>
userId: DefaultMyqlColumn<{
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: true
dataType: "string"
}>
providerAccountId: DefaultMyqlColumn<{
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: true
dataType: "string"
}>
credentialPublicKey: DefaultMyqlColumn<{
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: true
dataType: "string"
}>
counter: DefaultMyqlColumn<{
columnType: "MySqlInt"
data: number
notNull: true
dataType: "number"
}>
credentialDeviceType: DefaultMyqlColumn<{
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: true
dataType: "string"
}>
credentialBackedUp: DefaultMyqlColumn<{
columnType: "MySqlBoolean"
data: boolean
notNull: true
dataType: "boolean"
}>
transports: DefaultMyqlColumn<{
columnType: "MySqlVarChar" | "MySqlText"
data: string
notNull: false
dataType: "string"
}>
}
dialect: "mysql"
schema: string | undefined
}>
export type DefaultMySqlSchema = {
usersTable: DefaultMySqlUsersTable
accountsTable: DefaultMySqlAccountsTable
sessionsTable?: DefaultMySqlSessionsTable
verificationTokensTable?: DefaultMySqlVerificationTokenTable
authenticatorsTable?: DefaultMySqlAuthenticatorTable
}