zenstack
Version:
FullStack enhancement for Prisma ORM: seamless integration from database to UI
506 lines • 21.3 kB
JavaScript
var __importDefault = (this && this.__importDefault) || function (mod) {
return (mod && mod.__esModule) ? mod : { "default": mod };
};
Object.defineProperty(exports, "__esModule", { value: true });
exports.getPolicyExpressions = getPolicyExpressions;
exports.generateSelectForRules = generateSelectForRules;
exports.generateConstantQueryGuardFunction = generateConstantQueryGuardFunction;
exports.generateQueryGuardFunction = generateQueryGuardFunction;
exports.generateEntityCheckerFunction = generateEntityCheckerFunction;
exports.generateNormalizedAuthRef = generateNormalizedAuthRef;
exports.isEnumReferenced = isEnumReferenced;
const sdk_1 = require("@zenstackhq/sdk");
const ast_1 = require("@zenstackhq/sdk/ast");
const deepmerge_1 = __importDefault(require("deepmerge"));
const langium_1 = require("langium");
const __1 = require("..");
const ast_utils_1 = require("../../../utils/ast-utils");
const expression_writer_1 = require("./expression-writer");
/**
* Get policy expressions for the given model or field and operation kind
*/
function getPolicyExpressions(target, kind, operation, forOverride = false, filter = 'all') {
const attributes = target.attributes;
const attrName = (0, ast_1.isDataModel)(target) ? `@@${kind}` : `@${kind}`;
const attrs = attributes.filter((attr) => {
var _a;
if (((_a = attr.decl.ref) === null || _a === void 0 ? void 0 : _a.name) !== attrName) {
return false;
}
const overrideArg = (0, sdk_1.getAttributeArg)(attr, 'override');
const isOverride = !!overrideArg && (0, sdk_1.getLiteral)(overrideArg) === true;
return (forOverride && isOverride) || (!forOverride && !isOverride);
});
const checkOperation = operation === 'postUpdate' ? 'update' : operation;
let result = attrs
.filter((attr) => {
const opsValue = (0, sdk_1.getLiteral)(attr.args[0].value);
if (!opsValue) {
return false;
}
const ops = opsValue.split(',').map((s) => s.trim());
return ops.includes(checkOperation) || ops.includes('all');
})
.map((attr) => attr.args[1].value);
if (filter === 'onlyCrossModelComparison') {
result = result.filter((expr) => hasCrossModelComparison(expr));
}
else if (filter === 'withoutCrossModelComparison') {
result = result.filter((expr) => !hasCrossModelComparison(expr));
}
if (operation === 'update') {
result = processUpdatePolicies(result, false);
}
else if (operation === 'postUpdate') {
result = processUpdatePolicies(result, true);
}
return result;
}
function hasFutureReference(expr) {
var _a;
for (const node of (0, langium_1.streamAst)(expr)) {
if ((0, ast_1.isInvocationExpr)(node) && ((_a = node.function.ref) === null || _a === void 0 ? void 0 : _a.name) === 'future' && (0, sdk_1.isFromStdlib)(node.function.ref)) {
return true;
}
}
return false;
}
function processUpdatePolicies(expressions, postUpdate) {
const hasFutureRef = expressions.some(hasFutureReference);
if (postUpdate) {
// when compiling post-update rules, if any rule contains `future()` reference,
// we include all as post-update rules
return hasFutureRef ? expressions : [];
}
else {
// when compiling pre-update rules, if any rule contains `future()` reference,
// we completely skip pre-update check and defer them to post-update
return hasFutureRef ? [] : expressions;
}
}
/**
* Generates a "select" object that contains (recursively) fields referenced by the
* given policy rules
*/
function generateSelectForRules(rules, forOperation, forAuthContext = false, ignoreFutureReference = true) {
let result = {};
const addPath = (path) => {
const thisIndex = path.lastIndexOf('$this');
if (thisIndex >= 0) {
// drop everything before $this
path = path.slice(thisIndex + 1);
}
let curr = result;
path.forEach((seg, i) => {
if (i === path.length - 1) {
curr[seg] = true;
}
else {
if (!curr[seg]) {
curr[seg] = { select: {} };
}
curr = curr[seg].select;
}
});
};
// visit a reference or member access expression to build a
// selection path
const visit = (node) => {
if ((0, ast_1.isThisExpr)(node)) {
return ['$this'];
}
if ((0, sdk_1.isFutureExpr)(node)) {
return [];
}
if ((0, ast_1.isReferenceExpr)(node)) {
const target = (0, sdk_1.resolved)(node.target);
if ((0, ast_1.isDataModelField)(target)) {
// a field selection, it's a terminal
return [target.name];
}
}
if ((0, ast_1.isMemberAccessExpr)(node)) {
if (forAuthContext && (0, sdk_1.isAuthInvocation)(node.operand)) {
return [node.member.$refText];
}
if ((0, sdk_1.isFutureExpr)(node.operand) && ignoreFutureReference) {
// future().field is not subject to pre-update select
return undefined;
}
// build a selection path inside-out for chained member access
const inner = visit(node.operand);
if (inner) {
return [...inner, node.member.$refText];
}
}
return undefined;
};
// collect selection paths from the given expression
const collectReferencePaths = (expr) => {
var _a, _b, _c;
if ((0, ast_1.isThisExpr)(expr) && !(0, ast_1.isMemberAccessExpr)(expr.$container)) {
// a standalone `this` expression, include all id fields
const model = (_a = expr.$resolvedType) === null || _a === void 0 ? void 0 : _a.decl;
const idFields = (0, sdk_1.getIdFields)(model);
return idFields.map((field) => [field.name]);
}
if ((0, ast_1.isMemberAccessExpr)(expr) || (0, ast_1.isReferenceExpr)(expr)) {
const path = visit(expr);
if (path) {
if ((0, ast_1.isDataModel)((_b = expr.$resolvedType) === null || _b === void 0 ? void 0 : _b.decl)) {
// member selection ended at a data model field, include its id fields
const idFields = (0, sdk_1.getIdFields)((_c = expr.$resolvedType) === null || _c === void 0 ? void 0 : _c.decl);
return idFields.map((field) => [...path, field.name]);
}
else {
return [path];
}
}
else {
return [];
}
}
else if ((0, ast_utils_1.isCollectionPredicate)(expr)) {
const path = visit(expr.left);
// recurse into RHS
const rhs = collectReferencePaths(expr.right);
if (path) {
// combine path of LHS and RHS
return rhs.map((r) => [...path, ...r]);
}
else {
// LHS is not rooted from the current model,
// only keep RHS items that contains '$this'
return rhs.filter((r) => r.includes('$this'));
}
}
else if ((0, ast_1.isInvocationExpr)(expr)) {
// recurse into function arguments
return expr.args.flatMap((arg) => collectReferencePaths(arg.value));
}
else {
// recurse
const children = (0, langium_1.streamContents)(expr)
.filter((child) => (0, ast_1.isExpression)(child))
.toArray();
return children.flatMap((child) => collectReferencePaths(child));
}
};
for (const rule of rules) {
const paths = collectReferencePaths(rule);
paths.forEach((p) => addPath(p));
// merge selectors from models referenced by `check()` calls
(0, langium_1.streamAst)(rule).forEach((node) => {
var _a, _b, _c;
if ((0, ast_utils_1.isCheckInvocation)(node)) {
const expr = node;
const fieldRef = expr.args[0].value;
const targetModel = (_a = fieldRef.$resolvedType) === null || _a === void 0 ? void 0 : _a.decl;
const targetOperation = (_c = (0, sdk_1.getLiteral)((_b = expr.args[1]) === null || _b === void 0 ? void 0 : _b.value)) !== null && _c !== void 0 ? _c : forOperation;
const targetSelector = generateSelectForRules([
...getPolicyExpressions(targetModel, 'allow', targetOperation),
...getPolicyExpressions(targetModel, 'deny', targetOperation),
], targetOperation, forAuthContext, ignoreFutureReference);
if (targetSelector) {
result = (0, deepmerge_1.default)(result, { [fieldRef.target.$refText]: { select: targetSelector } });
}
}
});
}
return Object.keys(result).length === 0 ? undefined : result;
}
/**
* Generates a constant query guard function
*/
function generateConstantQueryGuardFunction(sourceFile, model, kind, value) {
const func = sourceFile.addFunction({
name: (0, sdk_1.getQueryGuardFunctionName)(model, undefined, false, kind),
returnType: 'any',
parameters: [
{
name: 'context',
type: 'QueryContext',
},
{
// for generating field references used by field comparison in the same model
name: 'db',
type: 'CrudContract',
},
],
statements: [`return ${value ? expression_writer_1.TRUE : expression_writer_1.FALSE};`],
});
return func;
}
/**
* Generates a query guard function that returns a partial Prisma query for the given model or field
*/
function generateQueryGuardFunction(sourceFile, model, kind, allows, denies, forField, fieldOverride = false) {
const statements = [];
const allowRules = allows.filter((rule) => !hasCrossModelComparison(rule));
const denyRules = denies.filter((rule) => !hasCrossModelComparison(rule));
generateNormalizedAuthRef(model, allowRules, denyRules, statements);
const hasFieldAccess = [...denyRules, ...allowRules].some((rule) => (0, langium_1.streamAst)(rule).some((child) =>
// this.???
(0, ast_1.isThisExpr)(child) ||
// future().???
(0, sdk_1.isFutureExpr)(child) ||
// field reference
((0, ast_1.isReferenceExpr)(child) && (0, ast_1.isDataModelField)(child.target.ref))));
if (!hasFieldAccess) {
// none of the rules reference model fields, we can compile down to a plain boolean
// function in this case (so we can skip doing SQL queries when validating)
statements.push((writer) => {
const transformer = new sdk_1.TypeScriptExpressionTransformer({
context: sdk_1.ExpressionContext.AccessPolicy,
isPostGuard: kind === 'postUpdate',
operationContext: kind,
});
try {
denyRules.forEach((rule) => {
writer.write(`if (${transformer.transform(rule, false)}) { return ${expression_writer_1.FALSE}; }`);
});
allowRules.forEach((rule) => {
writer.write(`if (${transformer.transform(rule, false)}) { return ${expression_writer_1.TRUE}; }`);
});
}
catch (err) {
if (err instanceof sdk_1.TypeScriptExpressionTransformerError) {
throw new sdk_1.PluginError(__1.name, err.message);
}
else {
throw err;
}
}
if (forField) {
if (allows.length === 0) {
// if there's no allow rule, for field-level rules, by default we allow
writer.write(`return ${expression_writer_1.TRUE};`);
}
else {
if (allowRules.length < allows.length) {
writer.write(`return ${expression_writer_1.TRUE};`);
}
else {
// if there's any allow rule, we deny unless any allow rule evaluates to true
writer.write(`return ${expression_writer_1.FALSE};`);
}
}
}
else {
if (allowRules.length < allows.length) {
// some rules are filtered out here and will be generated as additional
// checker functions, so we allow here to avoid a premature denial
writer.write(`return ${expression_writer_1.TRUE};`);
}
else {
// for model-level rules, the default is always deny unless for 'postUpdate'
writer.write(`return ${kind === 'postUpdate' ? expression_writer_1.TRUE : expression_writer_1.FALSE};`);
}
}
});
}
else {
statements.push((writer) => {
writer.write('return ');
const exprWriter = new expression_writer_1.ExpressionWriter(writer, {
isPostGuard: kind === 'postUpdate',
operationContext: kind,
});
const writeDenies = () => {
writer.conditionalWrite(denyRules.length > 1, '{ AND: [');
denyRules.forEach((expr, i) => {
writer.inlineBlock(() => {
writer.write('NOT: ');
exprWriter.write(expr);
});
writer.conditionalWrite(i !== denyRules.length - 1, ',');
});
writer.conditionalWrite(denyRules.length > 1, ']}');
};
const writeAllows = () => {
writer.conditionalWrite(allowRules.length > 1, '{ OR: [');
allowRules.forEach((expr, i) => {
exprWriter.write(expr);
writer.conditionalWrite(i !== allowRules.length - 1, ',');
});
writer.conditionalWrite(allowRules.length > 1, ']}');
};
if (allowRules.length > 0 && denyRules.length > 0) {
// include both allow and deny rules
writer.write('{ AND: [');
writeDenies();
writer.write(',');
writeAllows();
writer.write(']}');
}
else if (denyRules.length > 0) {
// only deny rules
writeDenies();
}
else if (allowRules.length > 0) {
// only allow rules
writeAllows();
}
else {
// disallow any operation unless for 'postUpdate'
writer.write(`return ${kind === 'postUpdate' ? expression_writer_1.TRUE : expression_writer_1.FALSE};`);
}
writer.write(';');
});
}
const func = sourceFile.addFunction({
name: (0, sdk_1.getQueryGuardFunctionName)(model, forField, fieldOverride, kind),
returnType: 'any',
parameters: [
{
name: 'context',
type: 'QueryContext',
},
{
// for generating field references used by field comparison in the same model
name: 'db',
type: 'CrudContract',
},
],
statements,
});
return func;
}
function generateEntityCheckerFunction(sourceFile, model, kind, allows, denies, forField, fieldOverride = false) {
const statements = [];
generateNormalizedAuthRef(model, allows, denies, statements);
const transformer = new sdk_1.TypeScriptExpressionTransformer({
context: sdk_1.ExpressionContext.AccessPolicy,
thisExprContext: 'input',
fieldReferenceContext: 'input',
isPostGuard: kind === 'postUpdate',
futureRefContext: 'input',
operationContext: kind,
});
denies.forEach((rule) => {
const compiled = transformer.transform(rule, false);
statements.push(`if (${compiled}) { return false; }`);
});
allows.forEach((rule) => {
const compiled = transformer.transform(rule, false);
statements.push(`if (${compiled}) { return true; }`);
});
if (kind === 'postUpdate') {
// 'postUpdate' rule defaults to allow
statements.push('return true;');
}
else {
if (forField) {
// if there's no allow rule, for field-level rules, by default we allow
if (allows.length === 0) {
statements.push('return true;');
}
else {
// if there's any allow rule, we deny unless any allow rule evaluates to true
statements.push(`return false;`);
}
}
else {
// for other cases, defaults to deny
statements.push(`return false;`);
}
}
const func = sourceFile.addFunction({
name: (0, sdk_1.getEntityCheckerFunctionName)(model, forField, fieldOverride, kind),
returnType: 'any',
parameters: [
{
name: 'input',
type: 'any',
},
{
name: 'context',
type: 'QueryContext',
},
],
statements,
});
return func;
}
/**
* Generates a normalized auth reference for the given policy rules
*/
function generateNormalizedAuthRef(model, allows, denies, statements) {
// check if any allow or deny rule contains 'auth()' invocation
const hasAuthRef = [...allows, ...denies].some((rule) => (0, langium_1.streamAst)(rule).some((child) => (0, sdk_1.isAuthInvocation)(child)));
if (hasAuthRef) {
const authModel = (0, sdk_1.getAuthDecl)((0, sdk_1.getDataModelAndTypeDefs)(model.$container, true));
if (!authModel) {
throw new sdk_1.PluginError(__1.name, 'Auth model not found');
}
const userIdFields = (0, sdk_1.getIdFields)(authModel);
if (!userIdFields || userIdFields.length === 0) {
throw new sdk_1.PluginError(__1.name, 'User model does not have an id field');
}
// normalize user to null to avoid accidentally use undefined in filter
statements.push(`const user: any = context.user ?? null;`);
}
}
/**
* Check if the given enum is referenced in the model
*/
function isEnumReferenced(model, decl) {
const dataModels = (0, sdk_1.getDataModels)(model);
return dataModels.some((dm) => {
return (0, langium_1.streamAllContents)(dm).some((node) => {
var _a, _b;
if ((0, ast_1.isDataModelField)(node) && ((_a = node.type.reference) === null || _a === void 0 ? void 0 : _a.ref) === decl) {
// referenced as field type
return true;
}
if ((0, sdk_1.isEnumFieldReference)(node) && ((_b = node.target.ref) === null || _b === void 0 ? void 0 : _b.$container) === decl) {
// enum field is referenced
return true;
}
return false;
});
});
}
function hasCrossModelComparison(expr) {
return (0, langium_1.streamAst)(expr).some((node) => {
if ((0, ast_1.isBinaryExpr)(node) && ['==', '!=', '>', '<', '>=', '<=', 'in'].includes(node.operator)) {
const leftRoot = getSourceModelOfFieldAccess(node.left);
const rightRoot = getSourceModelOfFieldAccess(node.right);
if (leftRoot && rightRoot && leftRoot !== rightRoot) {
return true;
}
}
return false;
});
}
function getSourceModelOfFieldAccess(expr) {
var _a, _b;
// `auth()` access doesn't involve db field look up so doesn't count as cross-model comparison
if ((0, sdk_1.isAuthInvocation)(expr)) {
return undefined;
}
// an expression that resolves to a data model and is part of a member access, return the model
// e.g.: profile.age => Profile
if ((0, ast_1.isDataModel)((_a = expr.$resolvedType) === null || _a === void 0 ? void 0 : _a.decl) && (0, ast_1.isMemberAccessExpr)(expr.$container)) {
return (_b = expr.$resolvedType) === null || _b === void 0 ? void 0 : _b.decl;
}
// `this` reference
if ((0, ast_1.isThisExpr)(expr)) {
return (0, langium_1.getContainerOfType)(expr, ast_1.isDataModel);
}
// `future()`
if ((0, ast_utils_1.isFutureInvocation)(expr)) {
return (0, langium_1.getContainerOfType)(expr, ast_1.isDataModel);
}
// direct field reference, return the model
if ((0, sdk_1.isDataModelFieldReference)(expr)) {
return expr.target.ref.$container;
}
// member access
if ((0, ast_1.isMemberAccessExpr)(expr)) {
return getSourceModelOfFieldAccess(expr.operand);
}
return undefined;
}
//# sourceMappingURL=utils.js.map
;