UNPKG

@thi.ng/shader-ast-optimize

Version:

Shader AST code/tree optimization passes/strategies

212 lines (211 loc) 6.35 kB
import { DEFAULT, defmulti } from "@thi.ng/defmulti/defmulti"; import { illegalArgs } from "@thi.ng/errors/illegal-arguments"; import { LogLevel } from "@thi.ng/logger/api"; import { deg, rad } from "@thi.ng/math/angle"; import { clamp } from "@thi.ng/math/interval"; import { mix } from "@thi.ng/math/mix"; import { fract, mod } from "@thi.ng/math/prec"; import { matchingPrimFor, neg } from "@thi.ng/shader-ast"; import { isFloat, isInt, isLitNumOrVecConst, isLitNumericConst, isLitVecConst, isUint } from "@thi.ng/shader-ast/ast/checks"; import { FLOAT0, FLOAT1, FLOAT2, bool, float, int, lit, uint } from "@thi.ng/shader-ast/ast/lit"; import { allChildren, walk } from "@thi.ng/shader-ast/ast/scope"; import { LOGGER } from "@thi.ng/shader-ast/logger"; const __replaceNode = (node, next) => { if (LOGGER.level <= LogLevel.DEBUG) { LOGGER.debug(`replacing AST node:`); LOGGER.debug(" old: " + JSON.stringify(node)); LOGGER.debug(" new: " + JSON.stringify(next)); } for (let k in node) { !next.hasOwnProperty(k) && delete node[k]; } Object.assign(node, next); return true; }; const __replaceNumericNode = (node, res) => { node.type === "int" && (res |= 0); node.type === "uint" && (res >>>= 0); return __replaceNode(node, lit(node.type, res)); }; const __replaceBooleanNode = (node, res) => __replaceNode(node, bool(res)); const __replaceWithConst = (node, ref, n) => __replaceNode(node, matchingPrimFor(ref, n)); const __maybeFoldMath = (op, l, r) => op === "+" ? l + r : op === "-" ? l - r : op === "*" ? l * r : op === "/" ? r != 0 ? l / r : illegalArgs(`division by zero: ${l}/${r}`) : void 0; const __maybeFoldCompare = (op, l, r) => op === "==" ? l === r : op === "!=" ? l !== r : op === "<" ? l < r : op === "<=" ? l <= r : op === ">=" ? l >= r : op === ">" ? l > r : void 0; const COMPS = { x: 0, y: 1, z: 2, w: 3 }; const BUILTINS = { abs: ([a]) => Math.abs(a), acos: ([a]) => Math.acos(a), asin: ([a]) => Math.asin(a), ceil: ([a]) => Math.ceil(a), clamp: ([a, b, c]) => clamp(a, b, c), cos: ([a]) => Math.cos(a), degrees: ([a]) => deg(a), exp: ([a]) => Math.exp(a), exp2: ([a]) => Math.pow(2, a), floor: ([a]) => Math.floor(a), fract: ([a]) => fract(a), inversesqrt: ([a]) => 1 / Math.sqrt(a), log: ([a]) => Math.log(a), log2: ([a]) => Math.log2(a), max: ([a, b]) => Math.max(a, b), min: ([a, b]) => Math.min(a, b), mix: ([a, b, c]) => mix(a, b, c), mod: ([a, b]) => mod(a, b), pow: ([a, b]) => Math.pow(a, b), radians: ([a]) => rad(a), sign: ([a]) => Math.sign(a), sin: ([a]) => Math.sin(a), tan: ([a]) => Math.tan(a), sqrt: ([a]) => Math.sqrt(a) }; const foldNode = defmulti( (t) => t.tag, {}, { [DEFAULT]: () => false, op1: (node) => { const $node = node; if ($node.op == "-" && isLitNumericConst($node.val)) { $node.val.val *= -1; return __replaceNode(node, $node.val); } }, op2: (node) => { const $node = node; const op = $node.op; const l = $node.l; const r = $node.r; const isNumL = isLitNumericConst(l); const isNumR = isLitNumericConst(r); if (isNumL && isNumR) { const num = __maybeFoldMath(op, l.val, r.val); if (num !== void 0) return __replaceNumericNode(node, num); const bool2 = __maybeFoldCompare(op, l.val, r.val); if (bool2 !== void 0) return __replaceBooleanNode(node, bool2); } else if (op === "*") { if (isNumL && l.val === 0) return __replaceWithConst(node, r, FLOAT0); if (isNumR && r.val === 0) return __replaceWithConst(node, l, FLOAT0); if (isNumL && l.val === 1) return __replaceNode(node, r); if (isNumR && r.val === 1) return __replaceNode(node, l); } else if (op === "/") { if (isNumL && l.val === 0) return __replaceWithConst(node, r, FLOAT0); if (isNumR && r.val === 0) illegalArgs("division by zero"); if (isNumR && r.val === 1) return __replaceNode(node, l); } else if (op === "+") { if (isNumL && l.val === 0) return __replaceNode(node, r); if (isNumR && r.val === 0) return __replaceNode(node, l); } else if (op === "-") { if (isNumL && l.val === 0) return __replaceNode(node, neg(r)); if (isNumR && r.val === 0) return __replaceNode(node, l); } }, call_i: (node) => { const $node = node; if ($node.args.every((x) => isLitNumericConst(x))) { const op = BUILTINS[$node.id]; if (op !== void 0) { return __replaceNumericNode( node, op($node.args.map((x) => x.val)) ); } } else { return foldBuiltin($node); } }, lit: (node) => { const $node = node; if (isLitNumericConst($node.val)) { if (isFloat($node.val)) { return __replaceNode(node, float($node.val.val)); } if (isInt($node.val)) { return __replaceNode(node, int($node.val.val)); } if (isUint($node.val)) { return __replaceNode(node, uint($node.val.val)); } } }, swizzle: (node) => { const $node = node; const val = $node.val; if (isLitVecConst(val)) { if (isFloat(node)) { return __replaceNode( node, float(val.val[COMPS[$node.id]]) ); } } } } ); const foldBuiltin = defmulti( (x) => x.id, {}, { [DEFAULT]: () => false, exp2: (node) => { const a = node.args[0]; if (isLitNumOrVecConst(a, 0)) { return __replaceWithConst(node, a, FLOAT1); } if (isLitNumOrVecConst(a, 1)) { return __replaceWithConst(node, a, FLOAT2); } }, pow: (node) => { const [a, b] = node.args; if (isLitNumOrVecConst(b, 0)) { return __replaceWithConst(node, a, FLOAT1); } if (isLitNumOrVecConst(b, 1)) { return __replaceNode(node, a); } } } ); const constantFolding = (tree) => { let exec = true; while (exec) { exec = false; walk( (_, node) => { exec = foldNode(node) || exec; }, allChildren, null, tree, false ); } return tree; }; export { constantFolding, foldBuiltin, foldNode };