Advanced Usage Guide

This guide covers advanced features and usage patterns in Torchium, including meta-optimizers, experimental algorithms, and sophisticated loss combinations.

Meta-Optimizers

Sharpness-Aware Minimization (SAM)

SAM finds flatter minima for better generalization:

import torch
import torch.nn as nn
import torchium

model = nn.Linear(10, 1)

# Basic SAM
optimizer = torchium.optimizers.SAM(
    model.parameters(),
    lr=1e-3,
    rho=0.05,  # Perturbation radius
    adaptive=False
)

# Training loop with SAM
for epoch in range(100):
    # First forward pass
    output = model(x)
    loss = criterion(output, y)
    loss.backward()

    # SAM perturbation step
    optimizer.first_step(zero_grad=True)

    # Second forward pass
    output = model(x)
    loss = criterion(output, y)
    loss.backward()

    # SAM update step
    optimizer.second_step(zero_grad=True)

Adaptive SAM (ASAM)

ASAM adapts the perturbation radius:

optimizer = torchium.optimizers.ASAM(
    model.parameters(),
    lr=1e-3,
    rho=0.5,  # Initial perturbation radius
    eta=0.01  # Adaptation rate
)

Gradient Surgery Methods

PCGrad for Multi-Task Learning

class MultiTaskModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared = nn.Linear(10, 64)
        self.task1 = nn.Linear(64, 1)
        self.task2 = nn.Linear(64, 5)

model = MultiTaskModel()

# Use PCGrad for gradient surgery
optimizer = torchium.optimizers.PCGrad(
    model.parameters(),
    lr=1e-3
)

# Training with multiple tasks
for epoch in range(100):
    optimizer.zero_grad()

    # Forward passes for different tasks
    shared_features = model.shared(x)
    task1_output = model.task1(shared_features)
    task2_output = model.task2(shared_features)

    # Compute losses
    loss1 = criterion1(task1_output, y1)
    loss2 = criterion2(task2_output, y2)

    # PCGrad handles gradient conflicts
    optimizer.step([loss1, loss2])

GradNorm for Dynamic Loss Balancing

optimizer = torchium.optimizers.GradNorm(
    model.parameters(),
    lr=1e-3,
    alpha=1.5  # Restoring force hyperparameter
)

Second-Order Optimizers

LBFGS for Well-Conditioned Problems

# LBFGS works best with full batch or large batches
optimizer = torchium.optimizers.LBFGS(
    model.parameters(),
    lr=1.0,
    max_iter=20,
    max_eval=None,
    tolerance_grad=1e-7,
    tolerance_change=1e-9,
    history_size=100,
    line_search_fn="strong_wolfe"
)

# Training loop for LBFGS
def closure():
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, y)
    loss.backward()
    return loss

for epoch in range(100):
    optimizer.step(closure)

Shampoo for Large Models

optimizer = torchium.optimizers.Shampoo(
    model.parameters(),
    lr=1e-3,
    momentum=0.9,
    weight_decay=1e-4,
    epsilon=1e-4,
    update_freq=1
)

Experimental Optimizers

CMA-ES for Global Optimization

# CMA-ES for non-convex optimization
optimizer = torchium.optimizers.CMAES(
    model.parameters(),
    population_size=20,
    sigma=0.1,
    max_generations=1000
)

# Training loop for CMA-ES
for generation in range(1000):
    optimizer.step()
    if optimizer.should_stop():
        break

Differential Evolution

optimizer = torchium.optimizers.DifferentialEvolution(
    model.parameters(),
    population_size=30,
    mutation_factor=0.8,
    crossover_probability=0.9,
    max_generations=1000
)

Particle Swarm Optimization

optimizer = torchium.optimizers.ParticleSwarmOptimization(
    model.parameters(),
    swarm_size=20,
    inertia_weight=0.9,
    cognitive_weight=2.0,
    social_weight=2.0,
    max_iterations=1000
)

Advanced Loss Combinations

Multi-Task Learning with Uncertainty Weighting

class MultiTaskLoss(nn.Module):
    def __init__(self, num_tasks):
        super().__init__()
        self.uncertainty_loss = torchium.losses.UncertaintyWeightingLoss(num_tasks)
        self.task_losses = [
            torchium.losses.MSELoss(),
            torchium.losses.CrossEntropyLoss(),
            torchium.losses.DiceLoss()
        ]

    def forward(self, predictions, targets):
        losses = []
        for i, (pred, target) in enumerate(zip(predictions, targets)):
            loss = self.task_losses[i](pred, target)
            losses.append(loss)

        return self.uncertainty_loss(losses)

criterion = MultiTaskLoss(num_tasks=3)

Combined Segmentation Loss

class CombinedSegmentationLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.dice = torchium.losses.DiceLoss(smooth=1e-5)
        self.focal = torchium.losses.FocalLoss(alpha=0.25, gamma=2.0)
        self.tversky = torchium.losses.TverskyLoss(alpha=0.3, beta=0.7)
        self.lovasz = torchium.losses.LovaszLoss()

    def forward(self, pred, target):
        dice_loss = self.dice(pred, target)
        focal_loss = self.focal(pred, target)
        tversky_loss = self.tversky(pred, target)
        lovasz_loss = self.lovasz(pred, target)

        # Weighted combination
        total_loss = (0.4 * dice_loss +
                     0.3 * focal_loss +
                     0.2 * tversky_loss +
                     0.1 * lovasz_loss)

        return total_loss

criterion = CombinedSegmentationLoss()

Generative Model Loss Combinations

class GANLossCombination(nn.Module):
    def __init__(self):
        super().__init__()
        self.gan_loss = torchium.losses.GANLoss()
        self.perceptual_loss = torchium.losses.PerceptualLoss()
        self.feature_matching_loss = torchium.losses.FeatureMatchingLoss()

    def forward(self, fake_pred, real_pred, fake_features, real_features):
        gan_loss = self.gan_loss(fake_pred, real_pred)
        perceptual_loss = self.perceptual_loss(fake_features, real_features)
        feature_matching_loss = self.feature_matching_loss(fake_features, real_features)

        return gan_loss + 0.1 * perceptual_loss + 0.1 * feature_matching_loss

Custom Parameter Groups

Advanced Parameter Grouping

# Different optimizers for different parts
param_groups = [
    {
        'params': model.backbone.parameters(),
        'lr': 1e-4,
        'weight_decay': 1e-4
    },
    {
        'params': model.classifier.parameters(),
        'lr': 1e-3,
        'weight_decay': 1e-5
    },
    {
        'params': model.bn.parameters(),
        'lr': 1e-3,
        'weight_decay': 0  # No weight decay for batch norm
    }
]

optimizer = torchium.optimizers.AdamW(param_groups)

# Or use factory function for complex grouping
optimizer = torchium.utils.factory.create_optimizer_with_groups(
    model,
    'adamw',
    lr=1e-3,
    weight_decay=1e-4,
    no_decay=['bias', 'bn', 'ln']  # Exclude these from weight decay
)

Learning Rate Scheduling

Custom Learning Rate Schedules

# Warmup + cosine annealing
def get_lr_scheduler(optimizer, warmup_epochs, total_epochs):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return epoch / warmup_epochs
        else:
            return 0.5 * (1 + math.cos(math.pi * (epoch - warmup_epochs) / (total_epochs - warmup_epochs)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

scheduler = get_lr_scheduler(optimizer, warmup_epochs=10, total_epochs=100)

# Training loop with scheduler
for epoch in range(100):
    # Training step
    train_one_epoch(model, optimizer, criterion, dataloader)

    # Update learning rate
    scheduler.step()

Gradient Clipping

Advanced Gradient Clipping

# Gradient clipping with different methods
def train_with_clipping(model, optimizer, criterion, dataloader, max_norm=1.0):
    for batch in dataloader:
        optimizer.zero_grad()
        output = model(batch.input)
        loss = criterion(output, batch.target)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

        optimizer.step()

# Or use built-in clipping for some optimizers
optimizer = torchium.optimizers.AdamW(
    model.parameters(),
    lr=1e-3,
    max_grad_norm=1.0  # Built-in gradient clipping
)

Mixed Precision Training

Automatic Mixed Precision

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for epoch in range(100):
    for batch in dataloader:
        optimizer.zero_grad()

        with autocast():
            output = model(batch.input)
            loss = criterion(output, batch.target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

Distributed Training

Multi-GPU Training

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# Initialize distributed training
dist.init_process_group(backend='nccl')

# Wrap model with DDP
model = DDP(model)

# Use LARS for distributed training
optimizer = torchium.optimizers.LARS(
    model.parameters(),
    lr=1e-3,
    momentum=0.9,
    weight_decay=1e-4
)

Performance Optimization

Memory Optimization

# Use Lion for memory efficiency
optimizer = torchium.optimizers.Lion(
    model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.99),
    weight_decay=1e-2
)

# Gradient checkpointing for large models
from torch.utils.checkpoint import checkpoint

class CheckpointedModel(nn.Module):
    def forward(self, x):
        return checkpoint(self._forward, x)

    def _forward(self, x):
        # Your model forward pass
        return self.layers(x)

Profiling and Debugging

# Profile optimizer performance
import torch.profiler

with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
    on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/profiler')
) as prof:
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()
        output = model(batch.input)
        loss = criterion(output, batch.target)
        loss.backward()
        optimizer.step()
        prof.step()

Best Practices

  1. Choose the Right Optimizer: - SAM for better generalization - Lion for memory efficiency - LBFGS for well-conditioned problems - CMA-ES for global optimization

  2. Combine Losses Wisely: - Use uncertainty weighting for multi-task learning - Combine complementary losses (e.g., Dice + Focal) - Balance loss weights carefully

  3. Parameter Grouping: - Different learning rates for different layers - Exclude batch norm from weight decay - Use appropriate weight decay values

  4. Learning Rate Scheduling: - Use warmup for stable training - Cosine annealing for better convergence - Monitor learning rate during training

  5. Gradient Management: - Use gradient clipping for stability - Monitor gradient norms - Use gradient surgery for multi-task learning

  6. Memory Management: - Use Lion for memory efficiency - Gradient checkpointing for large models - Mixed precision training when possible