File size: 5,035 Bytes
7aefe45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
from typing import Tuple
from functools import reduce

from argparse import Namespace
from omegaconf import DictConfig, OmegaConf


#################################################################################
#                             merge yaml and argparse                           #
#################################################################################

def register_resolver():
    OmegaConf.register_new_resolver(
        "add", lambda *numbers: sum(numbers)
    )
    OmegaConf.register_new_resolver(
        "multiply", lambda *numbers: reduce(lambda x, y: x * y, numbers)
    )
    OmegaConf.register_new_resolver(
        "sub", lambda n1, n2: n1 - n2
    )


def _merge_args_and_config(
        cmd_args: Namespace,
        yaml_config: DictConfig,
        read_only: bool = False
) -> Tuple[DictConfig, DictConfig, DictConfig]:
    # convert cmd line args to OmegaConf
    cmd_args_dict = vars(cmd_args)
    cmd_args_list = []
    for k, v in cmd_args_dict.items():
        cmd_args_list.append(f"{k}={v}")
    cmd_args_conf = OmegaConf.from_cli(cmd_args_list)

    # The following overrides the previous configuration
    # cmd_args_list > configs
    args_ = OmegaConf.merge(yaml_config, cmd_args_conf)

    if read_only:
        OmegaConf.set_readonly(args_, True)

    return args_, cmd_args_conf, yaml_config


def merge_configs(args, method_cfg_path):
    """merge command line args (argparse) and config file (OmegaConf)"""
    yaml_config_path = os.path.join("./", "config", method_cfg_path)
    try:
        yaml_config = OmegaConf.load(yaml_config_path)
    except FileNotFoundError as e:
        print(f"error: {e}")
        print(f"input file path: `{method_cfg_path}`")
        print(f"config path: `{yaml_config_path}` not found.")
        raise FileNotFoundError(e)
    return _merge_args_and_config(args, yaml_config, read_only=False)


def update_configs(source_args, update_nodes, strict=True, remove_update_nodes=True):
    """update config file (OmegaConf) with dotlist"""
    if update_nodes is None:
        return source_args

    update_args_list = str(update_nodes).split()
    if len(update_args_list) < 1:
        return source_args

    # check update_args
    for item in update_args_list:
        item_key_ = str(item).split('=')[0]  # get key
        # item_val_ = str(item).split('=')[1]  # get value

        if strict:
            # Tests if a key is existing
            # assert OmegaConf.select(source_args, item_key_) is not None, f"{item_key_} is not existing."

            # Tests if a value is missing
            assert not OmegaConf.is_missing(source_args, item_key_), f"the value of {item_key_} is missing."

            # if keys is None, then add key and set the value
            if OmegaConf.select(source_args, item_key_) is None:
                source_args.item_key_ = item_key_

    # update original yaml params
    update_nodes = OmegaConf.from_dotlist(update_args_list)
    merged_args = OmegaConf.merge(source_args, update_nodes)

    # remove update_args
    if remove_update_nodes:
        OmegaConf.update(merged_args, 'update', '')
    return merged_args


def update_if_exist(source_args, update_nodes):
    """update config file (OmegaConf) with dotlist"""
    if update_nodes is None:
        return source_args

    upd_args_list = str(update_nodes).split()
    if len(upd_args_list) < 1:
        return source_args

    update_args_list = []
    for item in upd_args_list:
        item_key_ = str(item).split('=')[0]  # get key

        # if a key is existing
        # if OmegaConf.select(source_args, item_key_) is not None:
        #     update_args_list.append(item)

        update_args_list.append(item)

    # update source_args if key be selected
    if len(update_args_list) < 1:
        merged_args = source_args
    else:
        update_nodes = OmegaConf.from_dotlist(update_args_list)
        merged_args = OmegaConf.merge(source_args, update_nodes)

    return merged_args


def merge_and_update_config(args):
    register_resolver()

    # if yaml_config is existing, then merge command line args and yaml_config
    # if os.path.isfile(args.config) and args.config is not None:
    if args.config is not None and str(args.config).endswith('.yaml'):
        merged_args, cmd_args, yaml_config = merge_configs(args, args.config)
    else:
        merged_args, cmd_args, yaml_config = args, args, None

    # update the yaml_config with the cmd '-update' flag
    update_nodes = args.update
    final_args = update_configs(merged_args, update_nodes)

    # to simplify log output, we empty this
    yaml_config_update = update_if_exist(yaml_config, update_nodes)
    cmd_args_update = update_if_exist(cmd_args, update_nodes)
    cmd_args_update.update = ""  # clear update params

    final_args.yaml_config = yaml_config_update
    final_args.cmd_args = cmd_args_update

    # update seed
    if final_args.seed < 0:
        import random
        final_args.seed = random.randint(0, 65535)

    return final_args