File size: 5,035 Bytes
7aefe45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
from typing import Tuple
from functools import reduce
from argparse import Namespace
from omegaconf import DictConfig, OmegaConf
#################################################################################
# merge yaml and argparse #
#################################################################################
def register_resolver():
OmegaConf.register_new_resolver(
"add", lambda *numbers: sum(numbers)
)
OmegaConf.register_new_resolver(
"multiply", lambda *numbers: reduce(lambda x, y: x * y, numbers)
)
OmegaConf.register_new_resolver(
"sub", lambda n1, n2: n1 - n2
)
def _merge_args_and_config(
cmd_args: Namespace,
yaml_config: DictConfig,
read_only: bool = False
) -> Tuple[DictConfig, DictConfig, DictConfig]:
# convert cmd line args to OmegaConf
cmd_args_dict = vars(cmd_args)
cmd_args_list = []
for k, v in cmd_args_dict.items():
cmd_args_list.append(f"{k}={v}")
cmd_args_conf = OmegaConf.from_cli(cmd_args_list)
# The following overrides the previous configuration
# cmd_args_list > configs
args_ = OmegaConf.merge(yaml_config, cmd_args_conf)
if read_only:
OmegaConf.set_readonly(args_, True)
return args_, cmd_args_conf, yaml_config
def merge_configs(args, method_cfg_path):
"""merge command line args (argparse) and config file (OmegaConf)"""
yaml_config_path = os.path.join("./", "config", method_cfg_path)
try:
yaml_config = OmegaConf.load(yaml_config_path)
except FileNotFoundError as e:
print(f"error: {e}")
print(f"input file path: `{method_cfg_path}`")
print(f"config path: `{yaml_config_path}` not found.")
raise FileNotFoundError(e)
return _merge_args_and_config(args, yaml_config, read_only=False)
def update_configs(source_args, update_nodes, strict=True, remove_update_nodes=True):
"""update config file (OmegaConf) with dotlist"""
if update_nodes is None:
return source_args
update_args_list = str(update_nodes).split()
if len(update_args_list) < 1:
return source_args
# check update_args
for item in update_args_list:
item_key_ = str(item).split('=')[0] # get key
# item_val_ = str(item).split('=')[1] # get value
if strict:
# Tests if a key is existing
# assert OmegaConf.select(source_args, item_key_) is not None, f"{item_key_} is not existing."
# Tests if a value is missing
assert not OmegaConf.is_missing(source_args, item_key_), f"the value of {item_key_} is missing."
# if keys is None, then add key and set the value
if OmegaConf.select(source_args, item_key_) is None:
source_args.item_key_ = item_key_
# update original yaml params
update_nodes = OmegaConf.from_dotlist(update_args_list)
merged_args = OmegaConf.merge(source_args, update_nodes)
# remove update_args
if remove_update_nodes:
OmegaConf.update(merged_args, 'update', '')
return merged_args
def update_if_exist(source_args, update_nodes):
"""update config file (OmegaConf) with dotlist"""
if update_nodes is None:
return source_args
upd_args_list = str(update_nodes).split()
if len(upd_args_list) < 1:
return source_args
update_args_list = []
for item in upd_args_list:
item_key_ = str(item).split('=')[0] # get key
# if a key is existing
# if OmegaConf.select(source_args, item_key_) is not None:
# update_args_list.append(item)
update_args_list.append(item)
# update source_args if key be selected
if len(update_args_list) < 1:
merged_args = source_args
else:
update_nodes = OmegaConf.from_dotlist(update_args_list)
merged_args = OmegaConf.merge(source_args, update_nodes)
return merged_args
def merge_and_update_config(args):
register_resolver()
# if yaml_config is existing, then merge command line args and yaml_config
# if os.path.isfile(args.config) and args.config is not None:
if args.config is not None and str(args.config).endswith('.yaml'):
merged_args, cmd_args, yaml_config = merge_configs(args, args.config)
else:
merged_args, cmd_args, yaml_config = args, args, None
# update the yaml_config with the cmd '-update' flag
update_nodes = args.update
final_args = update_configs(merged_args, update_nodes)
# to simplify log output, we empty this
yaml_config_update = update_if_exist(yaml_config, update_nodes)
cmd_args_update = update_if_exist(cmd_args, update_nodes)
cmd_args_update.update = "" # clear update params
final_args.yaml_config = yaml_config_update
final_args.cmd_args = cmd_args_update
# update seed
if final_args.seed < 0:
import random
final_args.seed = random.randint(0, 65535)
return final_args
|