From ml-research
Execute training runs with proper monitoring, checkpointing, and experiment tracking. Use when starting training, resuming training, debugging training issues, or setting up multi-GPU/distributed training with PyTorch Lightning and Hydra.
npx claudepluginhub nishide-dev/claude-code-ml-researchThis skill uses the workspace's default tool permissions.
Execute training runs with proper monitoring, checkpointing, and experiment tracking using PyTorch Lightning and Hydra.
Verifies tests pass on completed feature branch, presents options to merge locally, create GitHub PR, keep as-is or discard; executes choice and cleans up worktree.
Guides root cause investigation for bugs, test failures, unexpected behavior, performance issues, and build failures before proposing fixes.
Writes implementation plans from specs for multi-step tasks, mapping files and breaking into TDD bite-sized steps before coding.
Share bugs, ideas, or general feedback.
Execute training runs with proper monitoring, checkpointing, and experiment tracking using PyTorch Lightning and Hydra.
Choose a training template based on your setup:
Basic training:
python src/train.py
With specific experiment config:
python src/train.py experiment=my_experiment
With CLI overrides:
python src/train.py \
model.learning_rate=1e-3 \
data.batch_size=64 \
trainer.max_epochs=100
Resume from checkpoint:
python src/train.py ckpt_path="checkpoints/epoch_42.ckpt"
Multi-GPU training:
python src/train.py \
trainer.devices=4 \
trainer.strategy=ddp
Hyperparameter sweep:
python src/train.py --multirun \
model.learning_rate=1e-4,1e-3,1e-2 \
data.batch_size=32,64,128
Before starting training, verify:
Environment:
# Check Python version
python --version # Should be >= 3.10
# Check CUDA availability
python -c "import torch; print(f'CUDA: {torch.cuda.is_available()}, GPUs: {torch.cuda.device_count()}')"
# Check package installation
python -c "import pytorch_lightning as pl; print(f'Lightning: {pl.__version__}')"
# Validate config
python src/train.py --cfg job
# Dry run
python src/train.py trainer.fast_dev_run=5
Disk space:
model_size_mb × save_top_k × num_epochs / checkpoint_freqExperiment tracking:
# Initialize W&B (if using)
wandb login
export WANDB_PROJECT="your-project-name"
# Sync offline runs (if needed)
wandb sync
Real-time monitoring:
watch -n 1 nvidia-smiKey metrics to watch:
Red flags:
Gradient issues:
# Add to trainer config
trainer = Trainer(
gradient_clip_val=1.0,
gradient_clip_algorithm="norm",
track_grad_norm=2, # Log gradient norms
)
Memory issues:
# Use mixed precision + gradient accumulation
trainer = Trainer(
precision="16-mixed",
accumulate_grad_batches=4,
)
# Compile model (PyTorch 2.0+)
model = torch.compile(model)
Slow data loading:
# Profile to identify bottleneck
trainer = Trainer(profiler="simple")
# Optimize data loading
data_module = DataModule(
num_workers=8,
pin_memory=True,
persistent_workers=True,
prefetch_factor=2,
)
Overfitting:
# Add early stopping
from pytorch_lightning.callbacks import EarlyStopping
trainer = Trainer(
callbacks=[
EarlyStopping(monitor="val/loss", patience=10, min_delta=0.001)
]
)
# Increase regularization
model:
dropout: 0.3
weight_decay: 0.0001
DDP (Distributed Data Parallel):
trainer.strategy=ddp trainer.devices=4FSDP (Fully Sharded Data Parallel):
DeepSpeed:
trainer.strategy=deepspeed_stage_3Mixed precision training:
python src/train.py trainer.precision=16-mixed
Gradient checkpointing (save memory):
model.gradient_checkpointing_enable()
Learning rate finder:
trainer = Trainer()
lr_finder = trainer.tuner.lr_find(model, datamodule=dm)
fig = lr_finder.plot(suggest=True)
Stochastic Weight Averaging:
from pytorch_lightning.callbacks import StochasticWeightAveraging
trainer = Trainer(callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)])
Load best checkpoint:
best_model_path = trainer.checkpoint_callback.best_model_path
model = MyModel.load_from_checkpoint(best_model_path)
Evaluate on test set:
trainer.test(model, datamodule=dm, ckpt_path="best")
Generate predictions:
predictions = trainer.predict(model, datamodule=dm)
For graph neural networks, see GNN training guide:
# Node classification
python src/train.py \
model=gnn \
data=graph \
data.dataset_name=Cora
# Graph classification with batching
python src/train.py \
model=gnn \
data=graph \
data.dataset_name=PROTEINS \
data.batch_size=32
# Large graph sampling
python src/train.py \
model=gnn \
data=graph \
data.use_sampling=true \
data.num_neighbors=[15,10,5]
GNN-specific metrics to monitor:
See complete GNN guide for architectures, sampling strategies, and troubleshooting.
python src/train.py --multirun \
model.learning_rate=1e-4,1e-3,1e-2 \
data.batch_size=32,64,128
python src/train.py \
--multirun \
hydra/sweeper=optuna \
hydra.sweeper.n_trials=50
For advanced sweep configurations (random search, Bayesian optimization, multi-objective), see reference guide.
# In LightningModule
def training_step(self, batch, batch_idx):
loss = ...
# Log metrics
self.log("train/loss", loss)
self.log("train/acc", accuracy)
self.log("lr", self.optimizers().param_groups[0]["lr"])
return loss
W&B can automatically log:
log_model=true)See reference guide for logging confusion matrices, sample predictions, and custom artifacts.
# Quick debug run (5 batches)
python src/train.py trainer.fast_dev_run=5
# Overfit single batch (check model capacity)
python src/train.py trainer.overfit_batches=1 trainer.max_epochs=100
# Profile training (identify bottlenecks)
python src/train.py trainer.profiler=advanced trainer.max_epochs=1
For complete command reference, see reference guide.
See complete training examples:
Training doesn't start:
python src/train.py --cfg jobpython -c "import pytorch_lightning; import hydra"python -c "import torch; print(torch.cuda.is_available())"Training is unstable:
Training is slow:
trainer.profiler="advanced"torch.compile() (PyTorch 2.0+)For advanced topics, see:
Happy training! 🚀