sadaco.apis.contrastive package
Submodules
sadaco.apis.contrastive.train_byol module
sadaco.apis.contrastive.train_supcon module
-
sadaco.apis.contrastive.train_supcon.train_mixcon_epoch(model: torch.nn.modules.module.Module, device: torch.device, train_loader: torch.utils.data.dataloader.DataLoader, optimizer: <module 'torch.optim' from '/home/docs/checkouts/readthedocs.org/user_builds/sadaco/envs/latest/lib/python3.7/site-packages/torch/optim/__init__.py'>, base_criterion: Callable, contrast_criterion: Callable, epoch: int, weights: List[int] = [1, 1], return_stats: bool = True, verbose: bool = False, preprocessing: Callable = None, grad_thres=None, update_interval=1) → Optional[Union[DefaultDict, numpy.ndarray]][source]
-
sadaco.apis.contrastive.train_supcon.train_mixcon_epoch2(model: torch.nn.modules.module.Module, device: torch.device, train_loader: torch.utils.data.dataloader.DataLoader, optimizer: <module 'torch.optim' from '/home/docs/checkouts/readthedocs.org/user_builds/sadaco/envs/latest/lib/python3.7/site-packages/torch/optim/__init__.py'>, base_criterion: Callable, contrast_criterion: Callable, epoch: int, weights: List[int] = [1, 1], return_stats: bool = True, verbose: bool = False, preprocessing: Callable = None, grad_thres=None, update_interval=1, multi_label=False) → Optional[Union[DefaultDict, numpy.ndarray]][source]
sadaco.apis.contrastive.trainer_contrastive module
-
class
sadaco.apis.contrastive.trainer_contrastive.ContrastTrainer(train_configs)[source] Bases:
sadaco.apis.traintest.trainer_base.BaseTrainer-
attach_layer_handler(layers)
-
build_dataloader() _summary_
-
build_dataset() _summary_
- Raises
NotImplementedError – _description_
-
build_logger(use_wandb=True) _summary_
- Parameters
use_wandb (bool, optional) – _description_, defaults to True
- Returns
_description_
- Return type
_type_
-
build_model() _summary_
- Returns
_description_
- Return type
_type_
-
build_optimizer(trainables=None) _summary_
- Parameters
trainables (_type_, optional) – _description_, defaults to None
- Returns
_description_
- Return type
_type_
-
parallel() _summary_
-
prepare_kfold(i, k) _summary_
- Parameters
i (_type_) – _description_
k (_type_) – _description_
-
reset_trainer() _summary_
-
resume() _summary_
-
test(**kwargs) _summary_
- Returns
_description_
- Return type
_type_
-
train() _summary_
- Returns
_description_
- Return type
_type_
-
train_epoch() _summary_
- Raises
NotImplementedError – _description_
-
train_kfold(k) _summary_
- Parameters
k (_type_) – _description_
- Returns
_description_
- Return type
_type_
-
validate(return_stats=True) _summary_
- Parameters
return_stats (bool, optional) – _description_, defaults to True
- Returns
_description_
- Return type
_type_
-
validate_epoch() _summary_
- Raises
NotImplementedError – _description_
-