Source code for sadaco.apis.losses.BasicLoss

import torch
import torch.nn
from typing import Union
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss, MSELoss

[docs]class CELoss(CrossEntropyLoss): """_summary_ :param CrossEntropyLoss: _description_ :type CrossEntropyLoss: _type_ """
[docs] def __init__(self, mode : Union[str, int] ='onehot', **kwargs): """_summary_ :param mode: _description_, defaults to 'onehot' :type mode: Union[str, int], optional :raises ValueError: _description_ """ super().__init__(**kwargs) self.base_forward = super().forward if mode in ['onehot', 0]: self.mode = 0 elif mode in ['int', 1]: self.mode = 1 else: raise ValueError("Currently only Supporting One-hot or Integer")
[docs] def forward(self, output, label, **kwargs): """_summary_ :param output: _description_ :type output: _type_ :param label: _description_ :type label: _type_ :return: _description_ :rtype: _type_ """ if self.mode == 0: target = torch.argmax(label, axis=-1) else : target = target return self.base_forward(output, target)
[docs]class BCEWithLogitsLoss(BCEWithLogitsLoss): """_summary_ :param BCEWithLogitsLoss: _description_ :type BCEWithLogitsLoss: _type_ """
[docs] def __init__(self, mode : Union[str, int] ='multihot', max=None, **kwargs): """_summary_ :param mode: _description_, defaults to 'multihot' :type mode: Union[str, int], optional :param max: _description_, defaults to None :type max: _type_, optional :raises ValueError: _description_ :raises ValueError: _description_ """ super().__init__(**kwargs) self.base_forward = super().forward if mode in ['multihot', 0]: self.mode = 0 else: raise ValueError("Currently only Supporting Multi-hot")
[docs] def forward(self, output:torch.Tensor, label:torch.Tensor, **kwargs): """_summary_ :param input: _description_ :type input: torch.Tensor :param label: _description_ :type label: torch.Tensor :return: _description_ :rtype: _type_ """ if self.mode == 0: target = label else : raise ValueError("Currently only Supporting Multi-hot") return self.base_forward(output, target)
[docs]class Normalized_MSELoss(MSELoss): """A modified version of the MSELoss for non-constrastive self-supervised learning in BYOL, which is between the normalized predictions and target projections. :param MSELoss: Parent Loss 'MSELoss' from pytorch :type MSELoss: Loss fn :return: MSE Loss value :rtype: torch.Tensor """ def __init__(self, **kwargs): super().__init__(**kwargs) self.base_forward = super().forward
[docs] def forward( self, predictions: torch.Tensor, target_projections: torch.Tensor, **kwargs ): normalized_predictions = torch.nn.functional.normalize(predictions, dim=-1, p=2) normalized_target = torch.nn.functional.normalize( target_projections, dim=-1, p=2 ) return self.base_forward(normalized_predictions, normalized_target)