UNPKG

gepa-ts

Version:

TypeScript implementation of GEPA (Gradient-free Evolution of Prompts and Agents) - Complete port with 100% feature parity

265 lines (223 loc) 10.2 kB
#!/usr/bin/env tsx /** * AnyMaths Training - Exact 1:1 replica of Python GEPA's train_anymaths.py * Supports GSM8K and AIME datasets with identical logic and parameters */ import { Command } from 'commander'; import * as fs from 'fs/promises'; import * as path from 'path'; import { fileURLToPath } from 'url'; import { dirname } from 'path'; import { GEPAOptimizer } from '../src/core/optimizer.js'; import { AnyMathsAdapter } from '../src/adapters/anymaths-adapter.js'; import { OpenAILanguageModel } from '../src/models/openai.js'; const __filename = fileURLToPath(import.meta.url); const __dirname = dirname(__filename); // Exact replica of Python dataset interface interface MathDatasetItem { input: string; additional_context?: { solution: string }; answer: string; } // Exact replica of Python init_dataset function async function initDataset(anymathsDsetName: string = "openai/gsm8k"): Promise<{ trainset: MathDatasetItem[]; valset: MathDatasetItem[]; testset: MathDatasetItem[]; }> { let trainSplit: MathDatasetItem[] = []; let testSplit: MathDatasetItem[] = []; switch (anymathsDsetName) { case "openai/gsm8k": // Mock GSM8K dataset (in production, load from HuggingFace) const gsm8kTrainExamples = [ { question: "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?", answer: "Natalia sold 48/2 = <<48/2=24>>24 clips in May.\\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\\n#### 72" }, { question: "Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?", answer: "Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute.\\nWorking 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10.\\n#### 10" }, { question: "Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?", answer: "In the beginning, Betty has only 100/2 = $<<100/2=50>>50.\\nBetty's grandparents give her 15 * 2 = $<<15*2=30>>30.\\nThis means, Betty needs 100 - 50 - 15 - 30 = $<<100-50-15-30=5>>5 more.\\n#### 5" } ]; for (const item of gsm8kTrainExamples) { const answerParts = item.answer.split("####"); const answer = answerParts[answerParts.length - 1].trim(); const solution = answerParts[0].trim(); const question = item.question; trainSplit.push({ input: question, additional_context: { solution }, answer }); } // Shuffle with same seed as Python (Random(0)) trainSplit = shuffleArray(trainSplit, 0); // Create test split (same data for simplicity) testSplit = trainSplit.map(item => ({ input: item.input, answer: item.answer })); break; case "MathArena/aime_2025": // Mock AIME dataset const aimeTrainExamples = [ { problem: "Find the number of ordered pairs $(a,b)$ of integers such that $|a + b| = 1$ and $|a - b| = 5$.", solution: "We consider cases based on the signs in the absolute values...", answer: 4 }, { problem: "Let $S$ be the sum of all positive divisors of 60. Find the remainder when $S$ is divided by 13.", solution: "First find all divisors of 60 = 2² × 3 × 5...", answer: 6 } ]; for (const item of aimeTrainExamples) { trainSplit.push({ input: item.problem, additional_context: { solution: item.solution }, answer: item.answer.toString() }); } trainSplit = shuffleArray(trainSplit, 0); testSplit = trainSplit.map(item => ({ input: item.input, answer: item.answer })); break; default: throw new Error(`Unknown dataset name: ${anymathsDsetName}`); } const trainset = trainSplit.slice(0, Math.floor(trainSplit.length / 2)); const valset = trainSplit.slice(Math.floor(trainSplit.length / 2)); const testset = testSplit; return { trainset, valset, testset }; } // Exact replica of Python shuffle logic function shuffleArray<T>(array: T[], seed: number): T[] { let rng = seed; const shuffled = [...array]; for (let i = shuffled.length - 1; i > 0; i--) { rng = (rng * 9301 + 49297) % 233280; const j = Math.floor((rng / 233280) * (i + 1)); [shuffled[i], shuffled[j]] = [shuffled[j], shuffled[i]]; } return shuffled; } // Main function (exact replica of Python __main__) async function main() { const program = new Command(); // Exact replica of Python argument parser program .option('--anymaths_dset_name <name>', 'Dataset name', 'openai/gsm8k') .option('--train_size <size>', 'Training set size', '1') .option('--val_size <size>', 'Validation set size', '1') .option('--test_size <size>', 'Test set size', '1') .option('--base_lm <model>', 'Base language model', 'gpt-4o-mini') .option('--use_api_base', 'Use API base URL') .option('--api_base_url <url>', 'API base URL', 'http://localhost:11434') .option('--reflection_lm <model>', 'Reflection LM', 'gpt-4o') .option('--use_api_reflection', 'Use API reflection URL') .option('--api_reflection_url <url>', 'API reflection URL', 'http://localhost:11434') .option('--reflection_minibatch_size <size>', 'Reflection minibatch size', '8') .option('--max_litellm_workers <workers>', 'Max workers', '10') .option('--budget <budget>', 'Optimization budget', '500') .option('--seed <seed>', 'Random seed', '0') .parse(); const args = program.opts(); // Load instruction prompt (exact replica of Python) const INSTRUCTION_PROMPT_PATH = path.join(__dirname, 'prompt-templates', 'instruction_prompt.txt'); let seedInstruction: string; try { seedInstruction = await fs.readFile(INSTRUCTION_PROMPT_PATH, 'utf-8'); } catch (error) { // Fallback seed instruction seedInstruction = `You are a mathematical problem-solving assistant. Solve the given problem step by step and provide the final numerical answer. Instructions: 1. Read the problem carefully 2. Identify what is being asked 3. Show your work step by step 4. Provide the final answer clearly Problem: {input} Solution:`; } // Initialize dataset (exact replica of Python) const { trainset, valset, testset } = await initDataset(args.anymaths_dset_name); const trainSize = parseInt(args.train_size); const valSize = parseInt(args.val_size); const testSize = parseInt(args.test_size); // Validate sizes (exact replica of Python) for (const size of [trainSize, valSize, testSize]) { if (size <= 0) { throw new Error("Train, val, and test sizes must be positive integers."); } } const finalTrainset = trainset.slice(0, trainSize); const finalValset = valset.slice(0, valSize); const finalTestset = testset.slice(0, testSize); console.log('-'.repeat(100)); console.log(`Using dataset: ${args.anymaths_dset_name}`); console.log(`Training set size: ${finalTrainset.length}`); console.log(`Validation set size: ${finalValset.length}`); console.log(`Test set size: ${finalTestset.length}`); console.log('-'.repeat(100)); // Create LLMs (exact replica of Python setup) const baseLm = new OpenAILanguageModel({ apiKey: process.env.OPENAI_API_KEY!, model: args.base_lm, maxTokens: 2000, temperature: 0.7 }); const reflectionModel = new OpenAILanguageModel({ apiKey: process.env.OPENAI_API_KEY!, model: args.reflection_lm, maxTokens: 4000, temperature: 0.8 }); const reflectionLm = (prompt: string) => reflectionModel.generate(prompt); console.log(`Using base LM: ${args.base_lm}`); console.log(`Using reflection LM: ${args.reflection_lm}`); console.log(`Using API base URL: ${args.use_api_base ? args.api_base_url : 'default'}`); console.log(`Using API reflection URL: ${args.use_api_reflection ? args.api_reflection_url : 'default'}`); console.log(`Reflection minibatch size: ${args.reflection_minibatch_size}`); console.log(`Max LiteLLM workers: ${args.max_litellm_workers}`); console.log(`Budget: ${args.budget}`); console.log(`Seed: ${args.seed}`); console.log('-'.repeat(100)); // Create adapter (exact replica of Python) const adapter = new AnyMathsAdapter({ model: (prompt: string) => baseLm.generate(prompt), apiBase: args.use_api_base ? args.api_base_url : undefined, maxWorkers: parseInt(args.max_litellm_workers) }); // Create optimizer (exact replica of Python optimize call) const optimizer = new GEPAOptimizer({ adapter, config: { populationSize: 8, generations: Math.ceil(parseInt(args.budget) / 64), // Approximate generations from budget mutationRate: 0.6, verbose: true } }); console.log('🚀 Starting AnyMaths optimization...'); const optimizedResults = await optimizer.optimize({ initialCandidate: { instruction_prompt: seedInstruction }, trainData: finalTrainset, valData: finalValset, componentsToUpdate: ['instruction_prompt'], maxMetricCalls: parseInt(args.budget), reflectionLm, reflectionMinibatchSize: parseInt(args.reflection_minibatch_size), perfectScore: 1, skipPerfectScore: false, useWandb: false, seed: parseInt(args.seed), displayProgressBar: true }); console.log('-'.repeat(100)); console.log(`🎉 Best prompt >>> ${JSON.stringify(optimizedResults.bestCandidate, null, 2)}`); console.log(`🏆 Final score: ${(optimizedResults.bestScore * 100).toFixed(1)}%`); console.log(`📈 Generations completed: ${optimizedResults.generationsCompleted}`); console.log(`💬 Total evaluations: ${optimizedResults.totalEvaluations}`); console.log('-'.repeat(100)); } // Run if called directly (exact replica of Python __name__ == "__main__") if (import.meta.url === `file://${process.argv[1]}`) { main().catch(console.error); }