Source code for sadaco.apis.losses.CustomLoss

import torch
from typing import Union, Callable, Optional, Any

[docs]def mixup_criterion(criterion: Callable, **criterion_options): class MixedCriterion(criterion): def __init__(self, **criterion_options): self.reduction = None self.nonmix_forward = super().forward super().__init__(**criterion_options) if self.reduction is None: if 'reduction' in criterion_options.keys(): self.reduction = criterion_options['reduction'] else: self.reduction = 'mean' def forward(self, output: torch.Tensor, label: torch.Tensor, label2: torch.Tensor, lam: Union[float, torch.Tensor], **kwargs): loss1 = self.nonmix_forward(output=output, label=label, **kwargs) loss2 = self.nonmix_forward(output=output, label=label2, **kwargs) if len(lam.shape) < len(loss1.shape): lam = lam.unsqueeze(-1) mixup_loss = lam * loss1 + (1 - lam) * loss2 if self.reduction == 'sum': mixup_loss = mixup_loss.sum() else: mixup_loss = mixup_loss.mean() return mixup_loss return MixedCriterion(**criterion_options)