babel-plugin-recursive-tail-calls
Version:
A babel plugin for performing tail call optimization in recursive functions.
76 lines (75 loc) • 3.57 kB
JavaScript
import { template } from "@babel/core";
import { isReturnStatement, expressionStatement, assignmentExpression, arrayPattern, arrayExpression, booleanLiteral, continueStatement, variableDeclaration, variableDeclarator, logicalExpression, binaryExpression, nullLiteral, identifier, } from "@babel/types";
import { tailRecursionFinder } from './tailRecursionFinder.js';
import { isRecCall } from "./utils.js";
export const callExpVisitor = {
CallExpression(path) {
const callsItself = isRecCall(path, this.functionIdentifier);
const isLast = isReturnStatement(path.parent);
const shouldOptimize = callsItself && isLast;
if (!shouldOptimize)
return;
this.recursion = true;
const args = this.arguments.map(({ identifier, defaultValue }, index) => {
var _a;
return {
identifier,
value: (_a = path.node.arguments[index]) !== null && _a !== void 0 ? _a : defaultValue,
};
});
const updateExpression = expressionStatement(assignmentExpression("=", arrayPattern(args.map(({ identifier }) => identifier)), arrayExpression(args.map(({ value }) => value))));
// the parent is ReturnStatement
path.parentPath.insertBefore(updateExpression);
path.parentPath.insertBefore(expressionStatement(assignmentExpression("=", this.conditionIdentifier, booleanLiteral(true))));
path.parentPath.insertBefore(continueStatement(this.labelIdentifier));
path.parentPath.remove();
},
ReturnStatement(path) {
// look for callExpression among children
const state = { found: false, functionIdentifier: this.functionIdentifier };
path.traverse(tailRecursionFinder, state);
if (!state.found)
return;
const returnExpression = path.get("argument");
if (returnExpression.isLogicalExpression()) {
path.replaceWithMultiple(logicalExprRewrite(returnExpression.node, path.scope));
}
else if (returnExpression.isConditionalExpression()) {
path.replaceWith(buildIfStatement(returnExpression.node.test, returnExpression.node.consequent, returnExpression.node.alternate));
}
},
Function(path) {
// skip nested functions
path.skip();
},
};
const ifTemplate = template.statement(`
if (%%condition%%) {
return %%caseTrue%%;
} else {
return %%caseFalse%%;
}
`);
function buildIfStatement(condition, caseTrue, caseFalse) {
return ifTemplate({ condition, caseTrue, caseFalse });
}
/**
* rewrite a logical expression to a more explicit if statement
*/
function logicalExprRewrite({ left, right, operator, }, scope) {
const resultIdentifier = scope.generateUidIdentifier("left");
// assign left to a variable so we don't evaluate it twice
const resultDeclaration = variableDeclaration("const", [
variableDeclarator(resultIdentifier, left),
]);
let ifStatement;
if (operator === "&&")
ifStatement = buildIfStatement(resultIdentifier, right, resultIdentifier);
else if (operator === "||")
ifStatement = buildIfStatement(resultIdentifier, resultIdentifier, right);
else if (operator === "??")
ifStatement = buildIfStatement(logicalExpression("||", binaryExpression("==", left, nullLiteral()), binaryExpression("==", left, identifier("undefined"))), resultIdentifier, right);
else
throw new Error("Unknown LogicalExpression operator: " + operator);
return [resultDeclaration, ifStatement];
}