@thi.ng/shader-ast-optimize
Version:
Shader AST code/tree optimization passes/strategies
212 lines (211 loc) • 6.35 kB
JavaScript
import { DEFAULT, defmulti } from "@thi.ng/defmulti/defmulti";
import { illegalArgs } from "@thi.ng/errors/illegal-arguments";
import { LogLevel } from "@thi.ng/logger/api";
import { deg, rad } from "@thi.ng/math/angle";
import { clamp } from "@thi.ng/math/interval";
import { mix } from "@thi.ng/math/mix";
import { fract, mod } from "@thi.ng/math/prec";
import {
matchingPrimFor,
neg
} from "@thi.ng/shader-ast";
import {
isFloat,
isInt,
isLitNumOrVecConst,
isLitNumericConst,
isLitVecConst,
isUint
} from "@thi.ng/shader-ast/ast/checks";
import {
FLOAT0,
FLOAT1,
FLOAT2,
bool,
float,
int,
lit,
uint
} from "@thi.ng/shader-ast/ast/lit";
import { allChildren, walk } from "@thi.ng/shader-ast/ast/scope";
import { LOGGER } from "@thi.ng/shader-ast/logger";
const __replaceNode = (node, next) => {
if (LOGGER.level <= LogLevel.DEBUG) {
LOGGER.debug(`replacing AST node:`);
LOGGER.debug(" old: " + JSON.stringify(node));
LOGGER.debug(" new: " + JSON.stringify(next));
}
for (let k in node) {
!next.hasOwnProperty(k) && delete node[k];
}
Object.assign(node, next);
return true;
};
const __replaceNumericNode = (node, res) => {
node.type === "int" && (res |= 0);
node.type === "uint" && (res >>>= 0);
return __replaceNode(node, lit(node.type, res));
};
const __replaceBooleanNode = (node, res) => __replaceNode(node, bool(res));
const __replaceWithConst = (node, ref, n) => __replaceNode(node, matchingPrimFor(ref, n));
const __maybeFoldMath = (op, l, r) => op === "+" ? l + r : op === "-" ? l - r : op === "*" ? l * r : op === "/" ? r != 0 ? l / r : illegalArgs(`division by zero: ${l}/${r}`) : void 0;
const __maybeFoldCompare = (op, l, r) => op === "==" ? l === r : op === "!=" ? l !== r : op === "<" ? l < r : op === "<=" ? l <= r : op === ">=" ? l >= r : op === ">" ? l > r : void 0;
const COMPS = { x: 0, y: 1, z: 2, w: 3 };
const BUILTINS = {
abs: ([a]) => Math.abs(a),
acos: ([a]) => Math.acos(a),
asin: ([a]) => Math.asin(a),
ceil: ([a]) => Math.ceil(a),
clamp: ([a, b, c]) => clamp(a, b, c),
cos: ([a]) => Math.cos(a),
degrees: ([a]) => deg(a),
exp: ([a]) => Math.exp(a),
exp2: ([a]) => Math.pow(2, a),
floor: ([a]) => Math.floor(a),
fract: ([a]) => fract(a),
inversesqrt: ([a]) => 1 / Math.sqrt(a),
log: ([a]) => Math.log(a),
log2: ([a]) => Math.log2(a),
max: ([a, b]) => Math.max(a, b),
min: ([a, b]) => Math.min(a, b),
mix: ([a, b, c]) => mix(a, b, c),
mod: ([a, b]) => mod(a, b),
pow: ([a, b]) => Math.pow(a, b),
radians: ([a]) => rad(a),
sign: ([a]) => Math.sign(a),
sin: ([a]) => Math.sin(a),
tan: ([a]) => Math.tan(a),
sqrt: ([a]) => Math.sqrt(a)
};
const foldNode = defmulti(
(t) => t.tag,
{},
{
[DEFAULT]: () => false,
op1: (node) => {
const $node = node;
if ($node.op == "-" && isLitNumericConst($node.val)) {
$node.val.val *= -1;
return __replaceNode(node, $node.val);
}
},
op2: (node) => {
const $node = node;
const op = $node.op;
const l = $node.l;
const r = $node.r;
const isNumL = isLitNumericConst(l);
const isNumR = isLitNumericConst(r);
if (isNumL && isNumR) {
const num = __maybeFoldMath(op, l.val, r.val);
if (num !== void 0) return __replaceNumericNode(node, num);
const bool2 = __maybeFoldCompare(op, l.val, r.val);
if (bool2 !== void 0) return __replaceBooleanNode(node, bool2);
} else if (op === "*") {
if (isNumL && l.val === 0)
return __replaceWithConst(node, r, FLOAT0);
if (isNumR && r.val === 0)
return __replaceWithConst(node, l, FLOAT0);
if (isNumL && l.val === 1) return __replaceNode(node, r);
if (isNumR && r.val === 1) return __replaceNode(node, l);
} else if (op === "/") {
if (isNumL && l.val === 0)
return __replaceWithConst(node, r, FLOAT0);
if (isNumR && r.val === 0) illegalArgs("division by zero");
if (isNumR && r.val === 1) return __replaceNode(node, l);
} else if (op === "+") {
if (isNumL && l.val === 0) return __replaceNode(node, r);
if (isNumR && r.val === 0) return __replaceNode(node, l);
} else if (op === "-") {
if (isNumL && l.val === 0) return __replaceNode(node, neg(r));
if (isNumR && r.val === 0) return __replaceNode(node, l);
}
},
call_i: (node) => {
const $node = node;
if ($node.args.every((x) => isLitNumericConst(x))) {
const op = BUILTINS[$node.id];
if (op !== void 0) {
return __replaceNumericNode(
node,
op($node.args.map((x) => x.val))
);
}
} else {
return foldBuiltin($node);
}
},
lit: (node) => {
const $node = node;
if (isLitNumericConst($node.val)) {
if (isFloat($node.val)) {
return __replaceNode(node, float($node.val.val));
}
if (isInt($node.val)) {
return __replaceNode(node, int($node.val.val));
}
if (isUint($node.val)) {
return __replaceNode(node, uint($node.val.val));
}
}
},
swizzle: (node) => {
const $node = node;
const val = $node.val;
if (isLitVecConst(val)) {
if (isFloat(node)) {
return __replaceNode(
node,
float(val.val[COMPS[$node.id]])
);
}
}
}
}
);
const foldBuiltin = defmulti(
(x) => x.id,
{},
{
[DEFAULT]: () => false,
exp2: (node) => {
const a = node.args[0];
if (isLitNumOrVecConst(a, 0)) {
return __replaceWithConst(node, a, FLOAT1);
}
if (isLitNumOrVecConst(a, 1)) {
return __replaceWithConst(node, a, FLOAT2);
}
},
pow: (node) => {
const [a, b] = node.args;
if (isLitNumOrVecConst(b, 0)) {
return __replaceWithConst(node, a, FLOAT1);
}
if (isLitNumOrVecConst(b, 1)) {
return __replaceNode(node, a);
}
}
}
);
const constantFolding = (tree) => {
let exec = true;
while (exec) {
exec = false;
walk(
(_, node) => {
exec = foldNode(node) || exec;
},
allChildren,
null,
tree,
false
);
}
return tree;
};
export {
constantFolding,
foldBuiltin,
foldNode
};