sadaco.apis.traintest package

Submodules

sadaco.apis.traintest.common module

sadaco.apis.traintest.common.load_input(input_path, mode='stft', window_size=70, hop_length=25, sample_rate=16000)[source]
sadaco.apis.traintest.common.load_wav(input_path)[source]
sadaco.apis.traintest.common.stft2mel(mag, n_mels=128, sample_rate=16000)[source]
sadaco.apis.traintest.common.recover_wav(mag, phase, window_size, hop_length, sample_rate)[source]

sadaco.apis.traintest.demo module

class sadaco.apis.traintest.demo.demo_helper(master_cfg, model_cfg)[source]

Bases: object

do_inference(input_path, return_raw=False)[source]
do_explanation(input_path, method, cls)[source]

sadaco.apis.traintest.eval module

sadaco.apis.traintest.eval.move_device(data: DefaultDict, device: torch.device)[source]
sadaco.apis.traintest.eval.test_basic_epoch(model: torch.nn.modules.module.Module, device: torch.device, test_loader: torch.utils.data.dataloader.DataLoader, metrics: sadaco.utils.stats.ICBHI_Metrics, criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], epoch: int, verbose: bool = True, preprocessing: Callable = None) → Optional[Union[DefaultDict, numpy.ndarray]][source]
sadaco.apis.traintest.eval.test_merge_epoch(model: torch.nn.modules.module.Module, device: torch.device, test_loader: torch.utils.data.dataloader.DataLoader, metrics: sadaco.utils.stats.ICBHI_Metrics, criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], epoch: int, verbose: bool = True, preprocessing: Callable = None, target_classes: Union[List, Tuple] = None) → Optional[Union[DefaultDict, numpy.ndarray]][source]

sadaco.apis.traintest.preprocessings module

class sadaco.apis.traintest.preprocessings.Preprocessor(preproc_modules: List = None)[source]

Bases: object

add_module(module)[source]
to(device)[source]
class sadaco.apis.traintest.preprocessings.stft2meldb(n_stft, n_mels=128, sample_rate=16000)[source]

Bases: object

to(device)[source]
class sadaco.apis.traintest.preprocessings.normalize(mean=0, std=1)[source]

Bases: object

to(device)[source]

sadaco.apis.traintest.train module

sadaco.apis.traintest.train.move_device(data: DefaultDict, device: torch.device)[source]
sadaco.apis.traintest.train.train_basic_epoch(model: torch.nn.modules.module.Module, device: torch.device, train_loader: torch.utils.data.dataloader.DataLoader, optimizer: <module 'torch.optim' from '/home/docs/checkouts/readthedocs.org/user_builds/sadaco/envs/latest/lib/python3.7/site-packages/torch/optim/__init__.py'>, criterion: Callable, epoch: int, return_stats: bool = False, verbose: bool = False, preprocessing: Callable = None, grad_thres=None, update_interval=1) → Optional[Union[DefaultDict, numpy.ndarray]][source]

_summary_

Parameters
  • model (nn.Module) – _description_

  • device (torch.device) – _description_

  • train_loader (DataLoader) – _description_

  • optimizer (torch.optim) – _description_

  • criterion (Callable) – _description_

  • epoch (int) – _description_

  • return_stats (bool, optional) – _description_, defaults to False

  • verbose (bool, optional) – _description_, defaults to False

  • preprocessing (Callable, optional) – _description_, defaults to None

  • grad_thres (_type_, optional) – _description_, defaults to None

  • update_interval (int, optional) – _description_, defaults to 1

Returns

_description_

Return type

Optional[Union[DefaultDict, np.ndarray]]

sadaco.apis.traintest.trainer_base module

class sadaco.apis.traintest.trainer_base.BaseTrainer(train_configs)[source]

Bases: object

Base template class for the trainers. Trainers for each datasets are made on top of this class inheriting basic functions like train, test, validate containing the typical pipeline procedures. Users can also override some of the functions in order to meet user-specific requirements.

Returns

Trainer instance.

Return type

BaseTrainer

__init__(train_configs)[source]

Trainer will parse and load configurations given yaml configuration path, including model configs and data configs following the path information written in the master configs.

Parameters

train_configs (munch - python object) – YAML file path containing Master Configuration Settings.

Variables
  • configs – Master configs given as the train_configs

  • data_configs – Data configs parsed from train_configs.data_configs.file

  • model_configs – Model configs parsed from train_configs.model_configs.file

  • log_configs – Total Configuration containing all of the config settings. Logger will log this as a project configuration.

  • logger – Logger instance that contains configuration information and the train/val stats. Recommend using wandb since our BaseLogger only provides raw data saving. Checkout https://docs.wandb.ai/quickstart to make wandb account.

  • model – Built model from the given model configs. This will be used in training and inferencing.

  • optimizer – Model optimizer that will update the model while training. User can specify resume option wheter to resume optimizer too or not.

  • device – Model and Optimizer device location. cuda:0 by default if cuda is available, else cpu.

  • train_criterion – Criterion used in training. Currently supports only one criterion function. User have to create hybrid criterion Callable or override training procedures to use multiple target functions.

  • valid_criterion – Criterion used in validation. Currently supports only one criterion function. User have to create hybrid criterion Callable or override validation procedures to use multiple target functions.

  • scheduler – Training scheduler which controls hyperparameter(currently: LR only) and model versions.

  • preproc – Preprocessor(__Callable__) containing input preprocessing pipeline. Ignored when given None.

  • _progress – Training progress(#Epochs). Web session use this to query background training state.

build_dataset()[source]

_summary_

Raises

NotImplementedError – _description_

build_dataloader()[source]

_summary_

build_model()[source]

_summary_

Returns

_description_

Return type

_type_

parallel()[source]

_summary_

build_optimizer(trainables=None)[source]

_summary_

Parameters

trainables (_type_, optional) – _description_, defaults to None

Returns

_description_

Return type

_type_

build_logger(use_wandb=True)[source]

_summary_

Parameters

use_wandb (bool, optional) – _description_, defaults to True

Returns

_description_

Return type

_type_

reset_trainer()[source]

_summary_

resume()[source]

_summary_

train()[source]

_summary_

Returns

_description_

Return type

_type_

train_kfold(k)[source]

_summary_

Parameters

k (_type_) – _description_

Returns

_description_

Return type

_type_

prepare_kfold(i, k)[source]

_summary_

Parameters
  • i (_type_) – _description_

  • k (_type_) – _description_

validate(return_stats=True)[source]

_summary_

Parameters

return_stats (bool, optional) – _description_, defaults to True

Returns

_description_

Return type

_type_

test(**kwargs)[source]

_summary_

Returns

_description_

Return type

_type_

attach_layer_handler(layers)[source]
train_epoch()[source]

_summary_

Raises

NotImplementedError – _description_

validate_epoch()[source]

_summary_

Raises

NotImplementedError – _description_

sadaco.apis.traintest.trainer_base.build_optimizer(model, train_configs, trainables=None)[source]

_summary_

Parameters
  • model (_type_) – _description_

  • train_configs (_type_) – _description_

  • trainables (_type_, optional) – _description_, defaults to None

Returns

_description_

Return type

_type_

sadaco.apis.traintest.trainer_base.build_dataloader(dataset, train_configs, data_configs)[source]

_summary_

Parameters
  • dataset (_type_) – _description_

  • train_configs (_type_) – _description_

  • data_configs (_type_) – _description_

Returns

_description_

Return type

_type_

sadaco.apis.traintest.trainer_base.build_criterion(name, mixup=False, **kwargs)[source]

_summary_

Parameters
  • name (_type_) – _description_

  • mixup (bool, optional) – _description_, defaults to False

Returns

_description_

Return type

_type_

Module contents