autosnippet
Version:
Extract code patterns into a knowledge base for AI coding assistants
240 lines (239 loc) • 10.1 kB
JavaScript
/**
* AI/ML Enhancement Pack
* 条件: { languages: ['python'], frameworks: ['ml'] }
*
* 覆盖 PyTorch / TensorFlow / HuggingFace 机器学习生态:
* - nn.Module 模型架构
* - Training Loop 模式 (optimizer.zero_grad → loss.backward → optimizer.step)
* - DataLoader / Dataset
* - HuggingFace Trainer / Pipeline
* - 模型保存/加载 (state_dict / safetensors)
* - 分布式训练 (DDP / FSDP)
*/
import { EnhancementPack } from './EnhancementPack.js';
class MLEnhancement extends EnhancementPack {
get id() {
return 'python-ml';
}
get displayName() {
return 'Python AI/ML (PyTorch/HuggingFace) Enhancement';
}
get conditions() {
return { languages: ['python'], frameworks: ['ml'] };
}
getExtraDimensions() {
return [
{
id: 'ml-model-architecture-scan',
label: '模型架构分析',
guide: 'nn.Module 子类分析 — 层级结构 (forward 方法调用链)、自定义 Layer、残差连接/Attention 模式、参数量估算、模型注册表',
tierHint: 2,
knowledgeTypes: ['architecture', 'code-pattern'],
skillWorthy: true,
dualOutput: true,
skillMeta: {
name: 'project-ml-models',
description: 'PyTorch model architectures, custom layers and forward pass patterns (auto-generated by enhancement)',
},
},
{
id: 'ml-training-pipeline-scan',
label: 'Training Pipeline 分析',
guide: 'Training 流程分析 — Training Loop 结构 (epoch → batch → forward → loss → backward → step)、learning rate scheduler、gradient clipping/accumulation、早停策略、checkpoint 保存/恢复、HuggingFace Trainer 配置',
tierHint: 2,
knowledgeTypes: ['code-pattern'],
skillWorthy: true,
dualOutput: true,
skillMeta: {
name: 'project-ml-training',
description: 'ML training pipeline — training loops, schedulers, checkpointing and HF Trainer (auto-generated by enhancement)',
},
},
{
id: 'ml-data-pipeline-scan',
label: '数据管道分析',
guide: '数据处理管道分析 — Dataset/DataLoader 实现、数据增强 (transforms)、tokenizer 配置、特征工程、数据拆分策略 (train/val/test)',
tierHint: 2,
knowledgeTypes: ['code-pattern', 'architecture'],
skillWorthy: true,
dualOutput: true,
skillMeta: {
name: 'project-ml-data',
description: 'ML data pipelines — Dataset/DataLoader, transforms and tokenization (auto-generated by enhancement)',
},
},
];
}
getGuardRules() {
return [
{
ruleId: 'ml-no-eval-in-train',
category: 'correctness',
dimension: 'file',
severity: 'warning',
languages: ['python'],
pattern: /model\.train\(\)[\s\S]*?model\.eval\(\)[\s\S]*?loss\.backward/,
message: '训练循环中意外切换到 eval 模式后执行 backward — 确保 model.train() 在训练阶段',
},
{
ruleId: 'ml-device-mismatch',
category: 'correctness',
dimension: 'file',
severity: 'warning',
languages: ['python'],
pattern: /\.to\s*\(\s*['"](?:cuda|cpu)['"]\s*\)/,
message: '硬编码 device 字符串 — 建议使用 torch.device() 变量统一管理,支持 MPS/多 GPU',
},
{
ruleId: 'ml-missing-no-grad',
category: 'performance',
dimension: 'file',
severity: 'warning',
languages: ['python'],
pattern: /def\s+(?:evaluate|validate|test|predict|inference)\s*\([^)]*\)[\s\S]*?(?!torch\.no_grad|@torch\.no_grad)model\s*\(/,
message: '推理/评估函数应使用 @torch.no_grad() 或 with torch.no_grad() — 减少内存消耗并加速',
},
{
ruleId: 'ml-gradient-accumulation-zero',
category: 'correctness',
dimension: 'file',
severity: 'info',
languages: ['python'],
pattern: /loss\.backward\(\)[\s\S]*?optimizer\.step\(\)[\s\S]*?(?!optimizer\.zero_grad)/,
message: 'optimizer.step() 后应调用 optimizer.zero_grad() — 否则梯度会累积',
},
{
ruleId: 'ml-random-seed',
category: 'correctness',
dimension: 'file',
severity: 'info',
languages: ['python'],
pattern: /torch\.manual_seed|random\.seed|np\.random\.seed/,
message: '设置随机种子时建议同时设置 torch.manual_seed / torch.cuda.manual_seed_all / np.random.seed / random.seed 保证完全可复现',
},
];
}
detectPatterns(astSummary) {
const patterns = [];
// ── nn.Module subclasses ──
for (const cls of astSummary.classes || []) {
if (cls.superclass && /Module$|nn\.Module/.test(cls.superclass)) {
patterns.push({
type: 'pytorch-model',
className: cls.name,
line: cls.line,
confidence: 0.95,
});
}
}
// ── Dataset subclasses ──
for (const cls of astSummary.classes || []) {
if (cls.superclass && /Dataset$|IterableDataset/.test(cls.superclass)) {
patterns.push({
type: 'pytorch-dataset',
className: cls.name,
line: cls.line,
confidence: 0.9,
});
}
}
// ── HuggingFace model/tokenizer/trainer ──
for (const cls of astSummary.classes || []) {
if (cls.superclass && /PreTrainedModel|PretrainedConfig|Trainer/.test(cls.superclass)) {
patterns.push({
type: 'huggingface-model',
className: cls.name,
line: cls.line,
confidence: 0.9,
});
}
}
// ── Training functions (train/train_one_epoch/train_step) ──
for (const m of astSummary.methods || []) {
const nameLower = m.name?.toLowerCase() || '';
if (nameLower === 'train' ||
nameLower === 'train_one_epoch' ||
nameLower === 'train_step' ||
nameLower === 'training_step' ||
nameLower === 'train_loop') {
patterns.push({
type: 'ml-training-function',
methodName: m.name,
line: m.line,
confidence: 0.85,
});
}
}
// ── Evaluation functions ──
for (const m of astSummary.methods || []) {
const nameLower = m.name?.toLowerCase() || '';
if (nameLower === 'evaluate' ||
nameLower === 'validate' ||
nameLower === 'eval_step' ||
nameLower === 'validation_step' ||
nameLower === 'test_step') {
patterns.push({
type: 'ml-evaluation-function',
methodName: m.name,
line: m.line,
confidence: 0.85,
});
}
}
// ── forward method (nn.Module) ──
for (const m of astSummary.methods || []) {
if (m.name === 'forward' && m.className) {
patterns.push({
type: 'pytorch-forward',
className: m.className,
methodName: m.name,
line: m.line,
confidence: 0.9,
});
}
}
// ── Lightning modules ──
for (const cls of astSummary.classes || []) {
if (cls.superclass && /LightningModule|LightningDataModule/.test(cls.superclass)) {
patterns.push({
type: 'pytorch-lightning-module',
className: cls.name,
line: cls.line,
confidence: 0.9,
});
}
}
// ── Loss function classes ──
for (const cls of astSummary.classes || []) {
const nameLower = cls.name?.toLowerCase() || '';
if (nameLower.includes('loss') && cls.superclass && /Module/.test(cls.superclass)) {
patterns.push({
type: 'ml-custom-loss',
className: cls.name,
line: cls.line,
confidence: 0.85,
});
}
}
// ── ML ecosystem imports ──
const mlImports = (astSummary.imports || []).filter((imp) => imp.includes('torch') ||
imp.includes('tensorflow') ||
imp.includes('transformers') ||
imp.includes('datasets') ||
imp.includes('accelerate') ||
imp.includes('lightning') ||
imp.includes('sklearn') ||
imp.includes('numpy') ||
imp.includes('wandb') ||
imp.includes('tensorboard'));
if (mlImports.length > 0) {
patterns.push({
type: 'ml-ecosystem-usage',
importCount: mlImports.length,
confidence: 0.85,
});
}
return patterns;
}
}
export const pack = new MLEnhancement();