babel-plugin-recursive-tail-calls
Version:
A babel plugin for performing tail call optimization in recursive functions.
66 lines (65 loc) • 3.21 kB
JavaScript
import { expressionStatement, assignmentExpression, arrayExpression, booleanLiteral, continueStatement, variableDeclaration, variableDeclarator, logicalExpression, binaryExpression, identifier, isArgumentPlaceholder, isJSXNamespacedName, callExpression, ifStatement, returnStatement, blockStatement, } from "@babel/types";
import { findRecursion } from "./tailRecursionFinder.js";
/**
* Rewrite complex `ReturnStatement`s to be explicit and update recursive
* `CallExpression`s to be loop based.
*/
export const callExpressionRewriter = {
ReturnStatement(path) {
const argument = path.get("argument");
if (!argument.isExpression() ||
!findRecursion(argument, this.functionIdentifier))
return;
this.recursion = true;
const returnExpression = path.get("argument");
if (returnExpression.isLogicalExpression()) {
path.replaceWithMultiple(logicalExprRewrite(returnExpression.node, path.scope));
}
else if (returnExpression.isConditionalExpression()) {
path.replaceWith(ifStatement(returnExpression.node.test, blockStatement([returnStatement(returnExpression.node.consequent)]), blockStatement([returnStatement(returnExpression.node.alternate)])));
}
else if (returnExpression.isCallExpression()) {
path.replaceWithMultiple(callExprRewrite(returnExpression, this.parameters, this.conditionIdentifier, this.labelIdentifier));
}
},
Function(path) {
// skip nested functions
path.skip();
},
};
/**
* Rewrite a logical expression to an explicit `IfStatement`
*/
function logicalExprRewrite({ left, right, operator, }, scope) {
const symbolIdentifier = scope.generateUidIdentifier("symbol");
const logicalResultIdentifier = scope.generateUidIdentifier("evaluation");
return [
// declare symbol
variableDeclaration("const", [
variableDeclarator(symbolIdentifier, callExpression(identifier("Symbol"), [])),
]),
// evaluate logical expression with symbol and store result
variableDeclaration("const", [
variableDeclarator(logicalResultIdentifier, logicalExpression(operator, left, symbolIdentifier)),
]),
ifStatement(binaryExpression("===", logicalResultIdentifier, symbolIdentifier), blockStatement([returnStatement(right)]), blockStatement([returnStatement(logicalResultIdentifier)])),
];
}
/**
* If `CallExpression` is a recursive call in tail position, replace it with
* an assignment for function parameters together with a `ContinueStatement`
*/
function callExprRewrite(path, parameters, conditionIdentifier, labelIdentifier) {
const args = path.node.arguments.map((arg) => {
if (isArgumentPlaceholder(arg) || isJSXNamespacedName(arg))
throw new Error("Invalid argument type");
return arg;
});
return [
// update arguments
expressionStatement(assignmentExpression("=", parameters, arrayExpression(args))),
// set loop condition
expressionStatement(assignmentExpression("=", conditionIdentifier, booleanLiteral(true))),
continueStatement(labelIdentifier),
];
}