UNPKG

gepa-ts

Version:

TypeScript implementation of GEPA (Gradient-free Evolution of Prompts and Agents) - Complete port with 100% feature parity

359 lines (307 loc) 12.2 kB
import { BaseAdapter, EvaluationBatch } from '../core/adapter.js'; import { ComponentMap, ReflectiveDataset } from '../types/index.js'; import { DSPyExample, DSPyTrace, DSPyPrediction, DSPyProgram, DSPyModule } from './dspy-adapter.js'; import { DSPyExecutor, createDSPyExecutor } from '../executors/dspy-executor.js'; /** * DSPy Full Program Adapter - Evolves entire DSPy programs including * signatures, modules, and control flow logic. * * This adapter goes beyond just optimizing instructions - it can: * - Add/remove modules * - Change module types (Predict -> ChainOfThought) * - Modify control flow * - Evolve custom module implementations */ export class DSPyFullProgramAdapter extends BaseAdapter<DSPyExample, DSPyTrace, DSPyPrediction> { private metricFn: (example: DSPyExample, prediction: DSPyPrediction) => number | Promise<number>; private taskLm?: (prompt: string) => Promise<string> | string; private reflectionLm?: (prompt: string) => Promise<string> | string; private baseProgram: string; // String representation of the program private executor: DSPyExecutor; private numThreads: number; private taskLmConfig: any; constructor(config: { baseProgram?: string; taskLm?: any; // DSPy LM config or function metricFn: (example: DSPyExample, prediction: DSPyPrediction) => any; reflectionLm?: (prompt: string) => Promise<string> | string; numThreads?: number; pythonPath?: string; }) { super(); this.baseProgram = config.baseProgram || 'import dspy\nprogram = dspy.ChainOfThought("question -> answer")'; this.metricFn = config.metricFn; this.reflectionLm = config.reflectionLm; this.numThreads = config.numThreads || 10; this.executor = createDSPyExecutor(config.pythonPath); // Parse task LM config if (typeof config.taskLm === 'function') { this.taskLm = config.taskLm; this.taskLmConfig = { model: 'openai/gpt-4o-mini' }; } else if (config.taskLm) { this.taskLmConfig = { model: config.taskLm.model || 'openai/gpt-4o-mini', apiKey: config.taskLm.apiKey || process.env.OPENAI_API_KEY, maxTokens: config.taskLm.maxTokens || 2000, temperature: config.taskLm.temperature || 0.7 }; } else { this.taskLmConfig = { model: 'openai/gpt-4o-mini' }; } } private async validateProgram(programCode: string): Promise<{ valid: boolean; error?: string }> { try { // Quick syntax validation if (!programCode.includes('import dspy')) { return { valid: false, error: 'Program must import dspy' }; } if (!programCode.includes('program =')) { return { valid: false, error: 'Program must define a "program" variable' }; } return { valid: true }; } catch (error) { return { valid: false, error: String(error) }; } } private convertMetricFnToCode(): string { // Convert the metric function to Python code // This is a simplified version - in production, would need more sophisticated conversion return ` def metric_fn(example, pred): # Default metric function if hasattr(example, 'answer'): pred_answer = pred.get('answer', '') if hasattr(pred, 'get') else getattr(pred, 'answer', '') return 1.0 if str(pred_answer).strip() == str(example.answer).strip() else 0.0 return 0.0 `; } private async executeProgram( programCode: string, examples: DSPyExample[], captureTrace: boolean = false ): Promise<{ predictions: DSPyPrediction[], traces?: DSPyTrace[] }> { // Use real Python DSPy execution const metricCode = this.convertMetricFnToCode(); try { const result = await this.executor.execute( programCode, examples, this.taskLmConfig, metricCode ); if (!result.success) { throw new Error(result.error || 'Execution failed'); } const predictions: DSPyPrediction[] = []; const traces: DSPyTrace[] = []; for (let i = 0; i < examples.length; i++) { const res = result.results![i]; const prediction: DSPyPrediction = { outputs: res.outputs, score: res.score, feedback: res.feedback }; predictions.push(prediction); if (captureTrace && result.traces) { const trace: DSPyTrace = { moduleInvocations: result.traces[i].map((t: any) => ({ moduleName: t.module, inputs: t.inputs, outputs: t.outputs, timestamp: Date.now() })), score: res.score }; traces.push(trace); } } return { predictions, traces: captureTrace ? traces : undefined }; } catch (error) { // Return empty predictions on error const predictions = examples.map(() => ({ outputs: {}, score: 0, error: String(error) })); return { predictions }; } } async evaluate( batch: DSPyExample[], candidate: ComponentMap, captureTraces: boolean = false ): Promise<EvaluationBatch<DSPyTrace, DSPyPrediction>> { const programCode = candidate['program'] || this.baseProgram; // Validate program before execution const validation = await this.validateProgram(programCode); if (!validation.valid) { // Return failure for all examples return { outputs: batch.map(() => ({ outputs: {}, score: 0, error: validation.error })), scores: batch.map(() => 0), trajectories: captureTraces ? batch.map(() => ({ moduleInvocations: [], score: 0, errors: [validation.error!] })) : null }; } // Execute program on batch const { predictions, traces } = await this.executeProgram(programCode, batch, captureTraces); return { outputs: predictions, scores: predictions.map(p => p.score || 0), trajectories: traces || null }; } async makeReflectiveDataset( candidate: ComponentMap, evalBatch: EvaluationBatch<DSPyTrace, DSPyPrediction>, componentsToUpdate: string[] ): Promise<ReflectiveDataset> { const dataset: ReflectiveDataset = {}; if (!evalBatch.trajectories) { return dataset; } // For full program evolution, we analyze the entire program const programAnalysis: Array<Record<string, any>> = []; for (let i = 0; i < evalBatch.trajectories.length; i++) { const trace = evalBatch.trajectories[i]; const score = evalBatch.scores[i]; const output = evalBatch.outputs[i]; let feedback = ''; let suggestions = []; if (score < 0.3) { feedback = 'Program failed to produce correct output.'; suggestions.push('Consider adding ChainOfThought for complex reasoning'); suggestions.push('Add intermediate validation steps'); suggestions.push('Improve error handling'); } else if (score < 0.7) { feedback = 'Program partially successful.'; suggestions.push('Refine module signatures for clarity'); suggestions.push('Add more specific instructions'); // Analyze specific failure modes if (!output.outputs.reasoning && trace.moduleInvocations.length > 1) { suggestions.push('Add reasoning trace for debugging'); } } else { feedback = 'Program performing well.'; suggestions.push('Consider optimizing for efficiency'); suggestions.push('Simplify if over-engineered'); } programAnalysis.push({ 'Current Program': candidate['program'] || this.baseProgram, 'Score': score, 'Feedback': feedback, 'Suggestions': suggestions, 'Failed Examples': score < 0.5 ? trace.moduleInvocations : [], 'Successful Patterns': score > 0.8 ? output.outputs : {} }); } dataset['program'] = programAnalysis.slice(0, 5); return dataset; } async proposeNewTexts( candidate: ComponentMap, reflectiveDataset: ReflectiveDataset, componentsToUpdate: string[] ): Promise<ComponentMap> { const currentProgram = candidate['program'] || this.baseProgram; const examples = reflectiveDataset['program'] || []; // If we have a reflection LM, use it for sophisticated evolution if (this.reflectionLm) { const prompt = this.buildEvolutionPrompt(currentProgram, examples); const evolvedProgram = await this.reflectionLm(prompt); // Extract code from response const codeMatch = evolvedProgram.match(/```python?\n([\s\S]*?)\n```/); if (codeMatch) { return { program: codeMatch[1] }; } // Try to find program definition directly if (evolvedProgram.includes('import dspy') && evolvedProgram.includes('program =')) { return { program: evolvedProgram }; } } // Fallback to heuristic evolution const avgScore = examples.reduce((sum, ex) => sum + ex['Score'], 0) / (examples.length || 1); let evolvedProgram = currentProgram; if (avgScore < 0.3) { // Major restructuring needed evolvedProgram = this.majorRestructure(currentProgram, examples); } else if (avgScore < 0.7) { // Incremental improvements evolvedProgram = this.incrementalImprove(currentProgram, examples); } else { // Minor optimizations evolvedProgram = this.optimize(currentProgram, examples); } return { program: evolvedProgram }; } private buildEvolutionPrompt(currentProgram: string, feedback: any[]): string { return `You are evolving a DSPy program based on execution feedback. Current program: \`\`\`python ${currentProgram} \`\`\` Execution feedback: ${JSON.stringify(feedback.slice(0, 3), null, 2)} Improve the program by: 1. Analyzing failure patterns 2. Adding error handling where needed 3. Improving module signatures 4. Optimizing the dataflow Provide the improved program in a Python code block. Ensure it imports dspy and defines a 'program' variable.`; } private majorRestructure(program: string, feedback: any[]): string { // Convert simple Predict to ChainOfThought let evolved = program.replace(/dspy\.Predict\(/g, 'dspy.ChainOfThought('); // Add error handling if (!evolved.includes('try:')) { evolved = evolved.replace( /def forward\(self,/, 'def forward(self,\n try:\n ' ); } // Add intermediate steps const suggestions = feedback.flatMap(f => f['Suggestions'] || []); if (suggestions.some(s => s.includes('intermediate'))) { evolved = evolved.replace( 'return dspy.Prediction(', '# Add intermediate validation\n ' + 'intermediate = self.validate(result)\n ' + 'return dspy.Prediction(' ); } return evolved; } private incrementalImprove(program: string, feedback: any[]): string { let evolved = program; // Improve signatures const signatureRegex = /"([^"]+)"/g; evolved = evolved.replace(signatureRegex, (match, sig) => { const parts = sig.split('->').map(s => s.trim()); if (parts.length === 2) { return `"${parts[0]} -> reasoning, ${parts[1]}"`; } return match; }); // Add detailed docstrings if (!evolved.includes('"""')) { const classMatch = evolved.match(/class (\w+)\(dspy\.Module\):/); if (classMatch) { evolved = evolved.replace( classMatch[0], `${classMatch[0]}\n """Enhanced module with improved error handling and reasoning."""` ); } } return evolved; } private optimize(program: string, feedback: any[]): string { // Remove redundant modules let evolved = program; // Simplify if over-engineered const moduleCount = (evolved.match(/dspy\.\w+\(/g) || []).length; if (moduleCount > 5) { // Consider consolidating modules evolved = `# Consider consolidating modules for efficiency\n${evolved}`; } return evolved; } }