zenstack
Version:
FullStack enhancement for Prisma ORM: seamless integration from database to UI
522 lines • 25.4 kB
JavaScript
var __importDefault = (this && this.__importDefault) || function (mod) {
return (mod && mod.__esModule) ? mod : { "default": mod };
};
Object.defineProperty(exports, "__esModule", { value: true });
exports.PolicyGenerator = void 0;
const ast_1 = require("@zenstackhq/language/ast");
const sdk_1 = require("@zenstackhq/sdk");
const prisma_1 = require("@zenstackhq/sdk/prisma");
const langium_1 = require("langium");
const lower_case_first_1 = require("lower-case-first");
const path_1 = __importDefault(require("path"));
const ts_morph_1 = require("ts-morph");
const ast_utils_1 = require("../../../utils/ast-utils");
const constraint_transformer_1 = require("./constraint-transformer");
const utils_1 = require("./utils");
/**
* Generates source file that contains Prisma query guard objects used for injecting database queries
*/
class PolicyGenerator {
constructor(options) {
this.options = options;
}
generate(project, model, output) {
const sf = project.createSourceFile(path_1.default.join(output, 'policy.ts'), undefined, { overwrite: true });
this.writeImports(model, output, sf);
const models = (0, sdk_1.getDataModels)(model);
sf.addVariableStatement({
declarationKind: ts_morph_1.VariableDeclarationKind.Const,
declarations: [
{
name: 'policy',
type: 'PolicyDef',
initializer: (writer) => {
writer.block(() => {
this.writePolicy(writer, models, sf);
this.writeValidationMeta(writer, models);
this.writeAuthSelector(models, writer);
});
},
},
],
});
sf.addStatements('export default policy');
// save ts files if requested explicitly or the user provided
const preserveTsFiles = this.options.preserveTsFiles === true || !!this.options.output;
if (preserveTsFiles) {
(0, sdk_1.saveSourceFile)(sf);
}
}
writeImports(model, output, sf) {
sf.addImportDeclaration({
namedImports: [
{ name: 'type QueryContext' },
{ name: 'type CrudContract' },
{ name: 'type PermissionCheckerContext' },
],
moduleSpecifier: `${sdk_1.RUNTIME_PACKAGE}`,
});
sf.addImportDeclaration({
namedImports: [{ name: 'allFieldsEqual' }],
moduleSpecifier: `${sdk_1.RUNTIME_PACKAGE}/validation`,
});
sf.addImportDeclaration({
namedImports: [{ name: 'type PolicyDef' }, { name: 'type PermissionCheckerConstraint' }],
moduleSpecifier: `${sdk_1.RUNTIME_PACKAGE}/enhancements/node`,
});
// import enums
const prismaImport = (0, prisma_1.getPrismaClientImportSpec)(output, this.options);
for (const e of model.declarations.filter((d) => (0, ast_1.isEnum)(d) && (0, utils_1.isEnumReferenced)(model, d))) {
sf.addImportDeclaration({
namedImports: [{ name: e.name }],
moduleSpecifier: prismaImport,
});
}
}
writePolicy(writer, models, sourceFile) {
writer.write('policy:');
writer.inlineBlock(() => {
for (const model of models) {
writer.write(`${(0, lower_case_first_1.lowerCaseFirst)(model.name)}:`);
writer.block(() => {
// model-level guards
this.writeModelLevelDefs(model, writer, sourceFile);
// field-level guards
this.writeFieldLevelDefs(model, writer, sourceFile);
});
writer.writeLine(',');
}
});
writer.writeLine(',');
}
// #region Model-level definitions
// writes model-level policy def for each operation kind for a model
// `[modelName]: { [operationKind]: [funcName] },`
writeModelLevelDefs(model, writer, sourceFile) {
const policies = (0, sdk_1.analyzePolicies)(model);
writer.write('modelLevel:');
writer.inlineBlock(() => {
this.writeModelReadDef(model, policies, writer, sourceFile);
this.writeModelCreateDef(model, policies, writer, sourceFile);
this.writeModelUpdateDef(model, policies, writer, sourceFile);
this.writeModelPostUpdateDef(model, policies, writer, sourceFile);
this.writeModelDeleteDef(model, policies, writer, sourceFile);
});
writer.writeLine(',');
}
// writes `read: ...` for a given model
writeModelReadDef(model, policies, writer, sourceFile) {
writer.write(`read:`);
writer.inlineBlock(() => {
this.writeCommonModelDef(model, 'read', policies, writer, sourceFile);
});
writer.writeLine(',');
}
// writes `create: ...` for a given model
writeModelCreateDef(model, policies, writer, sourceFile) {
writer.write(`create:`);
writer.inlineBlock(() => {
this.writeCommonModelDef(model, 'create', policies, writer, sourceFile);
// create policy has an additional input checker for validating the payload
this.writeCreateInputChecker(model, writer, sourceFile);
});
writer.writeLine(',');
}
// writes `inputChecker: [funcName]` for a given model
writeCreateInputChecker(model, writer, sourceFile) {
if (this.canCheckCreateBasedOnInput(model)) {
const inputCheckFunc = this.generateCreateInputCheckerFunction(model, sourceFile);
writer.write(`inputChecker: ${inputCheckFunc.getName()},`);
}
}
canCheckCreateBasedOnInput(model) {
const allows = (0, utils_1.getPolicyExpressions)(model, 'allow', 'create', false, 'all');
const denies = (0, utils_1.getPolicyExpressions)(model, 'deny', 'create', false, 'all');
return [...allows, ...denies].every((rule) => {
return (0, langium_1.streamAst)(rule).every((expr) => {
var _a;
if ((0, ast_1.isThisExpr)(expr)) {
return false;
}
if ((0, ast_1.isReferenceExpr)(expr)) {
if ((0, ast_1.isDataModel)((_a = expr.$resolvedType) === null || _a === void 0 ? void 0 : _a.decl)) {
// if policy rules uses relation fields,
// we can't check based on create input
return false;
}
if ((0, ast_1.isDataModelField)(expr.target.ref) &&
expr.target.ref.$container === model &&
(0, sdk_1.hasAttribute)(expr.target.ref, '@default')) {
// reference to field of current model
// if it has default value, we can't check
// based on create input
return false;
}
if ((0, ast_1.isDataModelField)(expr.target.ref) && (0, sdk_1.isForeignKeyField)(expr.target.ref)) {
// reference to foreign key field
// we can't check based on create input
return false;
}
}
return true;
});
});
}
// generates a function for checking "create" input
generateCreateInputCheckerFunction(model, sourceFile) {
const statements = [];
const allows = (0, utils_1.getPolicyExpressions)(model, 'allow', 'create');
const denies = (0, utils_1.getPolicyExpressions)(model, 'deny', 'create');
(0, utils_1.generateNormalizedAuthRef)(model, allows, denies, statements);
statements.push((writer) => {
if (allows.length === 0) {
writer.write('return false;');
return;
}
const transformer = new sdk_1.TypeScriptExpressionTransformer({
context: sdk_1.ExpressionContext.AccessPolicy,
fieldReferenceContext: 'input',
operationContext: 'create',
});
let expr = denies.length > 0
? '!(' +
denies
.map((deny) => {
return transformer.transform(deny, false);
})
.join(' || ') +
')'
: undefined;
const allowStmt = allows
.map((allow) => {
return transformer.transform(allow, false);
})
.join(' || ');
expr = expr ? `${expr} && (${allowStmt})` : allowStmt;
writer.write('return ' + expr);
});
const func = sourceFile.addFunction({
name: model.name + '_create_input',
returnType: 'boolean',
parameters: [
{
name: 'input',
type: 'any',
},
{
name: 'context',
type: 'QueryContext',
},
],
statements,
});
return func;
}
// writes `update: ...` for a given model
writeModelUpdateDef(model, policies, writer, sourceFile) {
writer.write(`update:`);
writer.inlineBlock(() => {
this.writeCommonModelDef(model, 'update', policies, writer, sourceFile);
});
writer.writeLine(',');
}
// writes `postUpdate: ...` for a given model
writeModelPostUpdateDef(model, policies, writer, sourceFile) {
writer.write(`postUpdate:`);
writer.inlineBlock(() => {
this.writeCommonModelDef(model, 'postUpdate', policies, writer, sourceFile);
// post-update policy has an additional selector for reading the pre-update entity data
this.writePostUpdatePreValueSelector(model, writer);
});
writer.writeLine(',');
}
writePostUpdatePreValueSelector(model, writer) {
const allows = (0, utils_1.getPolicyExpressions)(model, 'allow', 'postUpdate');
const denies = (0, utils_1.getPolicyExpressions)(model, 'deny', 'postUpdate');
const preValueSelect = (0, utils_1.generateSelectForRules)([...allows, ...denies], 'postUpdate');
if (preValueSelect) {
writer.writeLine(`preUpdateSelector: ${JSON.stringify(preValueSelect)},`);
}
}
// writes `delete: ...` for a given model
writeModelDeleteDef(model, policies, writer, sourceFile) {
writer.write(`delete:`);
writer.inlineBlock(() => {
this.writeCommonModelDef(model, 'delete', policies, writer, sourceFile);
});
}
// writes `[kind]: ...` for a given model
writeCommonModelDef(model, kind, policies, writer, sourceFile) {
const allows = (0, utils_1.getPolicyExpressions)(model, 'allow', kind);
const denies = (0, utils_1.getPolicyExpressions)(model, 'deny', kind);
// policy guard
this.writePolicyGuard(model, kind, policies, allows, denies, writer, sourceFile);
// permission checker
if (kind !== 'postUpdate') {
this.writePermissionChecker(model, kind, policies, allows, denies, writer, sourceFile);
}
// write cross-model comparison rules as entity checker functions
// because they cannot be checked inside Prisma
const { functionName, selector } = this.writeEntityChecker(model, kind, sourceFile, false);
if (this.shouldUseEntityChecker(model, kind, true, false)) {
writer.write(`entityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },`);
}
}
shouldUseEntityChecker(target, kind, onlyCrossModelComparison, forOverride) {
const allows = (0, utils_1.getPolicyExpressions)(target, 'allow', kind, forOverride, onlyCrossModelComparison ? 'onlyCrossModelComparison' : 'all');
const denies = (0, utils_1.getPolicyExpressions)(target, 'deny', kind, forOverride, onlyCrossModelComparison ? 'onlyCrossModelComparison' : 'all');
if (allows.length > 0 || denies.length > 0) {
return true;
}
const allRules = [
...(0, utils_1.getPolicyExpressions)(target, 'allow', kind, forOverride, 'all'),
...(0, utils_1.getPolicyExpressions)(target, 'deny', kind, forOverride, 'all'),
];
return allRules.some((rule) => {
return (0, langium_1.streamAst)(rule).some((node) => {
var _a;
if ((0, ast_utils_1.isCheckInvocation)(node)) {
const expr = node;
const fieldRef = expr.args[0].value;
const targetModel = (_a = fieldRef.$resolvedType) === null || _a === void 0 ? void 0 : _a.decl;
return this.shouldUseEntityChecker(targetModel, kind, onlyCrossModelComparison, forOverride);
}
return false;
});
});
}
writeEntityChecker(target, kind, sourceFile, forOverride) {
var _a;
const allows = (0, utils_1.getPolicyExpressions)(target, 'allow', kind, forOverride, 'all');
const denies = (0, utils_1.getPolicyExpressions)(target, 'deny', kind, forOverride, 'all');
const model = (0, ast_1.isDataModel)(target) ? target : target.$container;
const func = (0, utils_1.generateEntityCheckerFunction)(sourceFile, model, kind, allows, denies, (0, ast_1.isDataModelField)(target) ? target : undefined, forOverride);
const selector = (_a = (0, utils_1.generateSelectForRules)([...allows, ...denies], kind, false, kind !== 'postUpdate')) !== null && _a !== void 0 ? _a : {};
return { functionName: func.getName(), selector };
}
// writes `guard: ...` for a given policy operation kind
writePolicyGuard(model, kind, policies, allows, denies, writer, sourceFile) {
// first handle several cases where a constant function can be used
if (kind === 'update' && allows.length === 0) {
// no allow rule for 'update', policy is constant based on if there's
// post-update counterpart
let func;
if ((0, utils_1.getPolicyExpressions)(model, 'allow', 'postUpdate').length === 0) {
func = (0, utils_1.generateConstantQueryGuardFunction)(sourceFile, model, kind, false);
}
else {
func = (0, utils_1.generateConstantQueryGuardFunction)(sourceFile, model, kind, true);
}
writer.write(`guard: ${func.getName()},`);
return;
}
if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) {
// no 'postUpdate' rule, always allow
const func = (0, utils_1.generateConstantQueryGuardFunction)(sourceFile, model, kind, true);
writer.write(`guard: ${func.getName()},`);
return;
}
if (kind in policies && typeof policies[kind] === 'boolean') {
// constant policy
const func = (0, utils_1.generateConstantQueryGuardFunction)(sourceFile, model, kind, policies[kind]);
writer.write(`guard: ${func.getName()},`);
return;
}
// generate a policy function that evaluates a partial prisma query
const guardFunc = (0, utils_1.generateQueryGuardFunction)(sourceFile, model, kind, allows, denies);
writer.write(`guard: ${guardFunc.getName()},`);
}
// writes `permissionChecker: ...` for a given policy operation kind
writePermissionChecker(model, kind, policies, allows, denies, writer, sourceFile) {
if (this.options.generatePermissionChecker !== true) {
return;
}
if (policies[kind] === true || policies[kind] === false) {
// constant policy
writer.write(`permissionChecker: ${policies[kind]},`);
return;
}
if (kind === 'update' && allows.length === 0) {
// no allow rule for 'update', policy is constant based on if there's
// post-update counterpart
if ((0, utils_1.getPolicyExpressions)(model, 'allow', 'postUpdate').length === 0) {
writer.write(`permissionChecker: false,`);
}
else {
writer.write(`permissionChecker: true,`);
}
return;
}
const guardFunc = this.generatePermissionCheckerFunction(model, kind, allows, denies, sourceFile);
writer.write(`permissionChecker: ${guardFunc.getName()},`);
}
generatePermissionCheckerFunction(model, kind, allows, denies, sourceFile) {
const statements = [];
(0, utils_1.generateNormalizedAuthRef)(model, allows, denies, statements);
const transformed = new constraint_transformer_1.ConstraintTransformer({
authAccessor: 'user',
}).transformRules(allows, denies);
statements.push(`return ${transformed};`);
const func = sourceFile.addFunction({
name: `${model.name}$checker$${kind}`,
returnType: 'PermissionCheckerConstraint',
parameters: [
{
name: 'context',
type: 'PermissionCheckerContext',
},
],
statements,
});
return func;
}
// #endregion
// #region Field-level definitions
writeFieldLevelDefs(model, writer, sf) {
writer.write('fieldLevel:');
writer.inlineBlock(() => {
this.writeFieldReadDef(model, writer, sf);
this.writeFieldUpdateDef(model, writer, sf);
});
writer.writeLine(',');
}
writeFieldReadDef(model, writer, sourceFile) {
writer.writeLine('read:');
writer.block(() => {
for (const field of model.fields) {
const allows = (0, utils_1.getPolicyExpressions)(field, 'allow', 'read');
const denies = (0, utils_1.getPolicyExpressions)(field, 'deny', 'read');
const overrideAllows = (0, utils_1.getPolicyExpressions)(field, 'allow', 'read', true);
if (allows.length === 0 && denies.length === 0 && overrideAllows.length === 0) {
continue;
}
writer.write(`${field.name}:`);
writer.block(() => {
// guard
const guardFunc = (0, utils_1.generateQueryGuardFunction)(sourceFile, model, 'read', allows, denies, field);
writer.write(`guard: ${guardFunc.getName()},`);
// checker function
// write all field-level rules as entity checker function
const { functionName, selector } = this.writeEntityChecker(field, 'read', sourceFile, false);
if (this.shouldUseEntityChecker(field, 'read', false, false)) {
writer.write(`entityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },`);
}
if (overrideAllows.length > 0) {
// override guard function
const denies = (0, utils_1.getPolicyExpressions)(field, 'deny', 'read');
const overrideGuardFunc = (0, utils_1.generateQueryGuardFunction)(sourceFile, model, 'read', overrideAllows, denies, field, true);
writer.write(`overrideGuard: ${overrideGuardFunc.getName()},`);
// additional entity checker for override
const { functionName, selector } = this.writeEntityChecker(field, 'read', sourceFile, true);
if (this.shouldUseEntityChecker(field, 'read', false, true)) {
writer.write(`overrideEntityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },`);
}
}
});
writer.writeLine(',');
}
});
writer.writeLine(',');
}
writeFieldUpdateDef(model, writer, sourceFile) {
writer.writeLine('update:');
writer.block(() => {
for (const field of model.fields) {
const allows = (0, utils_1.getPolicyExpressions)(field, 'allow', 'update');
const denies = (0, utils_1.getPolicyExpressions)(field, 'deny', 'update');
const overrideAllows = (0, utils_1.getPolicyExpressions)(field, 'allow', 'update', true);
if (allows.length === 0 && denies.length === 0 && overrideAllows.length === 0) {
continue;
}
writer.write(`${field.name}:`);
writer.block(() => {
// guard
const guardFunc = (0, utils_1.generateQueryGuardFunction)(sourceFile, model, 'update', allows, denies, field);
writer.write(`guard: ${guardFunc.getName()},`);
// write cross-model comparison rules as entity checker functions
// because they cannot be checked inside Prisma
const { functionName, selector } = this.writeEntityChecker(field, 'update', sourceFile, false);
if (this.shouldUseEntityChecker(field, 'update', true, false)) {
writer.write(`entityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },`);
}
if (overrideAllows.length > 0) {
// override guard
const overrideGuardFunc = (0, utils_1.generateQueryGuardFunction)(sourceFile, model, 'update', overrideAllows, denies, field, true);
writer.write(`overrideGuard: ${overrideGuardFunc.getName()},`);
// write cross-model comparison override rules as entity checker functions
// because they cannot be checked inside Prisma
const { functionName, selector } = this.writeEntityChecker(field, 'update', sourceFile, true);
if (this.shouldUseEntityChecker(field, 'update', true, true)) {
writer.write(`overrideEntityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },`);
}
}
});
writer.writeLine(',');
}
});
writer.writeLine(',');
}
// #endregion
//#region Auth selector
writeAuthSelector(models, writer) {
const authSelector = this.generateAuthSelector(models);
if (authSelector) {
writer.write(`authSelector: ${JSON.stringify(authSelector)},`);
}
}
// Generates a { select: ... } object to select `auth()` fields used in policy rules
generateAuthSelector(models) {
const authRules = [];
models.forEach((model) => {
// model-level rules
const modelPolicyAttrs = model.attributes.filter((attr) => ['@@allow', '@@deny'].includes(attr.decl.$refText));
// field-level rules
const fieldPolicyAttrs = model.fields
.flatMap((f) => f.attributes)
.filter((attr) => ['@allow', '@deny'].includes(attr.decl.$refText));
// all rule expression
const allExpressions = [...modelPolicyAttrs, ...fieldPolicyAttrs]
.filter((attr) => attr.args.length > 1)
.map((attr) => attr.args[1].value);
// collect `auth()` member access
allExpressions.forEach((rule) => {
(0, langium_1.streamAst)(rule).forEach((node) => {
if ((0, ast_1.isMemberAccessExpr)(node) && (0, sdk_1.isAuthInvocation)(node.operand)) {
authRules.push(node);
}
});
});
});
if (authRules.length > 0) {
return (0, utils_1.generateSelectForRules)(authRules, undefined, true);
}
else {
return undefined;
}
}
// #endregion
// #region Validation meta
writeValidationMeta(writer, models) {
writer.write('validation:');
writer.inlineBlock(() => {
for (const model of models) {
writer.write(`${(0, lower_case_first_1.lowerCaseFirst)(model.name)}:`);
writer.inlineBlock(() => {
writer.write(`hasValidation: ${
// explicit validation rules
(0, sdk_1.hasValidationAttributes)(model) ||
// type-def fields require schema validation
this.hasTypeDefFields(model)}`);
});
writer.writeLine(',');
}
});
writer.writeLine(',');
}
hasTypeDefFields(model) {
return model.fields.some((f) => { var _a; return (0, ast_1.isTypeDef)((_a = f.type.reference) === null || _a === void 0 ? void 0 : _a.ref); });
}
}
exports.PolicyGenerator = PolicyGenerator;
//# sourceMappingURL=policy-guard-generator.js.map
;