sadaco.apis.contrastive package

Submodules

sadaco.apis.contrastive.train_byol module

sadaco.apis.contrastive.train_supcon module

sadaco.apis.contrastive.train_supcon.move_device(data: DefaultDict, device: torch.device)[source]
sadaco.apis.contrastive.train_supcon.train_mixcon_epoch(model: torch.nn.modules.module.Module, device: torch.device, train_loader: torch.utils.data.dataloader.DataLoader, optimizer: <module 'torch.optim' from '/home/docs/checkouts/readthedocs.org/user_builds/sadaco/envs/latest/lib/python3.7/site-packages/torch/optim/__init__.py'>, base_criterion: Callable, contrast_criterion: Callable, epoch: int, weights: List[int] = [1, 1], return_stats: bool = True, verbose: bool = False, preprocessing: Callable = None, grad_thres=None, update_interval=1) → Optional[Union[DefaultDict, numpy.ndarray]][source]
sadaco.apis.contrastive.train_supcon.train_mixcon_epoch2(model: torch.nn.modules.module.Module, device: torch.device, train_loader: torch.utils.data.dataloader.DataLoader, optimizer: <module 'torch.optim' from '/home/docs/checkouts/readthedocs.org/user_builds/sadaco/envs/latest/lib/python3.7/site-packages/torch/optim/__init__.py'>, base_criterion: Callable, contrast_criterion: Callable, epoch: int, weights: List[int] = [1, 1], return_stats: bool = True, verbose: bool = False, preprocessing: Callable = None, grad_thres=None, update_interval=1, multi_label=False) → Optional[Union[DefaultDict, numpy.ndarray]][source]

sadaco.apis.contrastive.trainer_contrastive module

class sadaco.apis.contrastive.trainer_contrastive.ContrastTrainer(train_configs)[source]

Bases: sadaco.apis.traintest.trainer_base.BaseTrainer

attach_extractor()[source]
wrap_model()[source]
attach_layer_handler(layers)
build_dataloader()

_summary_

build_dataset()

_summary_

Raises

NotImplementedError – _description_

build_logger(use_wandb=True)

_summary_

Parameters

use_wandb (bool, optional) – _description_, defaults to True

Returns

_description_

Return type

_type_

build_model()

_summary_

Returns

_description_

Return type

_type_

build_optimizer(trainables=None)

_summary_

Parameters

trainables (_type_, optional) – _description_, defaults to None

Returns

_description_

Return type

_type_

parallel()

_summary_

prepare_kfold(i, k)

_summary_

Parameters
  • i (_type_) – _description_

  • k (_type_) – _description_

reset_trainer()

_summary_

resume()

_summary_

test(**kwargs)

_summary_

Returns

_description_

Return type

_type_

train()

_summary_

Returns

_description_

Return type

_type_

train_epoch()

_summary_

Raises

NotImplementedError – _description_

train_kfold(k)

_summary_

Parameters

k (_type_) – _description_

Returns

_description_

Return type

_type_

validate(return_stats=True)

_summary_

Parameters

return_stats (bool, optional) – _description_, defaults to True

Returns

_description_

Return type

_type_

validate_epoch()

_summary_

Raises

NotImplementedError – _description_

class sadaco.apis.contrastive.trainer_contrastive.ContrastiveWrapper(model, use_mapper, mapper_classify, keepdims)[source]

Bases: torch.nn.modules.module.Module

to(device)[source]
forward(x)[source]
training: bool
class sadaco.apis.contrastive.trainer_contrastive.NormLayer[source]

Bases: torch.nn.modules.module.Module

forward(x)[source]
training: bool

Module contents