sqlauthz
Version:
Declarative permission management for PostgreSQL
644 lines (581 loc) • 16.9 kB
text/typescript
import { Variable } from "oso";
import { Expression } from "oso/dist/src/Expression.js";
import { Pattern } from "oso/dist/src/Pattern.js";
import { Predicate } from "oso/dist/src/Predicate.js";
import { PolarOperator } from "oso/dist/src/types.js";
import { arrayProduct } from "./utils.js";
export interface Literal {
readonly type: "value";
readonly value: unknown;
}
export interface FunctionCall {
readonly type: "function-call";
readonly schema: string;
readonly name: string;
readonly args: Value[];
}
export interface Column {
readonly type: "column";
readonly value: string;
}
export type Value = Literal | Column | FunctionCall;
export interface ExpressionClause {
readonly type: "expression";
readonly operator: PolarOperator;
readonly values: readonly [Value, Value];
}
export interface NotClause {
readonly type: "not";
readonly clause: Clause;
}
export interface AndClause {
readonly type: "and";
readonly clauses: readonly Clause[];
}
export interface OrClause {
readonly type: "or";
readonly clauses: readonly Clause[];
}
export type Clause =
| ExpressionClause
| NotClause
| AndClause
| OrClause
| Value;
export const TrueClause = {
type: "and",
clauses: [],
} as const satisfies AndClause;
export const FalseClause = {
type: "or",
clauses: [],
} as const satisfies OrClause;
export function isTrueClause(
clause: Clause,
): clause is AndClause & { clauses: [] } {
return clause.type === "and" && clause.clauses.length === 0;
}
export function isFalseClause(
clause: Clause,
): clause is OrClause & { clauses: [] } {
return clause.type === "or" && clause.clauses.length === 0;
}
export function isColumn(clause: Clause): clause is Column {
return clause.type === "column";
}
export function isValue(clause: Clause): clause is Value {
return clause.type === "value";
}
export function mapClauses(
clause: Clause,
func: (clause: Clause) => Clause,
): Clause {
if (clause.type === "and" || clause.type === "or") {
const subClauses = clause.clauses.map((subClause) =>
mapClauses(subClause, func),
);
return func({
type: clause.type,
clauses: subClauses,
});
}
if (clause.type === "not") {
const subClause = mapClauses(clause.clause, func);
return func({
type: "not",
clause: subClause,
});
}
if (clause.type === "expression") {
const newValues = clause.values.map((value) => mapClauses(value, func)) as [
Value,
Value,
];
return func({
type: "expression",
operator: clause.operator,
values: newValues,
});
}
if (clause.type === "function-call") {
const values = clause.args.map((arg) => mapClauses(arg, func)) as Value[];
return func({
type: "function-call",
schema: clause.schema,
name: clause.name,
args: values,
});
}
return func(clause);
}
function clausesEqual(clause1: Clause, clause2: Clause): boolean {
if (clause1.type !== clause2.type) {
return false;
}
if (
(clause1.type === "and" && clause2.type === "and") ||
(clause1.type === "or" && clause2.type === "or")
) {
const deduped1 = deduplicateClauses(clause1.clauses);
const deduped2 = deduplicateClauses(clause2.clauses);
return (
deduped1.length === deduped2.length &&
deduped1.every((clause, idx) => clausesEqual(clause, deduped2[idx]!))
);
}
if (clause1.type === "not" && clause2.type === "not") {
return clausesEqual(clause1.clause, clause2.clause);
}
if (clause1.type === "expression" && clause2.type === "expression") {
return (
clause1.operator === clause2.operator &&
clause1.values.every((value, idx) =>
clausesEqual(value, clause2.values[idx]!),
)
);
}
if (
(clause1.type === "value" && clause2.type === "value") ||
(clause1.type === "column" && clause2.type === "column")
) {
return clause1.value === clause2.value;
}
if (clause1.type === "function-call" && clause2.type === "function-call") {
return (
clause1.name === clause2.name &&
clause1.schema === clause2.schema &&
clause1.args.length === clause2.args.length &&
clause1.args.every((arg, idx) => arg === clause2.args[idx])
);
}
return false;
}
function deduplicateClauses(clauses: readonly Clause[]): readonly Clause[] {
if (clauses.length <= 1) {
return clauses;
}
if (clauses.length === 2) {
if (clausesEqual(clauses[0]!, clauses[1]!)) {
return [clauses[0]!];
}
return clauses;
}
const first = clauses[0]!;
const rest = deduplicateClauses(clauses.slice(1));
const out: Clause[] = [first];
for (const clause of rest) {
if (!clausesEqual(first, clause)) {
out.push(clause);
}
}
return out;
}
export function optimizeClause(clause: Clause): Clause {
if (clause.type === "and") {
const outClauses: Clause[] = [];
for (const subClause of deduplicateClauses(clause.clauses)) {
const optimized = optimizeClause(subClause);
if (isTrueClause(optimized)) {
continue;
}
if (isFalseClause(optimized)) {
return FalseClause;
}
if (optimized.type === "and") {
outClauses.push(...optimized.clauses);
continue;
}
outClauses.push(optimized);
}
if (outClauses.length === 1) {
return outClauses[0]!;
}
return {
type: "and",
clauses: outClauses,
};
}
if (clause.type === "or") {
const outClauses: Clause[] = [];
for (const subClause of deduplicateClauses(clause.clauses)) {
const optimized = optimizeClause(subClause);
if (isTrueClause(optimized)) {
return TrueClause;
}
if (isFalseClause(optimized)) {
continue;
}
if (optimized.type === "or") {
outClauses.push(...optimized.clauses);
continue;
}
outClauses.push(optimized);
}
if (outClauses.length === 1) {
return outClauses[0]!;
}
return { type: "or", clauses: outClauses };
}
if (clause.type === "not") {
const optimized = optimizeClause(clause.clause);
if (optimized.type === "and") {
const orClause: OrClause = {
type: "or",
clauses: optimized.clauses.map((subClause) => {
return { type: "not", clause: subClause };
}),
};
return optimizeClause(orClause);
}
if (optimized.type === "or") {
const andClause: AndClause = {
type: "and",
clauses: optimized.clauses.map((subClause) => ({
type: "not",
clause: subClause,
})),
};
return optimizeClause(andClause);
}
return { type: "not", clause: optimized };
}
return clause;
}
export function valueToClause(value: unknown): Clause {
if (value instanceof Expression) {
if (value.operator === "And") {
const outClauses = value.args.map((arg) => valueToClause(arg));
return { type: "and", clauses: outClauses };
}
if (value.operator === "Or") {
const outClauses = value.args.map((arg) => valueToClause(arg));
return { type: "or", clauses: outClauses };
}
if (value.operator === "Dot") {
if (typeof value.args[0] === "string") {
const col: Column = {
type: "column",
value: ["_this", value.args[1]].join("."),
};
return {
type: "and",
clauses: [
col,
{
type: "expression",
operator: "Eq",
values: [
{ type: "column", value: "_this" },
{ type: "value", value: value.args[0] },
],
},
],
};
}
const args = value.args.map((arg) => valueToClause(arg));
const src = args[0] as Value | AndClause;
const name = args[1] as Value;
// TODO: is this the right behavior?
if (src.type === "function-call" || name.type === "function-call") {
throw new Error("Unexpected function call");
}
if (src.type === "and") {
const col = src.clauses[0] as Column;
const newCol: Column = {
type: "column",
value: [col.value, name.value].join("."),
};
return {
type: "and",
clauses: [newCol, ...src.clauses.slice(1)],
};
}
return {
type: "column",
value: [src.value, name.value].join("."),
};
}
if (value.operator === "Not") {
const subClause = valueToClause(value.args[0]);
return {
type: "not",
clause: subClause,
};
// Ignore these operators
}
if (
value.operator === "Cut" ||
value.operator === "Assign" ||
value.operator === "ForAll" ||
value.operator === "Isa" ||
value.operator === "Print"
) {
return TrueClause;
}
const clauses: Clause[] = [];
const leftClause = valueToClause(value.args[0]) as Value | AndClause;
let left: Value;
if (leftClause.type === "and") {
left = leftClause.clauses[0] as Value;
clauses.push(...leftClause.clauses.slice(1));
} else {
left = leftClause;
}
const rightClause = valueToClause(value.args[1]) as Value | AndClause;
let right: Value;
if (rightClause.type === "and") {
right = rightClause.clauses[0] as Value;
clauses.push(...rightClause.clauses.slice(1));
} else {
right = rightClause;
}
const operator = value.operator === "Unify" ? "Eq" : value.operator;
const newClause: ExpressionClause = {
type: "expression",
operator,
values: [left, right],
};
if (clauses.length > 0) {
return { type: "and", clauses: [newClause, ...clauses] };
}
return newClause;
}
if (value instanceof Variable) {
return {
type: "column",
value: value.name,
};
}
if (value instanceof Pattern) {
// TODO
return TrueClause;
}
if (value instanceof Predicate) {
const parts = value.name.split(".");
let schema: string;
let name: string;
if (parts.length === 1) {
schema = "";
name = parts[0]!;
} else {
schema = parts[0]!;
name = parts[1]!;
}
const clauses: Clause[] = [];
const args: Value[] = [];
for (const arg of value.args) {
const subClause = valueToClause(arg) as Value | AndClause;
if (subClause.type === "and") {
args.push(subClause.clauses[0] as Value);
clauses.push(...subClause.clauses.slice(1));
} else {
args.push(subClause);
}
}
const newClause: FunctionCall = {
type: "function-call",
schema: schema!,
name: name!,
args,
};
if (clauses.length > 0) {
return { type: "and", clauses: [newClause, ...clauses] };
}
return newClause;
}
return { type: "value", value };
}
export function factorOrClauses(clause: Clause): Clause[] {
const inner = (clause: Clause): Clause[] => {
if (clause.type === "and") {
const subOrs = clause.clauses.map((subClause) =>
factorOrClauses(subClause),
);
return Array.from(arrayProduct(subOrs)).map((subClauses) => ({
type: "and",
clauses: subClauses,
}));
}
if (clause.type === "or") {
return clause.clauses.flatMap((subClause) => factorOrClauses(subClause));
}
if (clause.type === "not") {
const subClauses = factorOrClauses(clause.clause);
if (subClauses.length > 1) {
const negativeAndClause: AndClause = {
type: "and",
clauses: subClauses.map((subClause) => ({
type: "not",
clause: subClause,
})),
};
return factorOrClauses(negativeAndClause);
}
return [{ type: "not", clause: subClauses[0]! }];
}
return [clause];
};
return inner(optimizeClause(clause)).map((subClause) =>
optimizeClause(subClause),
);
}
export interface EvaluateClauseArgs {
clause: Clause;
evaluate: (
expr: Exclude<Clause, AndClause | OrClause | NotClause>,
) => EvaluateClauseResult;
strictFields?: boolean;
}
export interface EvaluateClauseSuccess {
type: "success";
result: boolean;
}
export interface EvaluateClauseError {
type: "error";
errors: string[];
}
export type EvaluateClauseResult = EvaluateClauseSuccess | EvaluateClauseError;
export function evaluateClause({
clause,
evaluate,
strictFields,
}: EvaluateClauseArgs): EvaluateClauseResult {
if (clause.type === "and") {
const errors: string[] = [];
let result = true;
for (const subClause of clause.clauses) {
const clauseResult = evaluateClause({ clause: subClause, evaluate });
if (clauseResult.type === "success") {
result &&= clauseResult.result;
} else {
errors.push(...clauseResult.errors);
}
}
if ((strictFields || result) && errors.length > 0) {
return { type: "error", errors };
}
return { type: "success", result };
}
if (clause.type === "or") {
const errors: string[] = [];
let result = false;
for (const subClause of clause.clauses) {
const clauseResult = evaluateClause({ clause: subClause, evaluate });
if (clauseResult.type === "success") {
result ||= clauseResult.result;
} else {
errors.push(...clauseResult.errors);
}
}
if (errors.length > 0) {
return { type: "error", errors };
}
return { type: "success", result };
}
if (clause.type === "not") {
const clauseResult = evaluateClause({ clause: clause.clause, evaluate });
if (clauseResult.type === "success") {
return { type: "success", result: !clauseResult.result };
}
return { type: "error", errors: clauseResult.errors };
}
return evaluate(clause);
}
export interface SimpleEvaluatorArgs {
variableName: string;
errorVariableName: string;
// biome-ignore lint/suspicious/noExplicitAny: needed here
getValue: (value: Value) => any;
}
export function simpleEvaluator({
variableName,
errorVariableName,
getValue,
}: SimpleEvaluatorArgs): EvaluateClauseArgs["evaluate"] {
const func: EvaluateClauseArgs["evaluate"] = (expr) => {
if (expr.type === "column" && expr.value === variableName) {
return { type: "success", result: true };
}
if (expr.type === "column") {
return {
type: "error",
errors: [`${errorVariableName}: invalid reference: ${expr.value}`],
};
}
if (expr.type === "value") {
return func({
type: "expression",
operator: "Eq",
values: [{ type: "column", value: "_this" }, expr],
});
}
if (expr.type === "function-call") {
// TODO: is this the right behavior?
return {
type: "error",
errors: [`${errorVariableName}: unexpected function call`],
};
}
let operatorFunc: (a: unknown, b: unknown) => boolean;
if (expr.operator === "Eq") {
operatorFunc = (a, b) => a === b;
} else if (expr.operator === "Neq") {
operatorFunc = (a, b) => a !== b;
} else if (expr.operator === "Geq") {
operatorFunc = (a, b) => (a as string | number) >= (b as string | number);
} else if (expr.operator === "Gt") {
operatorFunc = (a, b) => (a as string | number) > (b as string | number);
} else if (expr.operator === "Lt") {
operatorFunc = (a, b) => (a as string | number) < (b as string | number);
} else if (expr.operator === "Leq") {
operatorFunc = (a, b) => (a as string | number) <= (b as string | number);
} else {
return {
type: "error",
errors: [
`${errorVariableName}: unsupported operator: ${expr.operator}`,
],
};
}
if (expr.values[0].type === "value" && expr.values[1].type === "value") {
return {
type: "success",
result: operatorFunc(expr.values[0].value, expr.values[1].value),
};
}
const errors: string[] = [];
// biome-ignore lint/suspicious/noExplicitAny: needed here
let left: any;
// biome-ignore lint/suspicious/noExplicitAny: needed here
let right: any;
try {
left = getValue(expr.values[0]);
} catch (error) {
if (error instanceof ValidationError) {
errors.push(error.message);
} else {
throw error;
}
}
try {
right = getValue(expr.values[1]);
} catch (error) {
if (error instanceof ValidationError) {
errors.push(error.message);
} else {
throw error;
}
}
if (errors.length > 0) {
return { type: "error", errors };
}
return { type: "success", result: operatorFunc(left, right) };
};
return func;
}
export class ValidationError extends Error {
constructor(readonly message: string) {
super(message);
Object.setPrototypeOf(this, new.target.prototype);
}
}