UNPKG

autosnippet

Version:

Extract code patterns into a knowledge base for AI coding assistants

240 lines (239 loc) 10.1 kB
/** * AI/ML Enhancement Pack * 条件: { languages: ['python'], frameworks: ['ml'] } * * 覆盖 PyTorch / TensorFlow / HuggingFace 机器学习生态: * - nn.Module 模型架构 * - Training Loop 模式 (optimizer.zero_grad → loss.backward → optimizer.step) * - DataLoader / Dataset * - HuggingFace Trainer / Pipeline * - 模型保存/加载 (state_dict / safetensors) * - 分布式训练 (DDP / FSDP) */ import { EnhancementPack } from './EnhancementPack.js'; class MLEnhancement extends EnhancementPack { get id() { return 'python-ml'; } get displayName() { return 'Python AI/ML (PyTorch/HuggingFace) Enhancement'; } get conditions() { return { languages: ['python'], frameworks: ['ml'] }; } getExtraDimensions() { return [ { id: 'ml-model-architecture-scan', label: '模型架构分析', guide: 'nn.Module 子类分析 — 层级结构 (forward 方法调用链)、自定义 Layer、残差连接/Attention 模式、参数量估算、模型注册表', tierHint: 2, knowledgeTypes: ['architecture', 'code-pattern'], skillWorthy: true, dualOutput: true, skillMeta: { name: 'project-ml-models', description: 'PyTorch model architectures, custom layers and forward pass patterns (auto-generated by enhancement)', }, }, { id: 'ml-training-pipeline-scan', label: 'Training Pipeline 分析', guide: 'Training 流程分析 — Training Loop 结构 (epoch → batch → forward → loss → backward → step)、learning rate scheduler、gradient clipping/accumulation、早停策略、checkpoint 保存/恢复、HuggingFace Trainer 配置', tierHint: 2, knowledgeTypes: ['code-pattern'], skillWorthy: true, dualOutput: true, skillMeta: { name: 'project-ml-training', description: 'ML training pipeline — training loops, schedulers, checkpointing and HF Trainer (auto-generated by enhancement)', }, }, { id: 'ml-data-pipeline-scan', label: '数据管道分析', guide: '数据处理管道分析 — Dataset/DataLoader 实现、数据增强 (transforms)、tokenizer 配置、特征工程、数据拆分策略 (train/val/test)', tierHint: 2, knowledgeTypes: ['code-pattern', 'architecture'], skillWorthy: true, dualOutput: true, skillMeta: { name: 'project-ml-data', description: 'ML data pipelines — Dataset/DataLoader, transforms and tokenization (auto-generated by enhancement)', }, }, ]; } getGuardRules() { return [ { ruleId: 'ml-no-eval-in-train', category: 'correctness', dimension: 'file', severity: 'warning', languages: ['python'], pattern: /model\.train\(\)[\s\S]*?model\.eval\(\)[\s\S]*?loss\.backward/, message: '训练循环中意外切换到 eval 模式后执行 backward — 确保 model.train() 在训练阶段', }, { ruleId: 'ml-device-mismatch', category: 'correctness', dimension: 'file', severity: 'warning', languages: ['python'], pattern: /\.to\s*\(\s*['"](?:cuda|cpu)['"]\s*\)/, message: '硬编码 device 字符串 — 建议使用 torch.device() 变量统一管理,支持 MPS/多 GPU', }, { ruleId: 'ml-missing-no-grad', category: 'performance', dimension: 'file', severity: 'warning', languages: ['python'], pattern: /def\s+(?:evaluate|validate|test|predict|inference)\s*\([^)]*\)[\s\S]*?(?!torch\.no_grad|@torch\.no_grad)model\s*\(/, message: '推理/评估函数应使用 @torch.no_grad() 或 with torch.no_grad() — 减少内存消耗并加速', }, { ruleId: 'ml-gradient-accumulation-zero', category: 'correctness', dimension: 'file', severity: 'info', languages: ['python'], pattern: /loss\.backward\(\)[\s\S]*?optimizer\.step\(\)[\s\S]*?(?!optimizer\.zero_grad)/, message: 'optimizer.step() 后应调用 optimizer.zero_grad() — 否则梯度会累积', }, { ruleId: 'ml-random-seed', category: 'correctness', dimension: 'file', severity: 'info', languages: ['python'], pattern: /torch\.manual_seed|random\.seed|np\.random\.seed/, message: '设置随机种子时建议同时设置 torch.manual_seed / torch.cuda.manual_seed_all / np.random.seed / random.seed 保证完全可复现', }, ]; } detectPatterns(astSummary) { const patterns = []; // ── nn.Module subclasses ── for (const cls of astSummary.classes || []) { if (cls.superclass && /Module$|nn\.Module/.test(cls.superclass)) { patterns.push({ type: 'pytorch-model', className: cls.name, line: cls.line, confidence: 0.95, }); } } // ── Dataset subclasses ── for (const cls of astSummary.classes || []) { if (cls.superclass && /Dataset$|IterableDataset/.test(cls.superclass)) { patterns.push({ type: 'pytorch-dataset', className: cls.name, line: cls.line, confidence: 0.9, }); } } // ── HuggingFace model/tokenizer/trainer ── for (const cls of astSummary.classes || []) { if (cls.superclass && /PreTrainedModel|PretrainedConfig|Trainer/.test(cls.superclass)) { patterns.push({ type: 'huggingface-model', className: cls.name, line: cls.line, confidence: 0.9, }); } } // ── Training functions (train/train_one_epoch/train_step) ── for (const m of astSummary.methods || []) { const nameLower = m.name?.toLowerCase() || ''; if (nameLower === 'train' || nameLower === 'train_one_epoch' || nameLower === 'train_step' || nameLower === 'training_step' || nameLower === 'train_loop') { patterns.push({ type: 'ml-training-function', methodName: m.name, line: m.line, confidence: 0.85, }); } } // ── Evaluation functions ── for (const m of astSummary.methods || []) { const nameLower = m.name?.toLowerCase() || ''; if (nameLower === 'evaluate' || nameLower === 'validate' || nameLower === 'eval_step' || nameLower === 'validation_step' || nameLower === 'test_step') { patterns.push({ type: 'ml-evaluation-function', methodName: m.name, line: m.line, confidence: 0.85, }); } } // ── forward method (nn.Module) ── for (const m of astSummary.methods || []) { if (m.name === 'forward' && m.className) { patterns.push({ type: 'pytorch-forward', className: m.className, methodName: m.name, line: m.line, confidence: 0.9, }); } } // ── Lightning modules ── for (const cls of astSummary.classes || []) { if (cls.superclass && /LightningModule|LightningDataModule/.test(cls.superclass)) { patterns.push({ type: 'pytorch-lightning-module', className: cls.name, line: cls.line, confidence: 0.9, }); } } // ── Loss function classes ── for (const cls of astSummary.classes || []) { const nameLower = cls.name?.toLowerCase() || ''; if (nameLower.includes('loss') && cls.superclass && /Module/.test(cls.superclass)) { patterns.push({ type: 'ml-custom-loss', className: cls.name, line: cls.line, confidence: 0.85, }); } } // ── ML ecosystem imports ── const mlImports = (astSummary.imports || []).filter((imp) => imp.includes('torch') || imp.includes('tensorflow') || imp.includes('transformers') || imp.includes('datasets') || imp.includes('accelerate') || imp.includes('lightning') || imp.includes('sklearn') || imp.includes('numpy') || imp.includes('wandb') || imp.includes('tensorboard')); if (mlImports.length > 0) { patterns.push({ type: 'ml-ecosystem-usage', importCount: mlImports.length, confidence: 0.85, }); } return patterns; } } export const pack = new MLEnhancement();