CantusSVS-hf / basics /base_augmentation.py
liampond
Clean deploy snapshot
c42fe7e
from utils.hparams import hparams
class BaseAugmentation:
"""
Base class for data augmentation.
All methods of this class should be thread-safe.
1. *process_item*:
Apply augmentation to one piece of data.
"""
def __init__(self, data_dirs: list, augmentation_args: dict):
self.raw_data_dirs = data_dirs
self.augmentation_args = augmentation_args
self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']
def process_item(self, item: dict, **kwargs) -> dict:
raise NotImplementedError()
def require_same_keys(func):
def run(*args, **kwargs):
item: dict = args[1]
res: dict = func(*args, **kwargs)
assert set(item.keys()) == set(res.keys()), 'Item keys mismatch after augmentation.\n' \
f'Before: {sorted(item.keys())}\n' \
f'After: {sorted(res.keys())}'
return res
return run