Advanced Usage Guide
This guide covers advanced features and usage patterns in Torchium, including meta-optimizers, experimental algorithms, and sophisticated loss combinations.
Meta-Optimizers
Adaptive SAM (ASAM)
ASAM adapts the perturbation radius:
optimizer = torchium.optimizers.ASAM(
model.parameters(),
lr=1e-3,
rho=0.5, # Initial perturbation radius
eta=0.01 # Adaptation rate
)
Gradient Surgery Methods
PCGrad for Multi-Task Learning
class MultiTaskModel(nn.Module):
def __init__(self):
super().__init__()
self.shared = nn.Linear(10, 64)
self.task1 = nn.Linear(64, 1)
self.task2 = nn.Linear(64, 5)
model = MultiTaskModel()
# Use PCGrad for gradient surgery
optimizer = torchium.optimizers.PCGrad(
model.parameters(),
lr=1e-3
)
# Training with multiple tasks
for epoch in range(100):
optimizer.zero_grad()
# Forward passes for different tasks
shared_features = model.shared(x)
task1_output = model.task1(shared_features)
task2_output = model.task2(shared_features)
# Compute losses
loss1 = criterion1(task1_output, y1)
loss2 = criterion2(task2_output, y2)
# PCGrad handles gradient conflicts
optimizer.step([loss1, loss2])
GradNorm for Dynamic Loss Balancing
optimizer = torchium.optimizers.GradNorm(
model.parameters(),
lr=1e-3,
alpha=1.5 # Restoring force hyperparameter
)
Second-Order Optimizers
LBFGS for Well-Conditioned Problems
# LBFGS works best with full batch or large batches
optimizer = torchium.optimizers.LBFGS(
model.parameters(),
lr=1.0,
max_iter=20,
max_eval=None,
tolerance_grad=1e-7,
tolerance_change=1e-9,
history_size=100,
line_search_fn="strong_wolfe"
)
# Training loop for LBFGS
def closure():
optimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
loss.backward()
return loss
for epoch in range(100):
optimizer.step(closure)
Shampoo for Large Models
optimizer = torchium.optimizers.Shampoo(
model.parameters(),
lr=1e-3,
momentum=0.9,
weight_decay=1e-4,
epsilon=1e-4,
update_freq=1
)
Experimental Optimizers
CMA-ES for Global Optimization
# CMA-ES for non-convex optimization
optimizer = torchium.optimizers.CMAES(
model.parameters(),
population_size=20,
sigma=0.1,
max_generations=1000
)
# Training loop for CMA-ES
for generation in range(1000):
optimizer.step()
if optimizer.should_stop():
break
Differential Evolution
optimizer = torchium.optimizers.DifferentialEvolution(
model.parameters(),
population_size=30,
mutation_factor=0.8,
crossover_probability=0.9,
max_generations=1000
)
Particle Swarm Optimization
optimizer = torchium.optimizers.ParticleSwarmOptimization(
model.parameters(),
swarm_size=20,
inertia_weight=0.9,
cognitive_weight=2.0,
social_weight=2.0,
max_iterations=1000
)
Advanced Loss Combinations
Multi-Task Learning with Uncertainty Weighting
class MultiTaskLoss(nn.Module):
def __init__(self, num_tasks):
super().__init__()
self.uncertainty_loss = torchium.losses.UncertaintyWeightingLoss(num_tasks)
self.task_losses = [
torchium.losses.MSELoss(),
torchium.losses.CrossEntropyLoss(),
torchium.losses.DiceLoss()
]
def forward(self, predictions, targets):
losses = []
for i, (pred, target) in enumerate(zip(predictions, targets)):
loss = self.task_losses[i](pred, target)
losses.append(loss)
return self.uncertainty_loss(losses)
criterion = MultiTaskLoss(num_tasks=3)
Combined Segmentation Loss
class CombinedSegmentationLoss(nn.Module):
def __init__(self):
super().__init__()
self.dice = torchium.losses.DiceLoss(smooth=1e-5)
self.focal = torchium.losses.FocalLoss(alpha=0.25, gamma=2.0)
self.tversky = torchium.losses.TverskyLoss(alpha=0.3, beta=0.7)
self.lovasz = torchium.losses.LovaszLoss()
def forward(self, pred, target):
dice_loss = self.dice(pred, target)
focal_loss = self.focal(pred, target)
tversky_loss = self.tversky(pred, target)
lovasz_loss = self.lovasz(pred, target)
# Weighted combination
total_loss = (0.4 * dice_loss +
0.3 * focal_loss +
0.2 * tversky_loss +
0.1 * lovasz_loss)
return total_loss
criterion = CombinedSegmentationLoss()
Generative Model Loss Combinations
class GANLossCombination(nn.Module):
def __init__(self):
super().__init__()
self.gan_loss = torchium.losses.GANLoss()
self.perceptual_loss = torchium.losses.PerceptualLoss()
self.feature_matching_loss = torchium.losses.FeatureMatchingLoss()
def forward(self, fake_pred, real_pred, fake_features, real_features):
gan_loss = self.gan_loss(fake_pred, real_pred)
perceptual_loss = self.perceptual_loss(fake_features, real_features)
feature_matching_loss = self.feature_matching_loss(fake_features, real_features)
return gan_loss + 0.1 * perceptual_loss + 0.1 * feature_matching_loss
Custom Parameter Groups
Advanced Parameter Grouping
# Different optimizers for different parts
param_groups = [
{
'params': model.backbone.parameters(),
'lr': 1e-4,
'weight_decay': 1e-4
},
{
'params': model.classifier.parameters(),
'lr': 1e-3,
'weight_decay': 1e-5
},
{
'params': model.bn.parameters(),
'lr': 1e-3,
'weight_decay': 0 # No weight decay for batch norm
}
]
optimizer = torchium.optimizers.AdamW(param_groups)
# Or use factory function for complex grouping
optimizer = torchium.utils.factory.create_optimizer_with_groups(
model,
'adamw',
lr=1e-3,
weight_decay=1e-4,
no_decay=['bias', 'bn', 'ln'] # Exclude these from weight decay
)
Learning Rate Scheduling
Custom Learning Rate Schedules
# Warmup + cosine annealing
def get_lr_scheduler(optimizer, warmup_epochs, total_epochs):
def lr_lambda(epoch):
if epoch < warmup_epochs:
return epoch / warmup_epochs
else:
return 0.5 * (1 + math.cos(math.pi * (epoch - warmup_epochs) / (total_epochs - warmup_epochs)))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
scheduler = get_lr_scheduler(optimizer, warmup_epochs=10, total_epochs=100)
# Training loop with scheduler
for epoch in range(100):
# Training step
train_one_epoch(model, optimizer, criterion, dataloader)
# Update learning rate
scheduler.step()
Gradient Clipping
Advanced Gradient Clipping
# Gradient clipping with different methods
def train_with_clipping(model, optimizer, criterion, dataloader, max_norm=1.0):
for batch in dataloader:
optimizer.zero_grad()
output = model(batch.input)
loss = criterion(output, batch.target)
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
# Or use built-in clipping for some optimizers
optimizer = torchium.optimizers.AdamW(
model.parameters(),
lr=1e-3,
max_grad_norm=1.0 # Built-in gradient clipping
)
Mixed Precision Training
Automatic Mixed Precision
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for epoch in range(100):
for batch in dataloader:
optimizer.zero_grad()
with autocast():
output = model(batch.input)
loss = criterion(output, batch.target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Distributed Training
Multi-GPU Training
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# Initialize distributed training
dist.init_process_group(backend='nccl')
# Wrap model with DDP
model = DDP(model)
# Use LARS for distributed training
optimizer = torchium.optimizers.LARS(
model.parameters(),
lr=1e-3,
momentum=0.9,
weight_decay=1e-4
)
Performance Optimization
Memory Optimization
# Use Lion for memory efficiency
optimizer = torchium.optimizers.Lion(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=1e-2
)
# Gradient checkpointing for large models
from torch.utils.checkpoint import checkpoint
class CheckpointedModel(nn.Module):
def forward(self, x):
return checkpoint(self._forward, x)
def _forward(self, x):
# Your model forward pass
return self.layers(x)
Profiling and Debugging
# Profile optimizer performance
import torch.profiler
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/profiler')
) as prof:
for step, batch in enumerate(dataloader):
optimizer.zero_grad()
output = model(batch.input)
loss = criterion(output, batch.target)
loss.backward()
optimizer.step()
prof.step()
Best Practices
Choose the Right Optimizer: - SAM for better generalization - Lion for memory efficiency - LBFGS for well-conditioned problems - CMA-ES for global optimization
Combine Losses Wisely: - Use uncertainty weighting for multi-task learning - Combine complementary losses (e.g., Dice + Focal) - Balance loss weights carefully
Parameter Grouping: - Different learning rates for different layers - Exclude batch norm from weight decay - Use appropriate weight decay values
Learning Rate Scheduling: - Use warmup for stable training - Cosine annealing for better convergence - Monitor learning rate during training
Gradient Management: - Use gradient clipping for stability - Monitor gradient norms - Use gradient surgery for multi-task learning
Memory Management: - Use Lion for memory efficiency - Gradient checkpointing for large models - Mixed precision training when possible