UNPKG

zenstack

Version:

FullStack enhancement for Prisma ORM: seamless integration from database to UI

522 lines 25.4 kB
"use strict"; var __importDefault = (this && this.__importDefault) || function (mod) { return (mod && mod.__esModule) ? mod : { "default": mod }; }; Object.defineProperty(exports, "__esModule", { value: true }); exports.PolicyGenerator = void 0; const ast_1 = require("@zenstackhq/language/ast"); const sdk_1 = require("@zenstackhq/sdk"); const prisma_1 = require("@zenstackhq/sdk/prisma"); const langium_1 = require("langium"); const lower_case_first_1 = require("lower-case-first"); const path_1 = __importDefault(require("path")); const ts_morph_1 = require("ts-morph"); const ast_utils_1 = require("../../../utils/ast-utils"); const constraint_transformer_1 = require("./constraint-transformer"); const utils_1 = require("./utils"); /** * Generates source file that contains Prisma query guard objects used for injecting database queries */ class PolicyGenerator { constructor(options) { this.options = options; } generate(project, model, output) { const sf = project.createSourceFile(path_1.default.join(output, 'policy.ts'), undefined, { overwrite: true }); this.writeImports(model, output, sf); const models = (0, sdk_1.getDataModels)(model); sf.addVariableStatement({ declarationKind: ts_morph_1.VariableDeclarationKind.Const, declarations: [ { name: 'policy', type: 'PolicyDef', initializer: (writer) => { writer.block(() => { this.writePolicy(writer, models, sf); this.writeValidationMeta(writer, models); this.writeAuthSelector(models, writer); }); }, }, ], }); sf.addStatements('export default policy'); // save ts files if requested explicitly or the user provided const preserveTsFiles = this.options.preserveTsFiles === true || !!this.options.output; if (preserveTsFiles) { (0, sdk_1.saveSourceFile)(sf); } } writeImports(model, output, sf) { sf.addImportDeclaration({ namedImports: [ { name: 'type QueryContext' }, { name: 'type CrudContract' }, { name: 'type PermissionCheckerContext' }, ], moduleSpecifier: `${sdk_1.RUNTIME_PACKAGE}`, }); sf.addImportDeclaration({ namedImports: [{ name: 'allFieldsEqual' }], moduleSpecifier: `${sdk_1.RUNTIME_PACKAGE}/validation`, }); sf.addImportDeclaration({ namedImports: [{ name: 'type PolicyDef' }, { name: 'type PermissionCheckerConstraint' }], moduleSpecifier: `${sdk_1.RUNTIME_PACKAGE}/enhancements/node`, }); // import enums const prismaImport = (0, prisma_1.getPrismaClientImportSpec)(output, this.options); for (const e of model.declarations.filter((d) => (0, ast_1.isEnum)(d) && (0, utils_1.isEnumReferenced)(model, d))) { sf.addImportDeclaration({ namedImports: [{ name: e.name }], moduleSpecifier: prismaImport, }); } } writePolicy(writer, models, sourceFile) { writer.write('policy:'); writer.inlineBlock(() => { for (const model of models) { writer.write(`${(0, lower_case_first_1.lowerCaseFirst)(model.name)}:`); writer.block(() => { // model-level guards this.writeModelLevelDefs(model, writer, sourceFile); // field-level guards this.writeFieldLevelDefs(model, writer, sourceFile); }); writer.writeLine(','); } }); writer.writeLine(','); } // #region Model-level definitions // writes model-level policy def for each operation kind for a model // `[modelName]: { [operationKind]: [funcName] },` writeModelLevelDefs(model, writer, sourceFile) { const policies = (0, sdk_1.analyzePolicies)(model); writer.write('modelLevel:'); writer.inlineBlock(() => { this.writeModelReadDef(model, policies, writer, sourceFile); this.writeModelCreateDef(model, policies, writer, sourceFile); this.writeModelUpdateDef(model, policies, writer, sourceFile); this.writeModelPostUpdateDef(model, policies, writer, sourceFile); this.writeModelDeleteDef(model, policies, writer, sourceFile); }); writer.writeLine(','); } // writes `read: ...` for a given model writeModelReadDef(model, policies, writer, sourceFile) { writer.write(`read:`); writer.inlineBlock(() => { this.writeCommonModelDef(model, 'read', policies, writer, sourceFile); }); writer.writeLine(','); } // writes `create: ...` for a given model writeModelCreateDef(model, policies, writer, sourceFile) { writer.write(`create:`); writer.inlineBlock(() => { this.writeCommonModelDef(model, 'create', policies, writer, sourceFile); // create policy has an additional input checker for validating the payload this.writeCreateInputChecker(model, writer, sourceFile); }); writer.writeLine(','); } // writes `inputChecker: [funcName]` for a given model writeCreateInputChecker(model, writer, sourceFile) { if (this.canCheckCreateBasedOnInput(model)) { const inputCheckFunc = this.generateCreateInputCheckerFunction(model, sourceFile); writer.write(`inputChecker: ${inputCheckFunc.getName()},`); } } canCheckCreateBasedOnInput(model) { const allows = (0, utils_1.getPolicyExpressions)(model, 'allow', 'create', false, 'all'); const denies = (0, utils_1.getPolicyExpressions)(model, 'deny', 'create', false, 'all'); return [...allows, ...denies].every((rule) => { return (0, langium_1.streamAst)(rule).every((expr) => { var _a; if ((0, ast_1.isThisExpr)(expr)) { return false; } if ((0, ast_1.isReferenceExpr)(expr)) { if ((0, ast_1.isDataModel)((_a = expr.$resolvedType) === null || _a === void 0 ? void 0 : _a.decl)) { // if policy rules uses relation fields, // we can't check based on create input return false; } if ((0, ast_1.isDataModelField)(expr.target.ref) && expr.target.ref.$container === model && (0, sdk_1.hasAttribute)(expr.target.ref, '@default')) { // reference to field of current model // if it has default value, we can't check // based on create input return false; } if ((0, ast_1.isDataModelField)(expr.target.ref) && (0, sdk_1.isForeignKeyField)(expr.target.ref)) { // reference to foreign key field // we can't check based on create input return false; } } return true; }); }); } // generates a function for checking "create" input generateCreateInputCheckerFunction(model, sourceFile) { const statements = []; const allows = (0, utils_1.getPolicyExpressions)(model, 'allow', 'create'); const denies = (0, utils_1.getPolicyExpressions)(model, 'deny', 'create'); (0, utils_1.generateNormalizedAuthRef)(model, allows, denies, statements); statements.push((writer) => { if (allows.length === 0) { writer.write('return false;'); return; } const transformer = new sdk_1.TypeScriptExpressionTransformer({ context: sdk_1.ExpressionContext.AccessPolicy, fieldReferenceContext: 'input', operationContext: 'create', }); let expr = denies.length > 0 ? '!(' + denies .map((deny) => { return transformer.transform(deny, false); }) .join(' || ') + ')' : undefined; const allowStmt = allows .map((allow) => { return transformer.transform(allow, false); }) .join(' || '); expr = expr ? `${expr} && (${allowStmt})` : allowStmt; writer.write('return ' + expr); }); const func = sourceFile.addFunction({ name: model.name + '_create_input', returnType: 'boolean', parameters: [ { name: 'input', type: 'any', }, { name: 'context', type: 'QueryContext', }, ], statements, }); return func; } // writes `update: ...` for a given model writeModelUpdateDef(model, policies, writer, sourceFile) { writer.write(`update:`); writer.inlineBlock(() => { this.writeCommonModelDef(model, 'update', policies, writer, sourceFile); }); writer.writeLine(','); } // writes `postUpdate: ...` for a given model writeModelPostUpdateDef(model, policies, writer, sourceFile) { writer.write(`postUpdate:`); writer.inlineBlock(() => { this.writeCommonModelDef(model, 'postUpdate', policies, writer, sourceFile); // post-update policy has an additional selector for reading the pre-update entity data this.writePostUpdatePreValueSelector(model, writer); }); writer.writeLine(','); } writePostUpdatePreValueSelector(model, writer) { const allows = (0, utils_1.getPolicyExpressions)(model, 'allow', 'postUpdate'); const denies = (0, utils_1.getPolicyExpressions)(model, 'deny', 'postUpdate'); const preValueSelect = (0, utils_1.generateSelectForRules)([...allows, ...denies], 'postUpdate'); if (preValueSelect) { writer.writeLine(`preUpdateSelector: ${JSON.stringify(preValueSelect)},`); } } // writes `delete: ...` for a given model writeModelDeleteDef(model, policies, writer, sourceFile) { writer.write(`delete:`); writer.inlineBlock(() => { this.writeCommonModelDef(model, 'delete', policies, writer, sourceFile); }); } // writes `[kind]: ...` for a given model writeCommonModelDef(model, kind, policies, writer, sourceFile) { const allows = (0, utils_1.getPolicyExpressions)(model, 'allow', kind); const denies = (0, utils_1.getPolicyExpressions)(model, 'deny', kind); // policy guard this.writePolicyGuard(model, kind, policies, allows, denies, writer, sourceFile); // permission checker if (kind !== 'postUpdate') { this.writePermissionChecker(model, kind, policies, allows, denies, writer, sourceFile); } // write cross-model comparison rules as entity checker functions // because they cannot be checked inside Prisma const { functionName, selector } = this.writeEntityChecker(model, kind, sourceFile, false); if (this.shouldUseEntityChecker(model, kind, true, false)) { writer.write(`entityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },`); } } shouldUseEntityChecker(target, kind, onlyCrossModelComparison, forOverride) { const allows = (0, utils_1.getPolicyExpressions)(target, 'allow', kind, forOverride, onlyCrossModelComparison ? 'onlyCrossModelComparison' : 'all'); const denies = (0, utils_1.getPolicyExpressions)(target, 'deny', kind, forOverride, onlyCrossModelComparison ? 'onlyCrossModelComparison' : 'all'); if (allows.length > 0 || denies.length > 0) { return true; } const allRules = [ ...(0, utils_1.getPolicyExpressions)(target, 'allow', kind, forOverride, 'all'), ...(0, utils_1.getPolicyExpressions)(target, 'deny', kind, forOverride, 'all'), ]; return allRules.some((rule) => { return (0, langium_1.streamAst)(rule).some((node) => { var _a; if ((0, ast_utils_1.isCheckInvocation)(node)) { const expr = node; const fieldRef = expr.args[0].value; const targetModel = (_a = fieldRef.$resolvedType) === null || _a === void 0 ? void 0 : _a.decl; return this.shouldUseEntityChecker(targetModel, kind, onlyCrossModelComparison, forOverride); } return false; }); }); } writeEntityChecker(target, kind, sourceFile, forOverride) { var _a; const allows = (0, utils_1.getPolicyExpressions)(target, 'allow', kind, forOverride, 'all'); const denies = (0, utils_1.getPolicyExpressions)(target, 'deny', kind, forOverride, 'all'); const model = (0, ast_1.isDataModel)(target) ? target : target.$container; const func = (0, utils_1.generateEntityCheckerFunction)(sourceFile, model, kind, allows, denies, (0, ast_1.isDataModelField)(target) ? target : undefined, forOverride); const selector = (_a = (0, utils_1.generateSelectForRules)([...allows, ...denies], kind, false, kind !== 'postUpdate')) !== null && _a !== void 0 ? _a : {}; return { functionName: func.getName(), selector }; } // writes `guard: ...` for a given policy operation kind writePolicyGuard(model, kind, policies, allows, denies, writer, sourceFile) { // first handle several cases where a constant function can be used if (kind === 'update' && allows.length === 0) { // no allow rule for 'update', policy is constant based on if there's // post-update counterpart let func; if ((0, utils_1.getPolicyExpressions)(model, 'allow', 'postUpdate').length === 0) { func = (0, utils_1.generateConstantQueryGuardFunction)(sourceFile, model, kind, false); } else { func = (0, utils_1.generateConstantQueryGuardFunction)(sourceFile, model, kind, true); } writer.write(`guard: ${func.getName()},`); return; } if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) { // no 'postUpdate' rule, always allow const func = (0, utils_1.generateConstantQueryGuardFunction)(sourceFile, model, kind, true); writer.write(`guard: ${func.getName()},`); return; } if (kind in policies && typeof policies[kind] === 'boolean') { // constant policy const func = (0, utils_1.generateConstantQueryGuardFunction)(sourceFile, model, kind, policies[kind]); writer.write(`guard: ${func.getName()},`); return; } // generate a policy function that evaluates a partial prisma query const guardFunc = (0, utils_1.generateQueryGuardFunction)(sourceFile, model, kind, allows, denies); writer.write(`guard: ${guardFunc.getName()},`); } // writes `permissionChecker: ...` for a given policy operation kind writePermissionChecker(model, kind, policies, allows, denies, writer, sourceFile) { if (this.options.generatePermissionChecker !== true) { return; } if (policies[kind] === true || policies[kind] === false) { // constant policy writer.write(`permissionChecker: ${policies[kind]},`); return; } if (kind === 'update' && allows.length === 0) { // no allow rule for 'update', policy is constant based on if there's // post-update counterpart if ((0, utils_1.getPolicyExpressions)(model, 'allow', 'postUpdate').length === 0) { writer.write(`permissionChecker: false,`); } else { writer.write(`permissionChecker: true,`); } return; } const guardFunc = this.generatePermissionCheckerFunction(model, kind, allows, denies, sourceFile); writer.write(`permissionChecker: ${guardFunc.getName()},`); } generatePermissionCheckerFunction(model, kind, allows, denies, sourceFile) { const statements = []; (0, utils_1.generateNormalizedAuthRef)(model, allows, denies, statements); const transformed = new constraint_transformer_1.ConstraintTransformer({ authAccessor: 'user', }).transformRules(allows, denies); statements.push(`return ${transformed};`); const func = sourceFile.addFunction({ name: `${model.name}$checker$${kind}`, returnType: 'PermissionCheckerConstraint', parameters: [ { name: 'context', type: 'PermissionCheckerContext', }, ], statements, }); return func; } // #endregion // #region Field-level definitions writeFieldLevelDefs(model, writer, sf) { writer.write('fieldLevel:'); writer.inlineBlock(() => { this.writeFieldReadDef(model, writer, sf); this.writeFieldUpdateDef(model, writer, sf); }); writer.writeLine(','); } writeFieldReadDef(model, writer, sourceFile) { writer.writeLine('read:'); writer.block(() => { for (const field of model.fields) { const allows = (0, utils_1.getPolicyExpressions)(field, 'allow', 'read'); const denies = (0, utils_1.getPolicyExpressions)(field, 'deny', 'read'); const overrideAllows = (0, utils_1.getPolicyExpressions)(field, 'allow', 'read', true); if (allows.length === 0 && denies.length === 0 && overrideAllows.length === 0) { continue; } writer.write(`${field.name}:`); writer.block(() => { // guard const guardFunc = (0, utils_1.generateQueryGuardFunction)(sourceFile, model, 'read', allows, denies, field); writer.write(`guard: ${guardFunc.getName()},`); // checker function // write all field-level rules as entity checker function const { functionName, selector } = this.writeEntityChecker(field, 'read', sourceFile, false); if (this.shouldUseEntityChecker(field, 'read', false, false)) { writer.write(`entityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },`); } if (overrideAllows.length > 0) { // override guard function const denies = (0, utils_1.getPolicyExpressions)(field, 'deny', 'read'); const overrideGuardFunc = (0, utils_1.generateQueryGuardFunction)(sourceFile, model, 'read', overrideAllows, denies, field, true); writer.write(`overrideGuard: ${overrideGuardFunc.getName()},`); // additional entity checker for override const { functionName, selector } = this.writeEntityChecker(field, 'read', sourceFile, true); if (this.shouldUseEntityChecker(field, 'read', false, true)) { writer.write(`overrideEntityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },`); } } }); writer.writeLine(','); } }); writer.writeLine(','); } writeFieldUpdateDef(model, writer, sourceFile) { writer.writeLine('update:'); writer.block(() => { for (const field of model.fields) { const allows = (0, utils_1.getPolicyExpressions)(field, 'allow', 'update'); const denies = (0, utils_1.getPolicyExpressions)(field, 'deny', 'update'); const overrideAllows = (0, utils_1.getPolicyExpressions)(field, 'allow', 'update', true); if (allows.length === 0 && denies.length === 0 && overrideAllows.length === 0) { continue; } writer.write(`${field.name}:`); writer.block(() => { // guard const guardFunc = (0, utils_1.generateQueryGuardFunction)(sourceFile, model, 'update', allows, denies, field); writer.write(`guard: ${guardFunc.getName()},`); // write cross-model comparison rules as entity checker functions // because they cannot be checked inside Prisma const { functionName, selector } = this.writeEntityChecker(field, 'update', sourceFile, false); if (this.shouldUseEntityChecker(field, 'update', true, false)) { writer.write(`entityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },`); } if (overrideAllows.length > 0) { // override guard const overrideGuardFunc = (0, utils_1.generateQueryGuardFunction)(sourceFile, model, 'update', overrideAllows, denies, field, true); writer.write(`overrideGuard: ${overrideGuardFunc.getName()},`); // write cross-model comparison override rules as entity checker functions // because they cannot be checked inside Prisma const { functionName, selector } = this.writeEntityChecker(field, 'update', sourceFile, true); if (this.shouldUseEntityChecker(field, 'update', true, true)) { writer.write(`overrideEntityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },`); } } }); writer.writeLine(','); } }); writer.writeLine(','); } // #endregion //#region Auth selector writeAuthSelector(models, writer) { const authSelector = this.generateAuthSelector(models); if (authSelector) { writer.write(`authSelector: ${JSON.stringify(authSelector)},`); } } // Generates a { select: ... } object to select `auth()` fields used in policy rules generateAuthSelector(models) { const authRules = []; models.forEach((model) => { // model-level rules const modelPolicyAttrs = model.attributes.filter((attr) => ['@@allow', '@@deny'].includes(attr.decl.$refText)); // field-level rules const fieldPolicyAttrs = model.fields .flatMap((f) => f.attributes) .filter((attr) => ['@allow', '@deny'].includes(attr.decl.$refText)); // all rule expression const allExpressions = [...modelPolicyAttrs, ...fieldPolicyAttrs] .filter((attr) => attr.args.length > 1) .map((attr) => attr.args[1].value); // collect `auth()` member access allExpressions.forEach((rule) => { (0, langium_1.streamAst)(rule).forEach((node) => { if ((0, ast_1.isMemberAccessExpr)(node) && (0, sdk_1.isAuthInvocation)(node.operand)) { authRules.push(node); } }); }); }); if (authRules.length > 0) { return (0, utils_1.generateSelectForRules)(authRules, undefined, true); } else { return undefined; } } // #endregion // #region Validation meta writeValidationMeta(writer, models) { writer.write('validation:'); writer.inlineBlock(() => { for (const model of models) { writer.write(`${(0, lower_case_first_1.lowerCaseFirst)(model.name)}:`); writer.inlineBlock(() => { writer.write(`hasValidation: ${ // explicit validation rules (0, sdk_1.hasValidationAttributes)(model) || // type-def fields require schema validation this.hasTypeDefFields(model)}`); }); writer.writeLine(','); } }); writer.writeLine(','); } hasTypeDefFields(model) { return model.fields.some((f) => { var _a; return (0, ast_1.isTypeDef)((_a = f.type.reference) === null || _a === void 0 ? void 0 : _a.ref); }); } } exports.PolicyGenerator = PolicyGenerator; //# sourceMappingURL=policy-guard-generator.js.map