Source code for sadaco.pipelines.ICBHI

from turtle import update
from sadaco.apis.traintest import BaseTrainer,  train_basic_epoch, test_basic_epoch
from sadaco.apis.contrastive import ContrastTrainer, train_mixcon_epoch
from sadaco.apis.traintest import preprocessings as preps
from sadaco.utils.stats import ICBHI_Metrics, print_stats
from sadaco.utils.config_parser import ArgsParser
from sadaco.pipelines.build_modules import build_criterion
from sadaco.pipelines.scheduler import BaseScheduler

from sadaco.dataman.icbhi.dummy import RespiDatasetSTFT

import torch
# torch.autograd.set_detect_anomaly(True)
from torch.utils.data import DataLoader

[docs]class ICBHI_Basic_Trainer(BaseTrainer): def __init__(self, configs): super().__init__(configs) self.resume() self.preproc = preps.Preprocessor( [preps.stft2meldb(n_stft=self.train_dataset.n_stft, n_mels = self.train_dataset.num_mel, sample_rate=self.train_dataset.sample_rate)] ) self.evaluator = ICBHI_Metrics(num_classes=4, normal_class_label=0)
[docs] def build_dataset(self): self.train_dataset = RespiDatasetSTFT(split='train', **self.data_configs.train) self.val_dataset = RespiDatasetSTFT(split='val', **self.data_configs.val)
[docs] def train_epoch(self, epoch): train_stats = train_basic_epoch(model=self.model, device=self.device, train_loader=self.train_loader, optimizer=self.optimizer,criterion=self.train_criterion, epoch=epoch, return_stats=True, verbose = False, preprocessing=self.preproc, grad_thres=10., update_interval=self.configs.train.update_interval) return train_stats
[docs] def validate_epoch(self, epoch): val_stats = test_basic_epoch(self.model,self.device, self.val_loader, self.evaluator, criterion=self.valid_criterion, epoch=epoch, verbose=False, preprocessing=self.preproc) return val_stats
[docs]class ICBHI_Contrast_Trainer(ContrastTrainer): def __init__(self, configs): super().__init__(configs) self.preproc = preps.Preprocessor( [preps.stft2meldb(n_stft=self.train_dataset.n_stft, n_mels = self.train_dataset.num_mel, sample_rate=self.train_dataset.sample_rate)] ) self.attach_extractor() self.wrap_model() # Should redo optimizer building since the model is wrapped self.optimizer = self.build_optimizer() self.resume() self.contrast_criterion = build_criterion(self.configs.train.contrast_criterion.name, mixup=self.configs.train.contrast_criterion.loss_mixup) self.train_criterion = build_criterion(self.configs.train.criterion.name, mixup=self.configs.train.criterion.loss_mixup, **self.configs.train.criterion.params) self.valid_criterion = build_criterion(self.configs.train.criterion.name, mixup=False, **self.configs.train.criterion.params) self.scheduler = BaseScheduler(self.configs, self.optimizer, self.model, exp_id=self.logger.name, parallel=self.model_configs.data_parallel) self.evaluator = ICBHI_Metrics(num_classes=4, normal_class_label=0)
[docs] def build_dataset(self): self.train_dataset = RespiDatasetSTFT(split='train', **self.data_configs.train) self.val_dataset = RespiDatasetSTFT(split='val', **self.data_configs.val)
[docs] def train_epoch(self, epoch): train_stats = train_mixcon_epoch(model=self.model, device=self.device, train_loader=self.train_loader, optimizer=self.optimizer, base_criterion=self.train_criterion, contrast_criterion=self.contrast_criterion, epoch=epoch, return_stats=True, verbose = False, preprocessing = self.preproc, grad_thres=10., update_interval=self.configs.train.update_interval) return train_stats
[docs] def validate_epoch(self, epoch): val_stats = test_basic_epoch(self.model,self.device, self.val_loader, self.evaluator, criterion=self.valid_criterion, epoch=epoch, verbose=False, preprocessing=self.preproc) return val_stats
def main(configs): if configs.train.method == 'contrastive': trainer = ICBHI_Contrast_Trainer(configs) elif configs.train.method == 'basic': trainer = ICBHI_Basic_Trainer(configs) else: raise ValueError("Method is not on the available list of [basic, contrastive]") if configs.fold is None: trainer.train() else: trainer.train_kfold(configs.fold) # results = trainer.test(return_stats=True) # print(print_stats(results)) def parse_configs(): parser = ArgsParser() # One can use the config files for the default settings, # and override settings by manually giving the arguments # Currently, overriding only the top-level arguments are available parser.add_argument("--mixup", action='store_true') parser.add_argument("--fold", default=None, type=int) parser.add_argument("--seed", default=None, type=int) args = parser.get_args() return args if __name__ == "__main__": configs = parse_configs() from sadaco.utils.misc import seed_everything seed_everything(configs.seed) main(configs)