Source code for torchium.losses.regression.mse_variants

"""
MSE and its variants implementation.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from ...utils.registry import register_loss

[docs] @register_loss("mseloss") class MSELoss(nn.MSELoss): """Enhanced MSELoss with additional features.""" pass
[docs] @register_loss("maeloss") class MAELoss(nn.L1Loss): """Mean Absolute Error Loss.""" pass
[docs] @register_loss("huberloss") class HuberLoss(nn.HuberLoss): """Enhanced HuberLoss with additional features.""" pass
[docs] @register_loss("smoothl1loss") class SmoothL1Loss(nn.SmoothL1Loss): """Enhanced SmoothL1Loss with additional features.""" pass
[docs] @register_loss("quantileloss") class QuantileLoss(nn.Module): """Quantile Loss for quantile regression."""
[docs] def __init__(self, quantile=0.5, reduction='mean'): super().__init__() self.quantile = quantile self.reduction = reduction
[docs] def forward(self, input, target): error = target - input loss = torch.max(self.quantile * error, (self.quantile - 1) * error) if self.reduction == 'mean': return loss.mean() elif self.reduction == 'sum': return loss.sum() else: return loss
[docs] @register_loss("logcoshloss") class LogCoshLoss(nn.Module): """Log-Cosh Loss for robust regression."""
[docs] def __init__(self, reduction='mean'): super().__init__() self.reduction = reduction
[docs] def forward(self, input, target): x = input - target loss = torch.log(torch.cosh(x)) if self.reduction == 'mean': return loss.mean() elif self.reduction == 'sum': return loss.sum() else: return loss