Source code for sadaco.pipelines.build_modules

from re import L
import torch
from sadaco.dataman.loader import _build_dataloader
import sadaco.apis.losses as LF

[docs]def build_optimizer(model, train_configs, trainables = None): if trainables is None: trainables = [p for p in model.parameters() if p.requires_grad] else: pass optimizer = getattr(torch.nn.optim, train_configs.optimizer.name)(trainables, **train_configs.optimizer.params) return optimizer
[docs]def build_dataloader(dataset, train_configs, data_configs): loader = _build_dataloader(dataset, train_configs, data_configs) return loader
[docs]def build_criterion(name, mixup=False, **kwargs): criterion = getattr(LF, name) if mixup : criterion = LF.mixup_criterion(criterion, **kwargs) else: criterion = criterion(**kwargs) return criterion