import os
import argparse
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Union
from munch import DefaultMunch
import yaml
## define custom tag handler
[docs]def join(loader, node):
seq = loader.construct_sequence(node)
return ''.join([str(i) for i in seq])
## register the tag handler
yaml.add_constructor('!join', join)
[docs]def parse_config_dict(
yml_path: str = None,
arg_type: str = "data",
) -> Dict :
'''
Take a yaml file and return the corresponding arguments.
Args:
arg_type (str): The type of arguments to return. One of "data", "frontend", "model", "train". (default: "data")
Returns:
Dict: The corresponding arguments in dictionary form.
'''
if yml_path is None:
yml_file = os.path.join(os.getcwd(), "configs_", arg_type + ".yml")
else:
yml_file = yml_path
with open(yml_file, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config
[docs]def parse_config_obj(yml_path: str = None):
with open(yml_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
obj = DefaultMunch.fromDict(config)
return obj
[docs]class ArgsParser():
def __init__(self, argv=None):
# Do argv default this way, as doing it in the functional
# declaration sets it at compile time.
if argv is None:
self.argv = sys.argv
else:
self.argv = argv
# Parse any conf_file specification
# We make this parser with add_help=False so that
# it doesn't parse -h and print help.
self.default_parser = argparse.ArgumentParser(
description=__doc__, # printed with -h/--help
# Don't mess with format of description
formatter_class=argparse.RawDescriptionHelpFormatter,
# Turn off help, so we print all options in response to -h
add_help=False,
)
self.default_parser.add_argument(
"-c", "--conf_file", help="Specify the master config file", metavar="FILE", required=False,
default = None
)
self.default_args, self.remaining_argv = self.default_parser.parse_known_args()
self.parser = argparse.ArgumentParser(
parents=[self.default_parser]
)
if self.default_args.conf_file is not None:
configs = parse_config_obj(yml_path=self.default_args.conf_file)
self.parser.set_defaults(**configs.__dict__)
else:
pass
[docs] def add_argument(self, *opts, **kwopts):
self.parser.add_argument(*opts, **kwopts)
[docs] def get_args(self):
args = self.parser.parse_args(self.remaining_argv)
args = DefaultMunch.fromDict(args.__dict__)
return args
# def ParseArgswithConfig(argv=None):
# # Parse rest of arguments
# # Don't suppress add_help here so it will handle -h
# parser = argparse.ArgumentParser(
# # Inherit options from config_parser
# parents=[conf_parser]
# )
# parser.set_defaults(**configs.__dict__)
# parser.add_argument('--seed')
# parser.add_argument('--gpus')
# parser.add_argument('--model_configs')
# parser.add_argument('--data_configs')
# return args
if __name__ == "__main__":
my_parser = ArgsParser()
print(my_parser.default_args)
my_parser.add_argument("--prefix", type=str, required=True, metavar='PFX', help='prefix for logging & checkpoint saving')
my_parser.add_argument("--mixup", action='store_true')
args = my_parser.get_args()
print(args)
breakpoint()