UNPKG

prisma-extension-casl

Version:
418 lines (363 loc) 18.6 kB
import { AbilityBuilder, AbilityTuple, PureAbility } from '@casl/ability' import { PrismaQuery } from '@casl/prisma' import { Prisma, PrismaClient, PrismaPromise } from '@prisma/client' import { applyCaslToQuery } from './applyCaslToQuery' import { filterQueryResults } from './filterQueryResults' import { caslOperationDict, getFluentField, getFluentModel, PrismaCaslOperation, PrismaExtensionCaslOptions, propertyFieldsByModel, relationFieldsByModel } from './helpers' export { applyCaslToQuery } /** * enrich a prisma client to check for CASL abilities even in nested queries * * `client.$extends(useCaslAbilities(build))` * * https://casl.js.org/v6/en/package/casl-prisma * * * @param getAbilities function to return CASL prisma abilities * - this is a function call to instantiate abilities on each prisma query to allow adding i.e. context or claims * @returns enriched prisma client */ /** * enrich a prisma client to check for CASL abilities even in nested queries * * `client.$extends(useCaslAbilities(build))` * * https://casl.js.org/v6/en/package/casl-prisma * * * @param getAbilityFactory function to return CASL prisma abilities * - this is a function call to instantiate abilities on each prisma query to allow adding i.e. context or claims * @param opts additional options: { permissionField, additionalActions } * @returns enriched prisma client * @returns */ export function useCaslAbilities( getAbilityFactory: () => AbilityBuilder<PureAbility<AbilityTuple, PrismaQuery>>, opts?: PrismaExtensionCaslOptions) { // Set default options const txMaxWait = opts?.txMaxWait ?? 30000 const txTimeout = opts?.txTimeout ?? 30000 return Prisma.defineExtension((client) => { let tickActive = false; const batches: Record<string, Array<{ params: object; model: string; action: string; args: unknown; /** called before resolve */ callback: (result: unknown) => void; resolve: (result: unknown) => void; reject: (error: unknown) => void; }>> = {}; function extendCaslAbilities(extendFactory: (factory: AbilityBuilder<PureAbility<AbilityTuple, PrismaQuery>>) => AbilityBuilder<PureAbility<AbilityTuple, PrismaQuery>>) { // alter the getAblities function shortly const extendedClient = client.$extends({ query: { $allModels: { ...allOperations(() => extendFactory(getAbilityFactory())), }, }, }) // if we are within a transaction, return client with transaction //@ts-ignore const transactionId = Prisma.getExtensionContext(this)[Symbol.for('prisma.client.transaction.id')] if (transactionId) { //@ts-ignore const transactionClient = extendedClient._createItxClient({ kind: 'itx', id: transactionId }) as typeof extendedClient //@ts-ignore transactionClient.$casl = extendCaslAbilities // if $transaction is called on already existing transaction client, just use current transaction transactionClient.$transaction = async (first: any) => { return first(transactionClient) } return transactionClient } //@ts-ignore extendedClient.$casl = extendCaslAbilities return extendedClient } const allOperations = (getAbilities: () => AbilityBuilder<PureAbility<AbilityTuple, PrismaQuery>>) => ({ async $allOperations<T>({ args, query, model, operation, ...rest }: { args: any, query: any, model: any, operation: any }) { const { fluentModel, fluentRelationModel, fluentRelationField } = getFluentModel(model, rest) const __internalParams = (rest as any).__internalParams const transaction = __internalParams.transaction const debug = (process.env.NODE_ENV === 'development' || process.env.NODE_ENV === 'test') && args.debugCasl const debugAllErrors = args.debugCasl delete args.debugCasl const perf = debug ? performance : undefined const logger = debug ? console : undefined perf?.clearMeasures('prisma-casl-extension-Overall') perf?.clearMeasures('prisma-casl-extension-Create Abilities') perf?.clearMeasures('prisma-casl-extension-Create Casl Query') perf?.clearMeasures('prisma-casl-extension-Finish Query') perf?.clearMeasures('prisma-casl-extension-Filtering Results') perf?.clearMarks('prisma-casl-extension-0') perf?.clearMarks('prisma-casl-extension-1') perf?.clearMarks('prisma-casl-extension-2') perf?.clearMarks('prisma-casl-extension-3') perf?.clearMarks('prisma-casl-extension-4') if (!(operation in caslOperationDict)) { return query(args) } perf?.mark('prisma-casl-extension-0') const abilities = transaction?.abilities ?? getAbilities().build() if (transaction) { transaction.abilities = abilities } perf?.mark('prisma-casl-extension-1') /** * for read actions we return null, if casl has an error * except we use debugCasl */ function getCaslQuery() { try { return applyCaslToQuery(operation, args, abilities, model, opts?.permissionField ? true : false) } catch (e) { if (debugAllErrors || caslOperationDict[operation as PrismaCaslOperation].action !== 'read') { throw e } } } const caslQuery = getCaslQuery() if (fluentRelationField?.isList && !caslQuery?.args.select[fluentRelationField.name]) { return [] } if (!caslQuery) { /** if casl query did not return a result we return either null or an empty array for findMany or list relation */ if (operation === 'findMany') { return [] } else { return null } } perf?.mark('prisma-casl-extension-2') logger?.log('Query Args', JSON.stringify(caslQuery.args)) logger?.log('Query Mask', JSON.stringify(caslQuery.mask)) const cleanupResults = (result: any) => { perf?.mark('prisma-casl-extension-3') if (fluentRelationModel && caslQuery.mask) { // on fluent models we need to take mask of the relation caslQuery.mask = fluentRelationModel && fluentRelationModel in caslQuery.mask ? caslQuery.mask[fluentRelationModel] : {} } const filteredResult = filterQueryResults(result, caslQuery.mask, caslQuery.creationTree, abilities, fluentModel as Prisma.ModelName, operation, opts) if (perf) { perf.mark('prisma-casl-extension-4') logger?.log( [perf.measure('prisma-casl-extension-Overall', 'prisma-casl-extension-0', 'prisma-casl-extension-4'), perf.measure('prisma-casl-extension-Create Abilities', 'prisma-casl-extension-0', 'prisma-casl-extension-1'), perf.measure('prisma-casl-extension-Create Casl Query', 'prisma-casl-extension-1', 'prisma-casl-extension-2'), perf.measure('prisma-casl-extension-Finish Query', 'prisma-casl-extension-2', 'prisma-casl-extension-3'), perf.measure('prisma-casl-extension-Filtering Results', 'prisma-casl-extension-3', 'prisma-casl-extension-4') ].map((measure) => { return `${measure.name.replace('prisma-casl-extension-', '')}: ${measure.duration}` }) ) } return filteredResult } const operationAbility = caslOperationDict[operation as PrismaCaslOperation] /** * on update or create we need to create a transaction * since there can be errors if newly created db entries * are not permitted by abilities * * for reads and deletes we skip the transaction */ if (transaction && transaction.kind === 'batch') { //@ts-ignore throw new Error('Sequential transactions are not supported in prisma-extension-casl.') // const extendedRequest = request.then(cleanupResults) // extendedRequest.requestTransaction = request.requestTransaction //@ts-ignore // return client._createPrismaPromise(new Promise((resolve, reject) => { // query(caslQuery.args).then(cleanupResults).then((result: any) => resolve(result)).catch(((e: any) => reject(e))) // }) } const hash = transaction?.id ?? 'batch' if (!batches[hash]) { batches[hash] = [] } // make sure, that we only tick once at a time if (!tickActive) { tickActive = true; process.nextTick(() => { dispatchBatches(transaction); tickActive = false; }); } /** batchQuery collects query within batches that will be dispatched every tick */ const batchQuery = ( model: string, action: string, args: any, callback: (result: any) => void ) => new Promise((resolve, reject) => { batches[hash].push({ params: __internalParams, model, action, args, reject, resolve, callback, }) }); if (operationAbility.action === 'update' || operationAbility.action === 'create' || operation === 'deleteMany') { /** * we get all update/deleteMany entries for logging purposes. */ // const getMany = operation === 'deleteMany' || operation === 'updateMany' // const manyResult: any[] = getMany ? await batchQuery(model, 'findMany', caslQuery.args.where ? { where: caslQuery.args.where } : undefined, (res: any[]) => { // /** create update objects for updateMany */ // return operation === 'updateMany' ? res.map((r) => ({ ...caslQuery.args.data, id: r.id })) : res // }) : [] /** * we use createManyAndReturn instead of createMany createMany entries for logging purposes and to check permissions on new entries */ const op = operation === 'createMany' ? 'createManyAndReturn' : operation === 'updateMany' ? 'updateManyAndReturn' : operation return batchQuery(model, op, caslQuery.args, async (result: any) => { const filteredResult = cleanupResults(result)//getMany ? manyResult : result) const results = operation === 'createMany' || operation === 'deleteMany' || operation === 'updateMany' ? { count: result.length } // : getMany ? { count: manyResult.length } : filteredResult return results }) } else { return batchQuery(model, operation, caslQuery.args, async (result: any) => { const fluentField = getFluentField(rest) if (fluentField) { return cleanupResults(result?.[fluentField]) } return cleanupResults(result) }) } } }) // Derived from yates // https://github.com/cerebruminc/yates/blob/master/src/index.ts#L227 // // By default, Prisma will batch requests by the transaction ID if it is present. // This behaviour prevents automatic batching from working when using this client extension, since all queries are executed inside an interactive transaction. // To get around this we monkey patch the batching function to use the batch ID and transaction ID. // To get the batching to work we also need to ensure that all the requests we might want to batch together are generated inside the same tick. // This means that all the requests per-tick that have the same role and context values will be batched together, // allowing the in-built prisma batch optimizations to work for us. // This is why we use process.nextTick and the tickActive flag to ensure we only tick once at a time. // See: // - https://github.com/prisma/prisma/blob/5.21.1/packages/client/src/runtime/RequestHandler.ts#L122 // - https://www.prisma.io/docs/orm/prisma-client/queries/query-optimization-performance ; (client as any)._requestHandler.dataloader.options.batchBy = ( request: any, ) => { const batchId = getBatchId(request.protocolQuery); if (request.transaction?.id) { return `transaction-${request.transaction.id}${batchId ? `-${batchId}` : "" }`; } return batchId }; /** * Derived from yates * https://github.com/cerebruminc/yates/blob/master/src/index.ts#L227 * * This function is called once per tick, and processes all the batches that have been created during that tick. * If the batch happened within an existing transaction, we use it to recreate its client, so we keep its interactve transaction logic **/ const dispatchBatches = (transaction?: { kind: 'itx' | 'batch' }) => { for (const [key, batch] of Object.entries(batches)) { delete batches[key]; const runBatchTransaction = async (tx: any) => { if (opts?.beforeQuery) { await opts.beforeQuery(tx as any) } const results = await Promise.all( batch.map((request: any) => { //@ts-ignore return tx[request.model][request.action](request.args).then((res) => request.callback(res)) .catch((e: Error) => { throw (e) }) }), ); // Switch role back to admin user if (opts?.afterQuery) { await opts?.afterQuery(tx as any) } return results; } new Promise((resolve, reject) => { if (transaction && transaction.kind === 'itx') { runBatchTransaction((client as any)._createItxClient(transaction)).then(resolve).catch(reject) } else { client.$transaction(async (tx) => { return runBatchTransaction(tx); }, { maxWait: txMaxWait, timeout: txTimeout, }).then(resolve).catch(reject) } }).then((results: any) => { results.forEach((result: any, index: number) => { batch[index].resolve(result); }); }) .catch((e) => { for (const request of batch) { request.reject(e); } delete batches[key] }) } }; return client.$extends({ name: "prisma-extension-casl", client: { $casl: extendCaslAbilities }, query: { $allModels: { ...allOperations(getAbilityFactory) }, } }) }) } /** * recreates getBatchId from prisma * //https://github.com/prisma/prisma/blob/1a9ef0fbd3948ee708add6816a33743e1ff7df9c/packages/client/src/runtime/core/jsonProtocol/getBatchId.ts#L4 * * @param query * @returns */ export function getBatchId(query: any): string | undefined { if (query.action !== "findUnique" && query.action !== "findUniqueOrThrow") { return undefined; } const parts: string[] = []; if (query.modelName) { parts.push(query.modelName); } if (query.query.arguments) { parts.push(buildKeysString(query.query.arguments)); } parts.push(buildKeysString(query.query.selection)); return parts.join(""); } function buildKeysString(obj: object): string { const keysArray = Object.keys(obj) .sort() .map((key) => { // @ts-ignore const value = obj[key]; if (typeof value === "object" && value !== null) { return `(${key} ${buildKeysString(value)})`; } return key; }); return `(${keysArray.join(" ")})`; }