prisma-extension-casl
Version:
Enforce casl abilities on prisma client
418 lines (363 loc) • 18.6 kB
text/typescript
import { AbilityBuilder, AbilityTuple, PureAbility } from '@casl/ability'
import { PrismaQuery } from '@casl/prisma'
import { Prisma, PrismaClient, PrismaPromise } from '@prisma/client'
import { applyCaslToQuery } from './applyCaslToQuery'
import { filterQueryResults } from './filterQueryResults'
import { caslOperationDict, getFluentField, getFluentModel, PrismaCaslOperation, PrismaExtensionCaslOptions, propertyFieldsByModel, relationFieldsByModel } from './helpers'
export { applyCaslToQuery }
/**
* enrich a prisma client to check for CASL abilities even in nested queries
*
* `client.$extends(useCaslAbilities(build))`
*
* https://casl.js.org/v6/en/package/casl-prisma
*
*
* @param getAbilities function to return CASL prisma abilities
* - this is a function call to instantiate abilities on each prisma query to allow adding i.e. context or claims
* @returns enriched prisma client
*/
/**
* enrich a prisma client to check for CASL abilities even in nested queries
*
* `client.$extends(useCaslAbilities(build))`
*
* https://casl.js.org/v6/en/package/casl-prisma
*
*
* @param getAbilityFactory function to return CASL prisma abilities
* - this is a function call to instantiate abilities on each prisma query to allow adding i.e. context or claims
* @param opts additional options: { permissionField, additionalActions }
* @returns enriched prisma client
* @returns
*/
export function useCaslAbilities(
getAbilityFactory: () => AbilityBuilder<PureAbility<AbilityTuple, PrismaQuery>>,
opts?: PrismaExtensionCaslOptions) {
// Set default options
const txMaxWait = opts?.txMaxWait ?? 30000
const txTimeout = opts?.txTimeout ?? 30000
return Prisma.defineExtension((client) => {
let tickActive = false;
const batches: Record<string, Array<{
params: object;
model: string;
action: string;
args: unknown;
/** called before resolve */
callback: (result: unknown) => void;
resolve: (result: unknown) => void;
reject: (error: unknown) => void;
}>> = {};
function extendCaslAbilities(extendFactory: (factory: AbilityBuilder<PureAbility<AbilityTuple, PrismaQuery>>) => AbilityBuilder<PureAbility<AbilityTuple, PrismaQuery>>) {
// alter the getAblities function shortly
const extendedClient = client.$extends({
query: {
$allModels: {
...allOperations(() => extendFactory(getAbilityFactory())),
},
},
})
// if we are within a transaction, return client with transaction
//@ts-ignore
const transactionId = Prisma.getExtensionContext(this)[Symbol.for('prisma.client.transaction.id')]
if (transactionId) {
//@ts-ignore
const transactionClient = extendedClient._createItxClient({
kind: 'itx',
id: transactionId
}) as typeof extendedClient
//@ts-ignore
transactionClient.$casl = extendCaslAbilities
// if $transaction is called on already existing transaction client, just use current transaction
transactionClient.$transaction = async (first: any) => {
return first(transactionClient)
}
return transactionClient
}
//@ts-ignore
extendedClient.$casl = extendCaslAbilities
return extendedClient
}
const allOperations = (getAbilities: () => AbilityBuilder<PureAbility<AbilityTuple, PrismaQuery>>) => ({
async $allOperations<T>({ args, query, model, operation, ...rest }: { args: any, query: any, model: any, operation: any }) {
const { fluentModel, fluentRelationModel, fluentRelationField } = getFluentModel(model, rest)
const __internalParams = (rest as any).__internalParams
const transaction = __internalParams.transaction
const debug = (process.env.NODE_ENV === 'development' || process.env.NODE_ENV === 'test') && args.debugCasl
const debugAllErrors = args.debugCasl
delete args.debugCasl
const perf = debug ? performance : undefined
const logger = debug ? console : undefined
perf?.clearMeasures('prisma-casl-extension-Overall')
perf?.clearMeasures('prisma-casl-extension-Create Abilities')
perf?.clearMeasures('prisma-casl-extension-Create Casl Query')
perf?.clearMeasures('prisma-casl-extension-Finish Query')
perf?.clearMeasures('prisma-casl-extension-Filtering Results')
perf?.clearMarks('prisma-casl-extension-0')
perf?.clearMarks('prisma-casl-extension-1')
perf?.clearMarks('prisma-casl-extension-2')
perf?.clearMarks('prisma-casl-extension-3')
perf?.clearMarks('prisma-casl-extension-4')
if (!(operation in caslOperationDict)) {
return query(args)
}
perf?.mark('prisma-casl-extension-0')
const abilities = transaction?.abilities ?? getAbilities().build()
if (transaction) {
transaction.abilities = abilities
}
perf?.mark('prisma-casl-extension-1')
/**
* for read actions we return null, if casl has an error
* except we use debugCasl
*/
function getCaslQuery() {
try {
return applyCaslToQuery(operation, args, abilities, model, opts?.permissionField ? true : false)
}
catch (e) {
if (debugAllErrors || caslOperationDict[operation as PrismaCaslOperation].action !== 'read') {
throw e
}
}
}
const caslQuery = getCaslQuery()
if (fluentRelationField?.isList && !caslQuery?.args.select[fluentRelationField.name]) {
return []
}
if (!caslQuery) {
/** if casl query did not return a result we return either null or an empty array for findMany or list relation */
if (operation === 'findMany') {
return []
} else {
return null
}
}
perf?.mark('prisma-casl-extension-2')
logger?.log('Query Args', JSON.stringify(caslQuery.args))
logger?.log('Query Mask', JSON.stringify(caslQuery.mask))
const cleanupResults = (result: any) => {
perf?.mark('prisma-casl-extension-3')
if (fluentRelationModel && caslQuery.mask) {
// on fluent models we need to take mask of the relation
caslQuery.mask = fluentRelationModel && fluentRelationModel in caslQuery.mask ? caslQuery.mask[fluentRelationModel] : {}
}
const filteredResult = filterQueryResults(result, caslQuery.mask, caslQuery.creationTree, abilities, fluentModel as Prisma.ModelName, operation, opts)
if (perf) {
perf.mark('prisma-casl-extension-4')
logger?.log(
[perf.measure('prisma-casl-extension-Overall', 'prisma-casl-extension-0', 'prisma-casl-extension-4'),
perf.measure('prisma-casl-extension-Create Abilities', 'prisma-casl-extension-0', 'prisma-casl-extension-1'),
perf.measure('prisma-casl-extension-Create Casl Query', 'prisma-casl-extension-1', 'prisma-casl-extension-2'),
perf.measure('prisma-casl-extension-Finish Query', 'prisma-casl-extension-2', 'prisma-casl-extension-3'),
perf.measure('prisma-casl-extension-Filtering Results', 'prisma-casl-extension-3', 'prisma-casl-extension-4')
].map((measure) => {
return `${measure.name.replace('prisma-casl-extension-', '')}: ${measure.duration}`
})
)
}
return filteredResult
}
const operationAbility = caslOperationDict[operation as PrismaCaslOperation]
/**
* on update or create we need to create a transaction
* since there can be errors if newly created db entries
* are not permitted by abilities
*
* for reads and deletes we skip the transaction
*/
if (transaction && transaction.kind === 'batch') {
//@ts-ignore
throw new Error('Sequential transactions are not supported in prisma-extension-casl.')
// const extendedRequest = request.then(cleanupResults)
// extendedRequest.requestTransaction = request.requestTransaction
//@ts-ignore
// return client._createPrismaPromise(new Promise((resolve, reject) => {
// query(caslQuery.args).then(cleanupResults).then((result: any) => resolve(result)).catch(((e: any) => reject(e)))
// })
}
const hash = transaction?.id ?? 'batch'
if (!batches[hash]) {
batches[hash] = []
}
// make sure, that we only tick once at a time
if (!tickActive) {
tickActive = true;
process.nextTick(() => {
dispatchBatches(transaction);
tickActive = false;
});
}
/** batchQuery collects query within batches that will be dispatched every tick */
const batchQuery = (
model: string,
action: string,
args: any,
callback: (result: any) => void
) => new Promise((resolve, reject) => {
batches[hash].push({
params: __internalParams,
model,
action,
args,
reject,
resolve,
callback,
})
});
if (operationAbility.action === 'update' || operationAbility.action === 'create' || operation === 'deleteMany') {
/**
* we get all update/deleteMany entries for logging purposes.
*/
// const getMany = operation === 'deleteMany' || operation === 'updateMany'
// const manyResult: any[] = getMany ? await batchQuery(model, 'findMany', caslQuery.args.where ? { where: caslQuery.args.where } : undefined, (res: any[]) => {
// /** create update objects for updateMany */
// return operation === 'updateMany' ? res.map((r) => ({ ...caslQuery.args.data, id: r.id })) : res
// }) : []
/**
* we use createManyAndReturn instead of createMany createMany entries for logging purposes and to check permissions on new entries
*/
const op = operation === 'createMany' ? 'createManyAndReturn' : operation === 'updateMany' ? 'updateManyAndReturn' : operation
return batchQuery(model, op, caslQuery.args, async (result: any) => {
const filteredResult = cleanupResults(result)//getMany ? manyResult : result)
const results = operation === 'createMany' || operation === 'deleteMany' || operation === 'updateMany'
? { count: result.length }
// : getMany ? { count: manyResult.length }
: filteredResult
return results
})
} else {
return batchQuery(model, operation, caslQuery.args, async (result: any) => {
const fluentField = getFluentField(rest)
if (fluentField) {
return cleanupResults(result?.[fluentField])
}
return cleanupResults(result)
})
}
}
})
// Derived from yates
// https://github.com/cerebruminc/yates/blob/master/src/index.ts#L227
//
// By default, Prisma will batch requests by the transaction ID if it is present.
// This behaviour prevents automatic batching from working when using this client extension, since all queries are executed inside an interactive transaction.
// To get around this we monkey patch the batching function to use the batch ID and transaction ID.
// To get the batching to work we also need to ensure that all the requests we might want to batch together are generated inside the same tick.
// This means that all the requests per-tick that have the same role and context values will be batched together,
// allowing the in-built prisma batch optimizations to work for us.
// This is why we use process.nextTick and the tickActive flag to ensure we only tick once at a time.
// See:
// - https://github.com/prisma/prisma/blob/5.21.1/packages/client/src/runtime/RequestHandler.ts#L122
// - https://www.prisma.io/docs/orm/prisma-client/queries/query-optimization-performance
; (client as any)._requestHandler.dataloader.options.batchBy = (
request: any,
) => {
const batchId = getBatchId(request.protocolQuery);
if (request.transaction?.id) {
return `transaction-${request.transaction.id}${batchId ? `-${batchId}` : ""
}`;
}
return batchId
};
/**
* Derived from yates
* https://github.com/cerebruminc/yates/blob/master/src/index.ts#L227
*
* This function is called once per tick, and processes all the batches that have been created during that tick.
* If the batch happened within an existing transaction, we use it to recreate its client, so we keep its interactve transaction logic
**/
const dispatchBatches = (transaction?: { kind: 'itx' | 'batch' }) => {
for (const [key, batch] of Object.entries(batches)) {
delete batches[key];
const runBatchTransaction = async (tx: any) => {
if (opts?.beforeQuery) {
await opts.beforeQuery(tx as any)
}
const results = await Promise.all(
batch.map((request: any) => {
//@ts-ignore
return tx[request.model][request.action](request.args).then((res) => request.callback(res))
.catch((e: Error) => {
throw (e)
})
}),
);
// Switch role back to admin user
if (opts?.afterQuery) {
await opts?.afterQuery(tx as any)
}
return results;
}
new Promise((resolve, reject) => {
if (transaction && transaction.kind === 'itx') {
runBatchTransaction((client as any)._createItxClient(transaction)).then(resolve).catch(reject)
} else {
client.$transaction(async (tx) => {
return runBatchTransaction(tx);
}, {
maxWait: txMaxWait,
timeout: txTimeout,
}).then(resolve).catch(reject)
}
}).then((results: any) => {
results.forEach((result: any, index: number) => {
batch[index].resolve(result);
});
})
.catch((e) => {
for (const request of batch) {
request.reject(e);
}
delete batches[key]
})
}
};
return client.$extends({
name: "prisma-extension-casl",
client: {
$casl: extendCaslAbilities
},
query: {
$allModels: {
...allOperations(getAbilityFactory)
},
}
})
})
}
/**
* recreates getBatchId from prisma
* //https://github.com/prisma/prisma/blob/1a9ef0fbd3948ee708add6816a33743e1ff7df9c/packages/client/src/runtime/core/jsonProtocol/getBatchId.ts#L4
*
* @param query
* @returns
*/
export function getBatchId(query: any): string | undefined {
if (query.action !== "findUnique" && query.action !== "findUniqueOrThrow") {
return undefined;
}
const parts: string[] = [];
if (query.modelName) {
parts.push(query.modelName);
}
if (query.query.arguments) {
parts.push(buildKeysString(query.query.arguments));
}
parts.push(buildKeysString(query.query.selection));
return parts.join("");
}
function buildKeysString(obj: object): string {
const keysArray = Object.keys(obj)
.sort()
.map((key) => {
// @ts-ignore
const value = obj[key];
if (typeof value === "object" && value !== null) {
return `(${key} ${buildKeysString(value)})`;
}
return key;
});
return `(${keysArray.join(" ")})`;
}