Spaces:
Sleeping
Sleeping
def get_backbone_type(root_config: dict, nested_config: dict = None): | |
if nested_config is None: | |
nested_config = root_config | |
return nested_config.get( | |
'backbone_type', | |
root_config.get( | |
'backbone_type', | |
root_config.get('diff_decoder_type', 'wavenet') | |
) | |
) | |
def get_backbone_args(config: dict, backbone_type: str): | |
args = config.get('backbone_args') | |
if args is not None: | |
return args | |
elif backbone_type == 'wavenet': | |
return { | |
'num_layers': config.get('residual_layers'), | |
'num_channels': config.get('residual_channels'), | |
'dilation_cycle_length': config.get('dilation_cycle_length'), | |
} | |
else: | |
return None | |