Source code for sadaco.utils.stats

from typing import DefaultDict, Tuple, Union

import numpy as np
import torch
from sklearn.metrics import precision_recall_fscore_support


[docs]class ICBHI_Metrics:
[docs] def __init__( self, num_classes: int = 4, normal_class_label: int = None, mixup: bool = False, mini_batch: bool = False, multi_label: bool = False, ) -> None: """Evaluation metrics for ICBHI challenge. Definitions: Sensitivity: The ability of the classifier to identify audio with the disease correctly. Specificity: The ability of the classifier to identify audio without the disease correctly. ICBHI Score: The average of sensitivity and specificity. Confusion Matrix: Task 1: Anomaly Cycle Classification (True) \ C, W, B, N \ C Cc, Wc, Bc, Nc \ W Cw, Ww, Bw, Nw \ (y_pred) B Cb, Wb, Bb, Nb \ N Cn, Wn, Bn, Nn \ (Total-True) Ct, Wt, Bt, Nt Task 2: Respiratory Disease Detection (True) \ C, NC, H \ (y_pred) C Cc, NCc, Hc \ NC Cnc, NCnc, Hnc \ H Ch, NCh, Hh \ (Total-True) Ct, NCt, Ht Label: TASK1_1 = {"Normal": 0, "Wheezes": 1, "Crackles": 2, "Both": 3} \ TASK1_2 = {"Normal": 0, "Abnormal": 1} \ TASK2_1 = {"Healthy": 0, "Chronic": 1, "NonChronic": 2} \ TASK2_2 = {"Healthy": 0, "Unhealthy": 1} Equations: Sensitivity Task1-1: (Cc + Ww + Bb) / (Ct + Wt + Bt) \ Sensitivity Task1-2: (Ccwb + Wcwb + Bcwb) / (Ct + Wt + Bt) \ Specificity Task1: Nn / Nt Sensitivity Task2-1: (Cc + NCnc) / (Ct + NCt) \ Sensitivity Task2-2: (Ccnc + NCcnc) / (Ct + NCt) \ Specificity Task2: Hh / Ht Score: (Sensitivity + Specificity) / 2 \ Accuracy: (TP + TN) / (TP + TN + FP + FN) :param num_classes: Number of classes to evaluate, defaults to 4 :type num_classes: int, optional :param normal_class_label: class label that will be considered 'Normal', defaults to None :type normal_class_label: int, optional :param mixup: Indicate whether the mixup operation is applied, defaults to False :type mixup: bool, optional :param mini_batch: batch to be evaluated, defaults to False :type mini_batch: bool, optional """ self.num_classes = num_classes self.normal_class_label = normal_class_label self.mixup = mixup self.mini_batch = mini_batch self.multi_label = multi_label self._init_attr()
[docs] def binary2int(self, tensor): tensor = tensor[:,0] + 2*tensor[:,1] return tensor
[docs] def int2onehot(self, tensor): tensor = torch.stack([tensor==0, tensor==1, tensor==2, tensor==3], dim=1).int() return tensor
[docs] def update_lists( self, logits: torch.Tensor = None, y_true: torch.Tensor = None, y_pred: torch.Tensor = None, ) -> None: if self.multi_label: y_pred = (torch.sigmoid(logits)>0.5).int() y_true = self.binary2int(y_true) y_pred = self.binary2int(y_pred) else: y_true = y_true.argmax(dim=1) _, y_pred = logits.max(1) if self.mini_batch: self.y_true = torch.cat((self.y_true, y_true.detach().cpu()), dim=0) self.y_pred = torch.cat((self.y_pred, y_pred.detach().cpu()), dim=0) else: self.y_true = torch.cat((self.y_true, y_true.detach().cpu()), dim=0) self.y_pred = torch.cat((self.y_pred, y_pred.detach().cpu()), dim=0) self.y_pred_prob = torch.cat( (self.y_pred_prob, logits.softmax(-1).detach().cpu()), dim=0 )
[docs] def get_stats(self) -> Tuple[float, float, float, float]: r"""Compute the sensitivity, specificity, score and balanced accuracy based on ICBHI challenge definition through confusion matrix. Returns: acc: The accuracy. se: The sensitivity. sp: The specificity. sc: The score. """ self._compute_confusion_matrix() self._compute_icbhi_scores() self.acc = self.confusion_matrix.diag().sum() / self.confusion_matrix.sum() self.precision, self.recall, self.f1score, _ = self.get_precision_recall_fbeta() return {'Accuracy':self.acc, 'Sensitivity':self.se, 'Specificity':self.sp, 'Score':self.sc, 'Precision':self.precision, 'Recall': self.recall, 'F1-Score':self.f1score}
[docs] def get_mixup_stats(self) -> float: self.acc = self.correct / self.total return self.acc
[docs] def get_precision_recall_fbeta( self, average: str = "macro", fbeta: float = 1.0 ) -> Tuple[ Union[float, np.array], Union[float, np.array], Union[float, np.array], np.array ]: r"""Compute the precision, recall and fbeta score. Returns: `shape = (num_classes,)` precision: The precision score. recall: The recall score. f-beta: The f-beta score. support: The number of samples in each class. """ assert average in [ "macro", "micro", "weighted", ], "average must be macro, micro or weighted" return precision_recall_fscore_support( y_true=self.y_true, y_pred=self.y_pred, average=average, beta=fbeta, zero_division=0, )
[docs] def reset_metrics(self) -> None: self._init_attr()
def _init_attr(self) -> None: if self.mixup: self.correct = 0.0 self.total = 0 else: self.y_true = torch.tensor([], dtype=torch.long) self.y_pred = torch.tensor([], dtype=torch.long) self.y_pred_prob = torch.tensor([], dtype=torch.float) self.acc = 0.0 def _compute_confusion_matrix(self): self.confusion_matrix = self.num_classes * self.y_true + self.y_pred self.confusion_matrix = torch.bincount(self.confusion_matrix.int()) if len(self.confusion_matrix) < self.num_classes * self.num_classes: self.confusion_matrix = torch.cat( ( self.confusion_matrix, torch.zeros( self.num_classes * self.num_classes - len(self.confusion_matrix), dtype=torch.long, ), ), dim=0, ) self.confusion_matrix = self.confusion_matrix.reshape( self.num_classes, self.num_classes ).T def _compute_icbhi_scores(self): self._compute_sp() self._compute_se() self._compute_sc() def _compute_sp(self) -> None: self.sp = ( self.confusion_matrix[self.normal_class_label, self.normal_class_label] / self.confusion_matrix[:, self.normal_class_label].sum() ) def _compute_se(self) -> None: self.se = ( self.confusion_matrix.diag().sum() - self.confusion_matrix[self.normal_class_label, self.normal_class_label] ) / self.confusion_matrix[:, self.normal_class_label + 1 :].sum() def _compute_sc(self) -> None: self.sc = (self.sp + self.se) * 0.5
if __name__ == "__main__": cm = ICBHI_Metrics(num_classes=4, normal_class_label=0) y_pred = torch.tensor( [ [0.0540, 0.0671, 0.2014, -0.1081], # 2 [0.0487, 0.0771, 0.0015, -0.1226], # 1 [0.0500, 0.0776, 0.0108, -0.1131], # 1 [0.0531, 0.0627, 0.0017, -0.1051], # 1 [0.0522, 0.0733, 0.0051, -0.1096], # 1 [0.0454, 0.0782, -0.0077, 0.1375], # 3 [0.0513, 0.0673, -0.0002, -0.1134], # 1 [0.0325, 0.0547, 0.0047, -0.0967], # 1 [0.0475, 0.0639, 0.0017, -0.1037], # 1 [0.1489, 0.0728, 0.0022, -0.1135], # 0 ] ) y_true = torch.tensor([3, 0, 1, 1, 0, 1, 3, 0, 1, 0]) cm.update_lists(logits=y_pred, y_true=y_true) cm.get_stats() assert cm.confusion_matrix.shape == (4, 4) assert round(cm.se.item(), 4) == 0.5000, "3/6 = 0.5000" assert round(cm.sp.item(), 4) == 0.2500, "1/4 = 0.2500" assert round(cm.sc.item(), 4) == 0.3750, "se + sp / 2 = 0.3750" assert round(cm.acc.item(), 4) == 0.4000, "4/10 = 0.4000" print(f"Confusion Matrix: \n{cm.confusion_matrix.numpy()}") print(f"==================== Test Passed ====================")