Spaces:
Runtime error
Runtime error
| from typing import List | |
| import pytorch_lightning as pl | |
| import torch.nn as nn | |
| def get_callbacks( | |
| task_name: str, | |
| config_yaml: str, | |
| workspace: str, | |
| checkpoints_dir: str, | |
| statistics_path: str, | |
| logger: pl.loggers.TensorBoardLogger, | |
| model: nn.Module, | |
| evaluate_device: str, | |
| ) -> List[pl.Callback]: | |
| r"""Get callbacks of a task and config yaml file. | |
| Args: | |
| task_name: str | |
| config_yaml: str | |
| dataset_dir: str | |
| workspace: str, containing useful files such as audios for evaluation | |
| checkpoints_dir: str, directory to save checkpoints | |
| statistics_dir: str, directory to save statistics | |
| logger: pl.loggers.TensorBoardLogger | |
| model: nn.Module | |
| evaluate_device: str | |
| Return: | |
| callbacks: List[pl.Callback] | |
| """ | |
| if task_name == 'musdb18': | |
| from bytesep.callbacks.musdb18 import get_musdb18_callbacks | |
| return get_musdb18_callbacks( | |
| config_yaml=config_yaml, | |
| workspace=workspace, | |
| checkpoints_dir=checkpoints_dir, | |
| statistics_path=statistics_path, | |
| logger=logger, | |
| model=model, | |
| evaluate_device=evaluate_device, | |
| ) | |
| elif task_name == 'voicebank-demand': | |
| from bytesep.callbacks.voicebank_demand import get_voicebank_demand_callbacks | |
| return get_voicebank_demand_callbacks( | |
| config_yaml=config_yaml, | |
| workspace=workspace, | |
| checkpoints_dir=checkpoints_dir, | |
| statistics_path=statistics_path, | |
| logger=logger, | |
| model=model, | |
| evaluate_device=evaluate_device, | |
| ) | |
| elif task_name in ['vctk-musdb18', 'violin-piano', 'piano-symphony']: | |
| from bytesep.callbacks.instruments_callbacks import get_instruments_callbacks | |
| return get_instruments_callbacks( | |
| config_yaml=config_yaml, | |
| workspace=workspace, | |
| checkpoints_dir=checkpoints_dir, | |
| statistics_path=statistics_path, | |
| logger=logger, | |
| model=model, | |
| evaluate_device=evaluate_device, | |
| ) | |
| else: | |
| raise NotImplementedError | |