diffsketcher_edit / libs /engine /config_processor.py
MarkMoHR's picture
added code
7aefe45
import os
from typing import Tuple
from functools import reduce
from argparse import Namespace
from omegaconf import DictConfig, OmegaConf
#################################################################################
# merge yaml and argparse #
#################################################################################
def register_resolver():
OmegaConf.register_new_resolver(
"add", lambda *numbers: sum(numbers)
)
OmegaConf.register_new_resolver(
"multiply", lambda *numbers: reduce(lambda x, y: x * y, numbers)
)
OmegaConf.register_new_resolver(
"sub", lambda n1, n2: n1 - n2
)
def _merge_args_and_config(
cmd_args: Namespace,
yaml_config: DictConfig,
read_only: bool = False
) -> Tuple[DictConfig, DictConfig, DictConfig]:
# convert cmd line args to OmegaConf
cmd_args_dict = vars(cmd_args)
cmd_args_list = []
for k, v in cmd_args_dict.items():
cmd_args_list.append(f"{k}={v}")
cmd_args_conf = OmegaConf.from_cli(cmd_args_list)
# The following overrides the previous configuration
# cmd_args_list > configs
args_ = OmegaConf.merge(yaml_config, cmd_args_conf)
if read_only:
OmegaConf.set_readonly(args_, True)
return args_, cmd_args_conf, yaml_config
def merge_configs(args, method_cfg_path):
"""merge command line args (argparse) and config file (OmegaConf)"""
yaml_config_path = os.path.join("./", "config", method_cfg_path)
try:
yaml_config = OmegaConf.load(yaml_config_path)
except FileNotFoundError as e:
print(f"error: {e}")
print(f"input file path: `{method_cfg_path}`")
print(f"config path: `{yaml_config_path}` not found.")
raise FileNotFoundError(e)
return _merge_args_and_config(args, yaml_config, read_only=False)
def update_configs(source_args, update_nodes, strict=True, remove_update_nodes=True):
"""update config file (OmegaConf) with dotlist"""
if update_nodes is None:
return source_args
update_args_list = str(update_nodes).split()
if len(update_args_list) < 1:
return source_args
# check update_args
for item in update_args_list:
item_key_ = str(item).split('=')[0] # get key
# item_val_ = str(item).split('=')[1] # get value
if strict:
# Tests if a key is existing
# assert OmegaConf.select(source_args, item_key_) is not None, f"{item_key_} is not existing."
# Tests if a value is missing
assert not OmegaConf.is_missing(source_args, item_key_), f"the value of {item_key_} is missing."
# if keys is None, then add key and set the value
if OmegaConf.select(source_args, item_key_) is None:
source_args.item_key_ = item_key_
# update original yaml params
update_nodes = OmegaConf.from_dotlist(update_args_list)
merged_args = OmegaConf.merge(source_args, update_nodes)
# remove update_args
if remove_update_nodes:
OmegaConf.update(merged_args, 'update', '')
return merged_args
def update_if_exist(source_args, update_nodes):
"""update config file (OmegaConf) with dotlist"""
if update_nodes is None:
return source_args
upd_args_list = str(update_nodes).split()
if len(upd_args_list) < 1:
return source_args
update_args_list = []
for item in upd_args_list:
item_key_ = str(item).split('=')[0] # get key
# if a key is existing
# if OmegaConf.select(source_args, item_key_) is not None:
# update_args_list.append(item)
update_args_list.append(item)
# update source_args if key be selected
if len(update_args_list) < 1:
merged_args = source_args
else:
update_nodes = OmegaConf.from_dotlist(update_args_list)
merged_args = OmegaConf.merge(source_args, update_nodes)
return merged_args
def merge_and_update_config(args):
register_resolver()
# if yaml_config is existing, then merge command line args and yaml_config
# if os.path.isfile(args.config) and args.config is not None:
if args.config is not None and str(args.config).endswith('.yaml'):
merged_args, cmd_args, yaml_config = merge_configs(args, args.config)
else:
merged_args, cmd_args, yaml_config = args, args, None
# update the yaml_config with the cmd '-update' flag
update_nodes = args.update
final_args = update_configs(merged_args, update_nodes)
# to simplify log output, we empty this
yaml_config_update = update_if_exist(yaml_config, update_nodes)
cmd_args_update = update_if_exist(cmd_args, update_nodes)
cmd_args_update.update = "" # clear update params
final_args.yaml_config = yaml_config_update
final_args.cmd_args = cmd_args_update
# update seed
if final_args.seed < 0:
import random
final_args.seed = random.randint(0, 65535)
return final_args