@schoolai/spicedb-zed-schema-parser
Version:
SpiceDB .zed file format parser and analyzer written in Typescript
218 lines (202 loc) • 6.63 kB
text/typescript
import { PermissionExpression, RelationType } from '../schema-parser/parser'
import { SymbolTable } from './symbol-table'
export class TypeInferenceEngine {
constructor(private symbolTable: SymbolTable) {}
inferExpressionType(
defName: string,
expr: PermissionExpression,
callStack: Set<string> = new Set(),
): RelationType[] | null {
switch (expr.type) {
case 'identifier': {
const rel = this.symbolTable.getRelation(defName, expr.name)
if (rel) {
const resolvedTypes: RelationType[] = []
for (const typeRef of rel.types) {
if (typeRef.relation) {
const subjectRelationTypes = this.inferSubjectRelationType(
typeRef.typeName,
typeRef.relation,
callStack,
)
if (subjectRelationTypes) {
resolvedTypes.push(...subjectRelationTypes)
}
} else {
resolvedTypes.push(typeRef)
}
}
return this.deduplicateTypes(resolvedTypes)
}
const perm = this.symbolTable.getPermission(defName, expr.name)
if (perm) {
const stackKey = `${defName}#${expr.name}`
if (callStack.has(stackKey)) {
return null // Cycle detected
}
callStack.add(stackKey)
const result = this.inferExpressionType(
defName,
perm.expression,
callStack,
)
callStack.delete(stackKey)
return result
}
return null
}
case 'union': {
const allTypes: RelationType[] = []
for (const operand of expr.operands) {
const types = this.inferExpressionType(defName, operand, callStack)
if (types) {
allTypes.push(...types)
}
}
return this.deduplicateTypes(allTypes)
}
case 'intersection': {
let commonTypes: RelationType[] | null = null
for (const operand of expr.operands) {
const types = this.inferExpressionType(defName, operand, callStack)
if (!types) return null
if (!commonTypes) {
commonTypes = types
} else {
commonTypes = this.intersectTypes(commonTypes, types)
}
}
return commonTypes
}
case 'exclusion':
return this.inferExpressionType(defName, expr.left, callStack)
case 'arrow':
case 'any':
case 'all': {
const leftTypes = this.inferExpressionType(
defName,
expr.left,
callStack,
)
if (!leftTypes) return null
const resultTypes: RelationType[] = []
for (const leftType of leftTypes) {
const targetRel = this.symbolTable.getRelation(
leftType.typeName,
expr.target,
)
if (targetRel) {
for (const typeRef of targetRel.types) {
if (typeRef.relation) {
const subjectRelationTypes = this.inferSubjectRelationType(
typeRef.typeName,
typeRef.relation,
callStack,
)
if (subjectRelationTypes) {
resultTypes.push(...subjectRelationTypes)
}
} else {
resultTypes.push(typeRef)
}
}
}
const targetPerm = this.symbolTable.getPermission(
leftType.typeName,
expr.target,
)
if (targetPerm) {
const stackKey = `${leftType.typeName}#${expr.target}`
if (callStack.has(stackKey)) {
continue // Cycle detected, skip this path
}
callStack.add(stackKey)
const permTypes = this.inferExpressionType(
leftType.typeName,
targetPerm.expression,
callStack,
)
callStack.delete(stackKey)
if (permTypes) {
resultTypes.push(...permTypes)
}
}
}
return this.deduplicateTypes(resultTypes)
}
default:
return null
}
}
private inferSubjectRelationType(
typeName: string,
relationName: string,
callStack: Set<string>,
): RelationType[] | null {
const targetRel = this.symbolTable.getRelation(typeName, relationName)
if (targetRel) {
const resolvedTypes: RelationType[] = []
for (const typeRef of targetRel.types) {
if (typeRef.relation) {
const subjectRelationTypes = this.inferSubjectRelationType(
typeRef.typeName,
typeRef.relation,
callStack,
)
if (subjectRelationTypes) {
resolvedTypes.push(...subjectRelationTypes)
}
} else {
resolvedTypes.push(typeRef)
}
}
return this.deduplicateTypes(resolvedTypes)
}
const targetPerm = this.symbolTable.getPermission(typeName, relationName)
if (targetPerm) {
const stackKey = `${typeName}#${relationName}`
if (callStack.has(stackKey)) {
return null // Cycle detected
}
callStack.add(stackKey)
const result = this.inferExpressionType(
typeName,
targetPerm.expression,
callStack,
)
callStack.delete(stackKey)
return result
}
return null
}
private deduplicateTypes(types: RelationType[]): RelationType[] {
const seen = new Set<string>()
const result: RelationType[] = []
for (const type of types) {
const key = `${type.typeName}${type.wildcard ? ':*' : ''}${type.relation ? '#' + type.relation : ''}`
if (!seen.has(key)) {
seen.add(key)
result.push(type)
}
}
return result
}
private intersectTypes(a: RelationType[], b: RelationType[]): RelationType[] {
const result: RelationType[] = []
const bKeys = new Set(
b.map(
type =>
`${type.typeName}${type.wildcard ? ':*' : ''}${type.relation ? '#' + type.relation : ''}`,
),
)
for (const typeA of a) {
const keyA = `${typeA.typeName}${typeA.wildcard ? ':*' : ''}${typeA.relation ? '#' + typeA.relation : ''}`
if (bKeys.has(keyA)) {
result.push(typeA)
}
}
// Deduplication is handled by the caller (inferExpressionType for 'intersection') if needed,
// but intersecting already deduplicated lists should yield a deduplicated list.
return this.deduplicateTypes(result) // Ensure the result is deduplicated
}
}