File size: 750 Bytes
c42fe7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def get_backbone_type(root_config: dict, nested_config: dict = None):
    if nested_config is None:
        nested_config = root_config
    return nested_config.get(
        'backbone_type',
        root_config.get(
            'backbone_type',
            root_config.get('diff_decoder_type', 'wavenet')
        )
    )


def get_backbone_args(config: dict, backbone_type: str):
    args = config.get('backbone_args')
    if args is not None:
        return args
    elif backbone_type == 'wavenet':
        return {
            'num_layers': config.get('residual_layers'),
            'num_channels': config.get('residual_channels'),
            'dilation_cycle_length': config.get('dilation_cycle_length'),
        }
    else:
        return None