Source code for sadaco.apis.models.custom

from sadaco.utils.config_parser import parse_config_obj
import torch
import torch.nn as nn

[docs]def custom_model(yml_path): configs = parse_config_obj(yml_path) model = nn.Sequential() for k in configs.layers: layer = getattr(nn, configs.layers[k].name)(**configs.layers[k].params) model.add_module(str(k), layer) return model
if __name__ == "__main__": model = custom_model('custom_example.yml') dummy = torch.randn((1,3,224,224)) print(dummy) out = model(dummy) print(out)