autosnippet
Version:
Extract code patterns into a knowledge base for AI coding assistants
347 lines (346 loc) • 13.5 kB
JavaScript
/**
* ToolExecutionPipeline — 工具执行的中间件管道
*
* 将 reactLoop 中 ~120 行的工具执行逻辑拆分为独立中间件:
* before → execute → after
*
* 每个中间件负责一个横切关注点:
* 1. EventBusPublisher — 事件发布
* 2. ProgressEmitter — 进度回调
* 3. SafetyGate — SafetyPolicy 安全拦截
* 4. CacheCheck — MemoryCoordinator 缓存命中
* 5. ObservationRecord — 记忆记录
* 6. TrackerSignal — ExplorationTracker 信号收集
* 7. TraceRecord — ActiveContext 推理链记录
* 8. SubmitDedup — 提交去重
*
* @module core/ToolExecutionPipeline
*/
import { SafetyPolicy } from '../policies.js';
export class ToolExecutionPipeline {
#middlewares = [];
/** 注册中间件 */
use(middleware) {
this.#middlewares.push(middleware);
return this;
}
/**
* 执行单个工具调用
*
* 执行流:
* 1. 依次调用 before 钩子 — 任一返回 blocked/result 则短路
* 2. 实际执行工具 (toolRegistry.execute)
* 3. 依次调用 after 钩子
*
* @param call { name, args, id }
* @param context { runtime, loopCtx, iteration }
* @returns >}
*/
async execute(call, context) {
let toolResult = null;
const metadata = { cacheHit: false, blocked: false, isNew: false, durationMs: 0 };
// ── before 阶段 ──
for (const mw of this.#middlewares) {
if (mw.before) {
const verdict = await mw.before(call, context, metadata);
if (verdict?.blocked) {
toolResult = verdict.result;
metadata.blocked = true;
break;
}
if (verdict?.result !== undefined) {
toolResult = verdict.result;
metadata.cacheHit = true;
break;
}
}
}
// ── execute 阶段 ──
if (toolResult === null) {
const t0 = Date.now();
try {
const { runtime, loopCtx } = context;
const safetyPolicy = runtime.policies.get?.(SafetyPolicy) || null;
toolResult = await runtime.toolRegistry.execute(call.name, call.args, {
agentId: runtime.id,
source: loopCtx.source || runtime.presetName,
container: runtime.container,
safetyPolicy,
projectRoot: runtime.projectRoot,
fileCache: runtime.fileCache,
lang: runtime.lang,
logger: runtime.logger || null,
aiProvider: runtime.aiProvider || null,
// ── bootstrap 维度上下文 (从 sharedState 透传) ──
_submittedTitles: loopCtx.sharedState?.submittedTitles || null,
_submittedPatterns: loopCtx.sharedState?.submittedPatterns || null,
_sharedState: loopCtx.sharedState || null,
_dimensionMeta: loopCtx.sharedState?._dimensionMeta || null,
_projectLanguage: loopCtx.sharedState?._projectLanguage || null,
_memoryCoordinator: loopCtx.memoryCoordinator || null,
_dimensionScopeId: loopCtx.sharedState?._dimensionScopeId || null,
_currentRound: loopCtx.iteration || 0,
});
}
catch (err) {
toolResult = { error: err.message };
}
metadata.durationMs = Date.now() - t0;
}
// ── after 阶段 ──
for (const mw of this.#middlewares) {
if (mw.after) {
await mw.after(call, toolResult, context, metadata);
}
}
return { result: toolResult, metadata };
}
}
// ─────────────────────────────────────────────
// 预置中间件
// ─────────────────────────────────────────────
/**
* AllowlistGate — 工具白名单守卫
*
* 防止 LLM hallucinate 不在当前 capability 允许列表中的工具调用。
* 从 LoopContext.toolSchemas 中提取允许的工具名列表,
* 拒绝不在列表中的调用(返回 error 提示)。
*
* Forge 集成:不在白名单的工具如果已由 ToolForge 锻造(存在于 ToolRegistry),则放行。
*
* before: 如果工具不在白名单中且非锻造工具则短路返回 error
*/
export const allowlistGate = {
name: 'allowlistGate',
before(call, ctx) {
const schemas = ctx.loopCtx?.toolSchemas;
// 如果没有 schema 列表(全工具模式),跳过检查
if (!schemas || schemas.length === 0) {
return undefined;
}
const allowedNames = new Set(schemas.map((s) => s.name || s.function?.name));
if (!allowedNames.has(call.name)) {
// Forge fallback: 不在白名单但已注册到 ToolRegistry(如锻造的临时工具)则放行
if (ctx.runtime.toolRegistry?.has(call.name)) {
ctx.runtime.logger.info(`[ToolPipeline] Tool "${call.name}" not in allowlist but exists in registry (forged?) — allowed`);
return undefined;
}
ctx.runtime.logger.warn(`[ToolPipeline] ⛔ Tool "${call.name}" not in allowlist — blocked (hallucinated call)`);
return {
blocked: true,
result: {
error: `工具 "${call.name}" 不可用。当前可用工具: ${[...allowedNames].slice(0, 5).join(', ')}${allowedNames.size > 5 ? '...' : ''}`,
},
};
}
return undefined;
},
};
/**
* SafetyGate — SafetyPolicy 安全拦截
*
* before: 如果策略拒绝则短路返回 error
*/
export const safetyGate = {
name: 'safetyGate',
before(call, ctx) {
const check = ctx.runtime.policies.validateToolCall(call.name, call.args);
if (!check.ok) {
ctx.runtime.logger.warn(`[ToolPipeline] Tool blocked by Policy: ${call.name} — ${check.reason}`);
return { blocked: true, result: { error: check.reason } };
}
return undefined;
},
};
/**
* CacheCheck — MemoryCoordinator 缓存命中
*
* before: 如果缓存命中则短路返回缓存值
*/
export const cacheCheck = {
name: 'cacheCheck',
before(call, ctx) {
const mc = ctx.loopCtx.memoryCoordinator;
if (!mc) {
return undefined;
}
const cached = mc.getCachedResult?.(call.name, call.args);
if (cached !== null && cached !== undefined) {
ctx.runtime.logger.info(`[ToolPipeline] 🔧 CACHE HIT: ${call.name} → skipped execution`);
return { result: cached };
}
return undefined;
},
};
/**
* ObservationRecord — MemoryCoordinator 观察记录
*
* after: 记录工具执行观察
*/
export const observationRecord = {
name: 'observationRecord',
after(call, result, ctx, meta) {
ctx.loopCtx.memoryCoordinator?.recordObservation?.(call.name, call.args, result, ctx.iteration, meta.cacheHit);
},
};
/**
* TrackerSignal — ExplorationTracker 信号收集
*
* after: 记录工具调用信号,更新 isNew 标记
*/
export const trackerSignal = {
name: 'trackerSignal',
after(call, result, ctx, meta) {
if (ctx.loopCtx.tracker) {
const r = ctx.loopCtx.tracker.recordToolCall(call.name, call.args, result);
meta.isNew = r.isNew;
}
},
};
/**
* TraceRecord — ActiveContext 推理链记录
*
* after: 记录 Action + Observation 到推理链
*/
export const traceRecord = {
name: 'traceRecord',
after(call, result, ctx, meta) {
ctx.loopCtx.trace?.recordToolCall(call.name, call.args, result, meta.isNew);
},
};
/**
* SubmitDedup — 提交去重
*
* after: 检查并标记重复提交 (修改 metadata)
*/
export const submitDedup = {
name: 'submitDedup',
after(call, result, ctx, meta) {
const { sharedState } = ctx.loopCtx;
if (!sharedState) {
return;
}
if (call.name !== 'submit_knowledge' && call.name !== 'submit_with_check') {
return;
}
const title = String(call.args?.title || call.args?.category || '');
const resultObj = result;
const isRejected = typeof result === 'object' && resultObj?.status === 'rejected';
const isError = typeof result === 'object' && (resultObj?.error || resultObj?.status === 'error');
if (!isRejected && !isError && sharedState.submittedTitles) {
const normalizedTitle = title.toLowerCase().trim();
// ── trigger 去重 (防止不同 title 相同 trigger 的跨维度重复) ──
const trigger = String(call.args?.trigger || '')
.toLowerCase()
.trim();
if (trigger && sharedState.submittedTriggers?.has(trigger)) {
meta.dedupMessage = `⚠ 重复 trigger: "${trigger}" 已被其他候选占用。`;
ctx.runtime.logger.info(`[ToolPipeline] 🔁 duplicate trigger: "${trigger}"`);
return;
}
if (sharedState.submittedTitles.has(normalizedTitle)) {
meta.dedupMessage = `⚠ 重复提交: "${title}" 已存在。`;
ctx.runtime.logger.info(`[ToolPipeline] 🔁 duplicate: "${title}"`);
}
else {
sharedState.submittedTitles.add(normalizedTitle);
// trigger 去重注册
if (trigger && sharedState.submittedTriggers) {
sharedState.submittedTriggers.add(trigger);
}
// 模式指纹去重
const contentObj = call.args?.content;
const pattern = String(contentObj?.pattern || '');
if (pattern.length >= 30 && sharedState.submittedPatterns) {
const fp = pattern
.replace(/\/\/[^\n]*/g, '')
.replace(/\/\*[\s\S]*?\*\//g, '')
.replace(/[\s]+/g, '')
.toLowerCase()
.slice(0, 200);
if (fp.length >= 20) {
sharedState.submittedPatterns.add(fp);
}
}
meta.isSubmit = true;
}
}
},
};
/**
* ProgressEmitter — 进度回调 (可选,需 runtime.emitProgress 为 public)
*
* NOTE: 默认管道不包含此中间件,因为 tool_end 事件需要 resultStr.length,
* 而 resultStr 在管道外部计算。由 #processToolCalls 直接处理。
*/
export const progressEmitter = {
name: 'progressEmitter',
before(call, ctx) {
ctx.runtime.emitProgress?.('tool_call', { tool: call.name, args: call.args });
},
after(call, result, ctx, meta) {
const resultObj = result;
ctx.runtime.emitProgress?.('tool_end', {
tool: call.name,
duration: meta.durationMs,
status: resultObj?.error ? 'error' : 'ok',
error: resultObj?.error || undefined,
});
},
};
/**
* EventBusPublisher — EventBus 事件发布 (可选)
*
* NOTE: 默认管道不包含此中间件。由 #processToolCalls 直接处理,
* 与原始 reactLoop 保持完全一致的事件顺序。
*/
export const eventBusPublisher = {
name: 'eventBusPublisher',
before(call, ctx) {
if (ctx.runtime.bus?.publish) {
ctx.runtime.bus.publish('tool:call:start', {
agentId: ctx.runtime.id,
tool: call.name,
}, { source: ctx.runtime.id });
}
},
after(call, result, ctx, meta) {
const resultObj = result;
if (ctx.runtime.bus?.publish) {
ctx.runtime.bus.publish('tool:call:end', {
agentId: ctx.runtime.id,
tool: call.name,
durationMs: meta.durationMs,
success: !resultObj?.error,
}, { source: ctx.runtime.id });
}
},
};
// ─────────────────────────────────────────────
// Factory helper
// ─────────────────────────────────────────────
/**
* 创建预配置的工具执行管道
*
* 中间件顺序:
* 1. safetyGate (安全拦截 — 可短路)
* 2. cacheCheck (缓存检查 — 可短路)
* 3. observationRecord (记忆记录)
* 4. trackerSignal (信号收集)
* 5. traceRecord (推理链)
* 6. submitDedup (提交去重)
*
* NOTE: eventBusPublisher 和 progressEmitter 不在默认管道中,
* 由 #processToolCalls 直接处理,以保持与原始 reactLoop 完全一致的事件顺序
* (progress_end 需要 resultStr.length,在管道外计算)。
*/
export function createToolPipeline() {
return new ToolExecutionPipeline()
.use(allowlistGate)
.use(safetyGate)
.use(cacheCheck)
.use(observationRecord)
.use(trackerSignal)
.use(traceRecord)
.use(submitDedup);
}