UNPKG

claude-flow

Version:

Ruflo - Enterprise AI agent orchestration for Claude Code. Deploy 60+ specialized agents in coordinated swarms with self-learning, fault-tolerant consensus, vector memory, and MCP integration

260 lines 10.4 kB
/** * GAIA Hardness Predictor — Linear Classifier (ADR-136 Track Q) * * Classifies GAIA questions into easy / medium / hard using a * hand-rolled logistic regression (no external ML dependencies). * * Training: * `predictor.train(labeledData)` — fits weights via gradient descent * on cross-entropy loss using the 17-dim feature vectors. * * Inference: * `predictor.predict(question)` — returns difficulty class + confidence * + a ComputeBudget that drives model/turns/voting choices in gaia-bench. * * Cold-start: * When untrained (weights = null), classifies everything as "medium". * This is the correct safe default: no wasted Haiku-on-hard, no missed * Sonnet-on-easy. * * Compute budget policy (from ADR-136 Track Q research): * easy → Haiku, max 4 turns, 1 attempt * medium → Sonnet, max 8 turns, 1 attempt * hard → Sonnet, max 12 turns, 3-vote (Track A) * * Conservative threshold: * If in doubt, classify UP (medium→hard preferred over medium→easy). * `conservativeMode: true` (default) shifts the easy/medium boundary * so fewer questions fall into "easy". * * Refs: ADR-136, ADR-135, #2156 */ import { extractFeatures } from './features.js'; // --------------------------------------------------------------------------- // Compute budget policy // --------------------------------------------------------------------------- export const COMPUTE_BUDGETS = { easy: { model: 'haiku', maxTurns: 4, votingAttempts: 1, }, medium: { model: 'sonnet', maxTurns: 8, votingAttempts: 1, }, hard: { model: 'sonnet', maxTurns: 12, votingAttempts: 3, }, }; // --------------------------------------------------------------------------- // Internal constants // --------------------------------------------------------------------------- const FEATURE_DIM = 17; const NUM_CLASSES = 3; // easy=0, medium=1, hard=2 const CLASS_NAMES = ['easy', 'medium', 'hard']; // Learning-rate and regularisation tuned for ~300-example datasets const LEARNING_RATE = 0.05; const REGULARISATION_LAMBDA = 0.01; const TRAINING_EPOCHS = 200; // Conservative mode: shift the probability threshold so fewer questions // are classified as "easy" (reduces risk of underpowering hard questions). // With conservativeMode: the easy threshold is 0.6 (not 0.5). const EASY_THRESHOLD_CONSERVATIVE = 0.55; // --------------------------------------------------------------------------- // Math helpers (pure functions, no deps) // --------------------------------------------------------------------------- function softmax(logits) { const maxLogit = Math.max(...logits); const exps = logits.map((l) => Math.exp(l - maxLogit)); const sumExps = exps.reduce((a, b) => a + b, 0); return exps.map((e) => e / sumExps); } function dot(a, b) { let sum = 0; for (let i = 0; i < a.length; i++) sum += a[i] * b[i]; return sum; } // --------------------------------------------------------------------------- // HardnessPredictor // --------------------------------------------------------------------------- export class HardnessPredictor { /** * Weight matrix: weights[classIdx][featureIdx]. * null = untrained (cold-start: return medium for everything). */ weights = null; /** Bias terms per class. */ biases = null; /** Whether conservative mode is active (default: true). */ conservativeMode; constructor(options = {}) { this.conservativeMode = options.conservativeMode ?? true; } /** * Returns true when the predictor has been trained and is ready * to make non-trivial predictions. */ get isTrained() { return this.weights !== null; } // ── Training ───────────────────────────────────────────────────────────── /** * Train the linear classifier using labelled examples from prior runs. * * Labelling strategy (weak supervision): * - All correct + turns ≤ median turns → easy * - All correct + turns > median turns → medium * - Incorrect → hard * * With < 10 examples, refuses to train (cold-start is safer). * With 10-50 examples, trains but sets `conservativeMode`-threshold high. */ train(labeledData) { if (labeledData.length < 10) { // Too few examples for meaningful generalisation. this.weights = null; this.biases = null; return; } // Derive labels using weak-supervision strategy. const allTurns = labeledData .filter((d) => d.turns !== undefined) .map((d) => d.turns); const medianTurns = allTurns.length > 0 ? median(allTurns) : 6; const X = []; const y = []; for (const example of labeledData) { const fv = extractFeatures(example.question); X.push(fv.values); let classIdx; if (example.wasCorrect) { const t = example.turns ?? medianTurns; classIdx = t <= medianTurns ? 0 : 1; // easy=0, medium=1 } else { classIdx = 2; // hard } y.push(classIdx); } // Initialise weights to 0. const W = Array.from({ length: NUM_CLASSES }, () => new Array(FEATURE_DIM).fill(0)); const b = new Array(NUM_CLASSES).fill(0); // Mini-batch gradient descent (batch = full dataset for small sizes). for (let epoch = 0; epoch < TRAINING_EPOCHS; epoch++) { // Accumulate gradients. const dW = Array.from({ length: NUM_CLASSES }, () => new Array(FEATURE_DIM).fill(0)); const db = new Array(NUM_CLASSES).fill(0); for (let n = 0; n < X.length; n++) { const x = X[n]; const trueClass = y[n]; // Compute logits and softmax probabilities. const logits = W.map((w, k) => dot(w, x) + b[k]); const probs = softmax(logits); // Cross-entropy gradient for each class. for (let k = 0; k < NUM_CLASSES; k++) { const grad = probs[k] - (k === trueClass ? 1 : 0); for (let f = 0; f < FEATURE_DIM; f++) { dW[k][f] += grad * x[f]; } db[k] += grad; } } // Update weights with L2 regularisation. const N = X.length; for (let k = 0; k < NUM_CLASSES; k++) { for (let f = 0; f < FEATURE_DIM; f++) { W[k][f] -= LEARNING_RATE * (dW[k][f] / N + REGULARISATION_LAMBDA * W[k][f]); } b[k] -= LEARNING_RATE * (db[k] / N); } } this.weights = W; this.biases = b; } // ── Inference ───────────────────────────────────────────────────────────── /** * Predict the hardness class of a single GAIA question. * * Cold-start (untrained): returns medium with confidence=0.5. */ predict(question) { const features = extractFeatures(question); if (!this.weights || !this.biases) { // Cold-start fallback: medium for everything. return { difficulty: 'medium', confidence: 0.5, budget: COMPUTE_BUDGETS.medium, features, }; } const logits = this.weights.map((w, k) => dot(w, features.values) + this.biases[k]); const probs = softmax(logits); // Conservative mode: down-weight easy probability. let adjustedProbs = [...probs]; if (this.conservativeMode) { // Transfer a fraction of easy probability to medium. const easyExcess = Math.max(0, probs[0] - EASY_THRESHOLD_CONSERVATIVE); adjustedProbs[0] = probs[0] - easyExcess; adjustedProbs[1] = probs[1] + easyExcess; } // Re-normalise after adjustment. const sumAdj = adjustedProbs.reduce((a, b) => a + b, 0); adjustedProbs = adjustedProbs.map((p) => p / sumAdj); // Pick argmax. let bestClass = 0; for (let k = 1; k < NUM_CLASSES; k++) { if (adjustedProbs[k] > adjustedProbs[bestClass]) bestClass = k; } const difficulty = CLASS_NAMES[bestClass]; const confidence = adjustedProbs[bestClass]; return { difficulty, confidence, budget: COMPUTE_BUDGETS[difficulty], features, }; } // ── Serialisation ───────────────────────────────────────────────────────── /** * Export weights as a plain JSON-serialisable object. * Returns null if untrained. */ export() { if (!this.weights || !this.biases) return null; return { weights: this.weights, biases: this.biases }; } /** * Import previously exported weights. */ import(state) { if (!Array.isArray(state.weights) || state.weights.length !== NUM_CLASSES || !Array.isArray(state.biases) || state.biases.length !== NUM_CLASSES) { throw new Error(`Invalid weight state: expected ${NUM_CLASSES}×${FEATURE_DIM} matrix + ${NUM_CLASSES} biases`); } this.weights = state.weights.map((row) => [...row]); this.biases = [...state.biases]; } } // --------------------------------------------------------------------------- // Utility // --------------------------------------------------------------------------- function median(values) { if (values.length === 0) return 0; const sorted = [...values].sort((a, b) => a - b); const mid = Math.floor(sorted.length / 2); return sorted.length % 2 === 0 ? (sorted[mid - 1] + sorted[mid]) / 2 : sorted[mid]; } //# sourceMappingURL=predictor.js.map