File size: 1,000 Bytes
c42fe7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from utils.hparams import hparams


class BaseAugmentation:
    """
    Base class for data augmentation.
    All methods of this class should be thread-safe.
    1. *process_item*:
        Apply augmentation to one piece of data.
    """
    def __init__(self, data_dirs: list, augmentation_args: dict):
        self.raw_data_dirs = data_dirs
        self.augmentation_args = augmentation_args
        self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']

    def process_item(self, item: dict, **kwargs) -> dict:
        raise NotImplementedError()


def require_same_keys(func):
    def run(*args, **kwargs):
        item: dict = args[1]
        res: dict = func(*args, **kwargs)
        assert set(item.keys()) == set(res.keys()), 'Item keys mismatch after augmentation.\n' \
                                                    f'Before: {sorted(item.keys())}\n' \
                                                    f'After: {sorted(res.keys())}'
        return res
    return run