UNPKG

@convex-dev/better-auth

Version:
744 lines (718 loc) 21.1 kB
import { mutation, query, QueryCtx } from "../component/_generated/server"; import { asyncMap } from "convex-helpers"; import { Infer, v } from "convex/values"; import { Doc, Id, TableNames } from "../component/_generated/dataModel"; import schema, { specialFields } from "../component/schema"; import { PaginationOptions, paginationOptsValidator, PaginationResult, } from "convex/server"; import { partial } from "convex-helpers/validators"; import { stream } from "convex-helpers/server/stream"; import { mergedStream } from "convex-helpers/server/stream"; import { stripIndent } from "common-tags"; export const adapterWhereValidator = v.object({ field: v.string(), operator: v.optional( v.union( v.literal("lt"), v.literal("lte"), v.literal("gt"), v.literal("gte"), v.literal("eq"), v.literal("in"), v.literal("ne"), v.literal("contains"), v.literal("starts_with"), v.literal("ends_with") ) ), value: v.union( v.string(), v.number(), v.boolean(), v.array(v.string()), v.array(v.number()), v.null() ), connector: v.optional(v.union(v.literal("AND"), v.literal("OR"))), }); export const adapterArgsValidator = v.object({ model: v.string(), where: v.optional(v.array(adapterWhereValidator)), sortBy: v.optional( v.object({ field: v.string(), direction: v.union(v.literal("asc"), v.literal("desc")), }) ), select: v.optional(v.array(v.string())), limit: v.optional(v.number()), offset: v.optional(v.number()), unique: v.optional(v.boolean()), }); const isUniqueField = (model: TableNames, field: string) => { const fields = specialFields[model as keyof typeof specialFields]; if (!fields) { return false; } return Object.entries(fields) .filter(([, value]) => value.unique) .map(([key]) => key) .includes(field); }; const hasUniqueFields = (model: TableNames, input: Record<string, any>) => { for (const field of Object.keys(input)) { if (isUniqueField(model, field)) { return true; } } return false; }; const findIndex = (args: { model: string; where?: { field: string; operator?: | "lt" | "lte" | "gt" | "gte" | "eq" | "in" | "ne" | "contains" | "starts_with" | "ends_with"; value: string | number | boolean | null | string[] | number[]; connector?: "AND" | "OR"; }[]; sortBy?: { field: string; direction: "asc" | "desc"; }; }) => { if ( (args.where?.length ?? 0) > 1 && args.where?.some((w) => w.connector === "OR") ) { throw new Error( `OR connector not supported with multiple where statements in findIndex, split up the where statements before calling findIndex: ${JSON.stringify(args.where)}` ); } const where = args.where?.filter((w) => { return ( (!w.operator || ["lt", "lte", "gt", "gte", "eq", "in"].includes(w.operator)) && w.field !== "id" ); }); if (!where?.length && !args.sortBy) { return; } const lowerBounds = where?.filter((w) => w.operator === "lt" || w.operator === "lte") ?? []; if (lowerBounds.length > 1) { throw new Error( `cannot have more than one lower bound where clause: ${JSON.stringify(where)}` ); } const upperBounds = where?.filter((w) => w.operator === "gt" || w.operator === "gte") ?? []; if (upperBounds.length > 1) { throw new Error( `cannot have more than one upper bound where clause: ${JSON.stringify(where)}` ); } const lowerBound = lowerBounds[0]; const upperBound = upperBounds[0]; if (lowerBound && upperBound && lowerBound.field !== upperBound.field) { throw new Error( `lower bound and upper bound must have the same field: ${JSON.stringify(where)}` ); } const boundField = lowerBound?.field || upperBound?.field; if ( boundField && where?.some( (w) => w.field === boundField && w !== lowerBound && w !== upperBound ) ) { throw new Error( `too many where clauses on the bound field: ${JSON.stringify(where)}` ); } const indexEqFields = where ?.filter((w) => !w.operator || w.operator === "eq") .sort((a, b) => { return a.field.localeCompare(b.field); }) .map((w) => [w.field, w.value]) ?? []; if (!indexEqFields?.length && !boundField && !args.sortBy) { return; } const indexes = schema.tables[args.model as keyof typeof schema.tables][" indexes"](); const sortField = args.sortBy?.field; // We internally use _creationTime in place of Better Auth's createdAt const indexFields = indexEqFields .map(([field]) => field) .concat( boundField && boundField !== "createdAt" ? `${indexEqFields.length ? "_" : ""}${boundField}` : "" ) .concat( sortField && sortField !== "createdAt" && boundField !== sortField ? `${indexEqFields.length || boundField ? "_" : ""}${sortField}` : "" ) .filter(Boolean); if (!indexFields.length && !boundField && !sortField) { return; } // Use the built in _creationTime index if bounding or sorting by createdAt // with no other fields const index = !indexFields.length ? { indexDescriptor: "by_creation_time", fields: [], } : indexes.find(({ fields }) => { const fieldsMatch = indexFields.every( (field, idx) => field === fields[idx] ); // If sorting by createdAt, no intermediate fields can be on the index // as they may override the createdAt sort order. const boundFieldMatch = boundField === "createdAt" || sortField === "createdAt" ? indexFields.length === fields.length : true; return fieldsMatch && boundFieldMatch; }); if (!index) { return { indexFields }; } return { index: { indexDescriptor: index.indexDescriptor, fields: [...index.fields, "_creationTime"], }, boundField, sortField, values: { eq: indexEqFields.map(([, value]) => value), lt: lowerBound?.operator === "lt" ? lowerBound.value : undefined, lte: lowerBound?.operator === "lte" ? lowerBound.value : undefined, gt: upperBound?.operator === "gt" ? upperBound.value : undefined, gte: upperBound?.operator === "gte" ? upperBound.value : undefined, }, }; }; const checkUniqueFields = async ( ctx: QueryCtx, table: TableNames, input: Record<string, any>, doc?: Doc<any> ) => { if (!hasUniqueFields(table, input)) { return; } for (const field of Object.keys(input)) { if (!isUniqueField(table, field)) { continue; } const { index } = findIndex({ model: table, where: [ { field, operator: "eq", value: input[field as keyof typeof input] }, ], }) || {}; if (!index) { throw new Error(`No index found for ${table}${field}`); } const existingDoc = await ctx.db .query(table as any) .withIndex(index.indexDescriptor, (q) => q.eq(field, input[field as keyof typeof input]) ) .unique(); if (existingDoc && existingDoc._id !== doc?._id) { throw new Error(`${table} ${field} already exists`); } } }; const selectFields = <T extends TableNames, D extends Doc<T>>( doc: D | null, select?: string[] ) => { if (!doc) { return null; } if (!select?.length) { return doc; } return select.reduce((acc, field) => { (acc as any)[field] = doc[field]; return acc; }, {} as D); }; // Manually filter an individual document by where clauses. This is used to // simplify queries that can only return 0 or 1 documents, or "in" clauses that // query multiple single documents in parallel. const filterByWhere = ( doc: Doc<any>, where?: Infer<typeof adapterWhereValidator>[], filterWhere?: (w: Infer<typeof adapterWhereValidator>) => any ) => { if (!doc) { return false; } for (const w of where ?? []) { if (filterWhere && !filterWhere(w)) { continue; } const value = doc[w.field as keyof typeof doc] as Infer< typeof adapterWhereValidator >["value"]; const isLessThan = (val: typeof value, wVal: typeof w.value) => { if (!wVal) { return false; } if (!val) { return true; } return val < wVal; }; const isGreaterThan = (val: typeof value, wVal: typeof w.value) => { if (!val) { return false; } if (!wVal) { return true; } return val > wVal; }; switch (w.operator) { case undefined: case "eq": { return value === w.value; } case "in": { return Array.isArray(w.value) && (w.value as any[]).includes(value); } case "lt": { return isLessThan(value, w.value); } case "lte": { return value === w.value || isLessThan(value, w.value); } case "gt": { return isGreaterThan(value, w.value); } case "gte": { return value === w.value || isGreaterThan(value, w.value); } case "ne": { return value !== w.value; } case "contains": { return typeof value === "string" && value.includes(w.value as string); } case "starts_with": { return typeof value === "string" && value.startsWith(w.value as string); } case "ends_with": { return typeof value === "string" && value.endsWith(w.value as string); } } } return true; }; const generateQuery = ( ctx: QueryCtx, args: Infer<typeof adapterArgsValidator> ) => { const { index, values, boundField, indexFields } = findIndex(args) ?? {}; const query = stream(ctx.db, schema).query(args.model as any); const hasValues = values?.eq?.length || values?.lt || values?.lte || values?.gt || values?.gte; const indexedQuery = index && index.indexDescriptor !== "by_creation_time" ? query.withIndex( index.indexDescriptor, hasValues ? (q: any) => { for (const [idx, value] of (values?.eq ?? []).entries()) { q = q.eq(index.fields[idx], value); } if (values?.lt) { q = q.lt(boundField, values.lt); } if (values?.lte) { q = q.lte(boundField, values.lte); } if (values?.gt) { q = q.gt(boundField, values.gt); } if (values?.gte) { q = q.gte(boundField, values.gte); } return q; } : undefined ) : query; const orderedQuery = args.sortBy ? indexedQuery.order(args.sortBy.direction === "asc" ? "asc" : "desc") : indexedQuery; const filteredQuery = orderedQuery.filterWith(async (doc) => { if (!index && indexFields?.length) { console.warn( stripIndent` Querying without an index on table "${args.model}". This can cause performance issues, and may hit the document read limit. To fix, add an index that begins with the following fields in order: [${indexFields.join(", ")}] ` ); return filterByWhere(doc, args.where); } return filterByWhere( doc, args.where, (w) => w.operator && ["contains", "starts_with", "ends_with", "ne"].includes(w.operator) ); }); return filteredQuery; }; // This is the core function for reading from the database, it parses and // validates where conditions, selects indexes, and allows the caller to // optionally paginate as needed. Every response is a pagination result. const paginate = async ( ctx: QueryCtx, args: Infer<typeof adapterArgsValidator> & { paginationOpts: PaginationOptions; } ): Promise<PaginationResult<Doc<any>>> => { if (args.offset) { throw new Error(`offset not supported: ${JSON.stringify(args.offset)}`); } if (args.where?.some((w) => w.connector === "OR") && args.where?.length > 1) { throw new Error( `OR connector not supported with multiple where statements in paginate, split up the where statements before calling paginate: ${JSON.stringify(args.where)}` ); } if ( args.where?.some( (w) => w.field === "id" && w.operator && !["eq", "in"].includes(w.operator) ) ) { throw new Error( `id can only be used with eq or in operator: ${JSON.stringify(args.where)}` ); } // If any where clause is "eq" (or missing operator) on a unique field, // we can only return a single document, so we get it and use any other // where clauses as static filters. const uniqueWhere = args.where?.find( (w) => (!w.operator || w.operator === "eq") && (isUniqueField(args.model as TableNames, w.field) || w.field === "id") ); if (uniqueWhere) { const { index } = findIndex({ model: args.model, where: [uniqueWhere], }) || {}; const doc = uniqueWhere.field === "id" ? await ctx.db.get(uniqueWhere.value as Id<TableNames>) : await ctx.db .query(args.model as any) .withIndex(index?.indexDescriptor as any, (q) => q.eq(index?.fields[0], uniqueWhere.value) ) .unique(); if (filterByWhere(doc, args.where, (w) => w !== uniqueWhere)) { return { page: [selectFields(doc, args.select)].filter(Boolean), isDone: true, continueCursor: "", }; } return { page: [], isDone: true, continueCursor: "", }; } const paginationOpts = { ...args.paginationOpts, // If maximumRowsRead is not at least 1 higher than numItems, bad cursors // and incorrect paging will result (at least with convex-test). maximumRowsRead: Math.max((args.paginationOpts.numItems ?? 0) + 1, 200), }; // Large queries using "in" clause will crash, but these are only currently // possible with the organization plugin listing all members with a high // limit. For cases like this we need to create proper convex queries in // the component as an alternative to using Better Auth api's. const inWhere = args.where?.find((w) => w.operator === "in"); if (inWhere) { if (!Array.isArray(inWhere.value)) { throw new Error("in clause value must be an array"); } // For ids, just use asyncMap + .get() if (inWhere.field === "id") { const docs = await asyncMap(inWhere.value as any[], async (value) => { return ctx.db.get(value as Id<TableNames>); }); const filteredDocs = docs .flatMap((doc) => doc || []) .filter((doc) => filterByWhere(doc, args.where, (w) => w !== inWhere)); return { page: filteredDocs.sort((a, b) => { if (args.sortBy?.field === "createdAt") { return args.sortBy.direction === "asc" ? a._creationTime - b._creationTime : b._creationTime - a._creationTime; } if (args.sortBy) { const aValue = a[args.sortBy.field as keyof typeof a]; const bValue = b[args.sortBy.field as keyof typeof b]; if (aValue === bValue) { return 0; } return args.sortBy.direction === "asc" ? aValue > bValue ? 1 : -1 : aValue > bValue ? -1 : 1; } return 0; }), isDone: true, continueCursor: "", }; } const streams = inWhere.value.map((value) => { return generateQuery(ctx, { ...args, where: args.where?.map((w) => { if (w === inWhere) { return { ...w, operator: "eq", value }; } return w; }), }); }); const result = await mergedStream( streams, [ args.sortBy?.field !== "createdAt" && args.sortBy?.field, "_creationTime", ].flatMap((f) => (f ? [f] : [])) ) .filterWith(async (doc) => filterByWhere( doc, args.where, (w) => w.operator && ["contains", "starts_with", "ends_with", "ne"].includes(w.operator) ) ) .paginate(paginationOpts); return { ...result, page: result.page.map((doc) => selectFields(doc, args.select)), }; } const query = generateQuery(ctx, args); const result = await query.paginate(paginationOpts); return { ...result, page: result.page.map((doc) => selectFields(doc, args.select)), }; }; const listOne = async ( ctx: QueryCtx, args: Infer<typeof adapterArgsValidator> ): Promise<Doc<any> | null> => { return ( await paginate(ctx, { ...args, paginationOpts: { numItems: 1, cursor: null, }, }) ).page[0]; }; export const create = mutation({ args: { input: v.union( ...Object.entries(schema.tables).map(([model, table]) => v.object({ model: v.literal(model), data: v.object(table.validator.fields), }) ) ), }, handler: async (ctx, args) => { await checkUniqueFields( ctx, args.input.model as TableNames, args.input.data ); const id = await ctx.db.insert(args.input.model as any, args.input.data); const doc = await ctx.db.get(id); if (!doc) { throw new Error(`Failed to create ${args.input.model}`); } return doc; }, }); export const findOne = query({ args: adapterArgsValidator, handler: async (ctx, args) => { return await listOne(ctx, args); }, }); export const findMany = query({ args: { ...adapterArgsValidator.fields, paginationOpts: paginationOptsValidator, }, handler: async (ctx, args) => { return await paginate(ctx, args); }, }); export const updateOne = mutation({ args: { input: v.union( ...Object.entries(schema.tables).map(([model, table]) => v.object({ model: v.literal(model), where: v.optional(v.array(adapterWhereValidator)), update: v.object(partial(table.validator.fields)), }) ) ), }, handler: async (ctx, args) => { const doc = await listOne(ctx, args.input); if (!doc) { throw new Error(`Failed to update ${args.input.model}`); } await checkUniqueFields( ctx, args.input.model as TableNames, args.input.update, doc ); await ctx.db.patch(doc._id, args.input.update as any); const updatedDoc = await ctx.db.get(doc._id); if (!updatedDoc) { throw new Error(`Failed to update ${args.input.model}`); } return updatedDoc; }, }); export const updateMany = mutation({ args: { input: v.union( ...Object.entries(schema.tables).map(([model, table]) => v.object({ ...adapterArgsValidator.fields, model: v.literal(model), where: v.optional(v.array(adapterWhereValidator)), update: v.object(partial(table.validator.fields)), paginationOpts: paginationOptsValidator, }) ) ), }, handler: async (ctx, args) => { const { page, ...result } = await paginate(ctx, args.input); if (args.input.update) { if ( hasUniqueFields( args.input.model as TableNames, args.input.update ?? {} ) && page.length > 1 ) { throw new Error( `Attempted to set unique fields in multiple documents in ${args.input.model} with the same value. Fields: ${Object.keys(args.input.update ?? {}).join(", ")}` ); } await asyncMap(page, async (doc) => { await checkUniqueFields( ctx, args.input.model as TableNames, args.input.update ?? {}, doc ); await ctx.db.patch(doc._id, args.input.update as any); }); } return { ...result, count: page.length, }; }, }); export const deleteOne = mutation({ args: adapterArgsValidator, handler: async (ctx, args) => { const doc = await listOne(ctx, args); if (!doc) { return; } await ctx.db.delete(doc._id); return doc; }, }); export const deleteMany = mutation({ args: { ...adapterArgsValidator.fields, paginationOpts: paginationOptsValidator, }, handler: async (ctx, args) => { const { page, ...result } = await paginate(ctx, args); await asyncMap(page, async (doc) => { await ctx.db.delete(doc._id); }); return { ...result, count: page.length, }; }, }); // Get the session via sessionId in jwt claims // TODO: this needs a refresh, subquery only necessary for actions export const getCurrentSession = query({ args: {}, handler: async (ctx) => { const identity = await ctx.auth.getUserIdentity(); if (!identity) { return null; } return ctx.db.get(identity.sessionId as Id<"session">); }, });