File size: 5,683 Bytes
c42fe7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import argparse
import os
import yaml

try:
    from lightning.pytorch.utilities.rank_zero import rank_zero_only
except ModuleNotFoundError:
    def rank_zero_only(f):
        return f

from utils.multiprocess_utils import is_main_process as mp_is_main_process
global_print_hparams = True
hparams = {}


class Args:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            self.__setattr__(k, v)


def override_config(old_config: dict, new_config: dict):
    for k, v in new_config.items():
        if isinstance(v, dict) and k in old_config:
            override_config(old_config[k], new_config[k])
        else:
            old_config[k] = v


def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
    """
        Load hparams from multiple sources:
        1. config chain (i.e. first load base_config, then load config);
        2. if reset == True, load from the (auto-saved) complete config file ('config.yaml')
           which contains all settings and do not rely on base_config;
        3. load from argument --hparams or hparams_str, as temporary modification.
    """
    if config == '':
        parser = argparse.ArgumentParser(description='neural music')
        parser.add_argument('--config', type=str, default='',
                            help='location of the data corpus')
        parser.add_argument('--exp_name', type=str, default='', help='exp_name')
        parser.add_argument('--hparams', type=str, default='',
                            help='location of the data corpus')
        parser.add_argument('--infer', action='store_true', help='infer')
        parser.add_argument('--reset', action='store_true', help='reset hparams')
        args, unknown = parser.parse_known_args()
        
        tmp_args_hparams = args.hparams.split(',') if args.hparams.strip() != '' else []
        tmp_args_hparams.extend(hparams_str.split(',') if hparams_str.strip() != '' else [])
        args.hparams = ','.join(tmp_args_hparams)
    else:
        args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
                    infer=False, reset=False)

    args_work_dir = ''
    if args.exp_name != '':
        args.work_dir = args.exp_name
        args_work_dir = os.path.join('checkpoints', args.work_dir)

    config_chains = []
    loaded_config = set()

    def load_config(config_fn):  # deep first
        with open(config_fn, encoding='utf-8') as f:
            hparams_ = yaml.safe_load(f)
        loaded_config.add(config_fn)
        if 'base_config' in hparams_:
            ret_hparams = {}
            if not isinstance(hparams_['base_config'], list):
                hparams_['base_config'] = [hparams_['base_config']]
            for c in hparams_['base_config']:
                if c not in loaded_config:
                    if c.startswith('.'):
                        c = f'{os.path.dirname(config_fn)}/{c}'
                        c = os.path.normpath(c)
                    override_config(ret_hparams, load_config(c))
            override_config(ret_hparams, hparams_)
        else:
            ret_hparams = hparams_
        config_chains.append(config_fn)
        return ret_hparams

    global hparams
    assert args.config != '' or args_work_dir != '', 'Either config or exp name should be specified.'
    saved_hparams = {}
    ckpt_config_path = os.path.join(args_work_dir, 'config.yaml')
    if args_work_dir != '' and os.path.exists(ckpt_config_path):
        with open(ckpt_config_path, encoding='utf-8') as f:
            saved_hparams.update(yaml.safe_load(f))

    hparams_ = {}
    if args.config != '':
        hparams_.update(load_config(args.config))

    if not args.reset:
        hparams_.update(saved_hparams)
    hparams_['work_dir'] = args_work_dir

    if args.hparams != "":
        for new_hparam in args.hparams.split(","):
            if new_hparam.strip() == "":
                continue
            k, v = new_hparam.split("=")
            if k not in hparams_:
                hparams_[k] = eval(v)
            if v in ['True', 'False'] or type(hparams_[k]) == bool:
                hparams_[k] = eval(v)
            else:
                hparams_[k] = type(hparams_[k])(v)

    @rank_zero_only
    def dump_hparams():
        if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
            os.makedirs(hparams_['work_dir'], exist_ok=True)
            if mp_is_main_process:
                # Only the main process will save the config file
                with open(ckpt_config_path, 'w', encoding='utf-8') as f:
                    hparams_non_recursive = hparams_.copy()
                    hparams_non_recursive['base_config'] = []
                    yaml.safe_dump(hparams_non_recursive, f, allow_unicode=True, encoding='utf-8')
    dump_hparams()

    hparams_['infer'] = args.infer
    if global_hparams:
        hparams.clear()
        hparams.update(hparams_)
    
    if hparams.get('exp_name') is None:
        hparams['exp_name'] = args.exp_name
    if hparams_.get('exp_name') is None:
        hparams_['exp_name'] = args.exp_name
    
    @rank_zero_only
    def print_out_hparams():
        global global_print_hparams
        if mp_is_main_process and print_hparams and global_print_hparams and global_hparams:
            print('| Hparams chains: ', config_chains)
            print('| Hparams: ')
            for i, (k, v) in enumerate(sorted(hparams_.items())):
                print(f"\033[0;33m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
            print("")
            global_print_hparams = False
    print_out_hparams()
    
    return hparams_