UNPKG

@mondaydotcomorg/atp-compiler

Version:

Production-ready compiler for transforming async iteration patterns into resumable operations with checkpoint-based state management

138 lines 6.5 kB
import * as t from '@babel/types'; import { generateUniqueId } from '../runtime/context.js'; import { containsAwait, extractForOfParamName } from './utils.js'; import { BatchOptimizer } from './batch-optimizer.js'; import { BatchParallelDetector } from './batch-detector.js'; import { findLLMCallExpression } from './array-transformer-utils.js'; export class LoopTransformer { transformCount = 0; batchOptimizer; batchDetector; batchSizeThreshold; constructor(batchSizeThreshold = 10) { this.batchOptimizer = new BatchOptimizer(); this.batchDetector = new BatchParallelDetector(); this.batchSizeThreshold = batchSizeThreshold; } transformForOfLoop(path) { const node = path.node; if (!containsAwait(node.body)) { return false; } const batchResult = this.batchOptimizer.canBatchForOfLoop(node); if (batchResult.canBatch) { const decision = this.batchOptimizer.makeSmartBatchDecision('for...of', batchResult, node.right, this.batchSizeThreshold); if (decision.shouldBatch) { return this.transformForOfToBatch(path, node); } } return this.transformForOfToSequential(path, node); } /** * Transform simple for...of to batch parallel */ transformForOfToBatch(path, node) { const loopId = generateUniqueId('for_of_batch'); const right = node.right; const paramName = extractForOfParamName(node.left); const llmCall = findLLMCallExpression(node.body); if (!llmCall) { return this.transformForOfToSequential(path, node); } const callInfo = this.batchDetector.extractCallInfo(llmCall); if (!callInfo) { return this.transformForOfToSequential(path, node); } const payloadNode = this.batchDetector.extractPayloadNode(llmCall); if (!payloadNode) { return this.transformForOfToSequential(path, node); } const batchCallsArray = t.callExpression(t.memberExpression(right, t.identifier('map')), [ t.arrowFunctionExpression([t.identifier(paramName)], t.objectExpression([ t.objectProperty(t.identifier('type'), t.stringLiteral(callInfo.type)), t.objectProperty(t.identifier('operation'), t.stringLiteral(callInfo.operation)), t.objectProperty(t.identifier('payload'), payloadNode), ])), ]); const batchCall = t.awaitExpression(t.callExpression(t.memberExpression(t.identifier('__runtime'), t.identifier('batchParallel')), [batchCallsArray, t.stringLiteral(loopId)])); path.replaceWith(t.expressionStatement(batchCall)); this.transformCount++; return true; } /** * Transform for...of to sequential with checkpoints (fallback) */ transformForOfToSequential(path, node) { const loopId = generateUniqueId('for_of'); const right = node.right; const paramName = extractForOfParamName(node.left); const bodyStatements = t.isBlockStatement(node.body) ? node.body.body : [node.body]; const callbackFn = t.arrowFunctionExpression([t.identifier(paramName), t.identifier('__index')], t.blockStatement(bodyStatements), true); const runtimeCall = t.awaitExpression(t.callExpression(t.memberExpression(t.identifier('__runtime'), t.identifier('resumableForOf')), [right, callbackFn, t.stringLiteral(loopId)])); path.replaceWith(t.expressionStatement(runtimeCall)); this.transformCount++; return true; } transformWhileLoop(path) { const node = path.node; if (!containsAwait(node.body)) { return false; } const loopId = generateUniqueId('while'); const conditionFn = t.arrowFunctionExpression([], node.test, false); const bodyStatements = t.isBlockStatement(node.body) ? node.body.body : [node.body]; const bodyFn = t.arrowFunctionExpression([t.identifier('__iteration')], t.blockStatement(bodyStatements), true); const runtimeCall = t.awaitExpression(t.callExpression(t.memberExpression(t.identifier('__runtime'), t.identifier('resumableWhile')), [conditionFn, bodyFn, t.stringLiteral(loopId)])); path.replaceWith(t.expressionStatement(runtimeCall)); this.transformCount++; return true; } transformForLoop(path) { const node = path.node; if (!containsAwait(node.body)) { return false; } if (!node.init || !node.test || !node.update) { return false; } const loopId = generateUniqueId('for'); let initValue = t.numericLiteral(0); let loopVar = '__i'; if (t.isVariableDeclaration(node.init)) { const decl = node.init.declarations[0]; if (decl && t.isIdentifier(decl.id) && decl.init) { loopVar = decl.id.name; initValue = decl.init; } } const conditionFn = t.arrowFunctionExpression([t.identifier(loopVar)], node.test, false); const bodyStatements = t.isBlockStatement(node.body) ? node.body.body : [node.body]; const bodyFn = t.arrowFunctionExpression([t.identifier(loopVar)], t.blockStatement(bodyStatements), true); let incrementFn; if (t.isUpdateExpression(node.update)) { if (node.update.operator === '++') { incrementFn = t.arrowFunctionExpression([t.identifier(loopVar)], t.binaryExpression('+', t.identifier(loopVar), t.numericLiteral(1)), false); } else if (node.update.operator === '--') { incrementFn = t.arrowFunctionExpression([t.identifier(loopVar)], t.binaryExpression('-', t.identifier(loopVar), t.numericLiteral(1)), false); } else { return false; } } else { return false; } const runtimeCall = t.awaitExpression(t.callExpression(t.memberExpression(t.identifier('__runtime'), t.identifier('resumableForLoop')), [initValue, conditionFn, incrementFn, bodyFn, t.stringLiteral(loopId)])); path.replaceWith(t.expressionStatement(runtimeCall)); this.transformCount++; return true; } getTransformCount() { return this.transformCount; } resetTransformCount() { this.transformCount = 0; } } //# sourceMappingURL=loop-transformer.js.map