File size: 2,949 Bytes
5ad4868
 
 
 
 
 
0388c00
5ad4868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0388c00
 
 
 
 
 
 
 
 
 
 
 
 
5ad4868
0388c00
 
 
 
 
 
 
5ad4868
0388c00
 
 
 
 
5ad4868
0388c00
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from torch.utils.data import DataLoader
import librosa
from math import floor
import torch
from torch.nn.functional import pad
from torchaudio.transforms import Resample
from random import randint


def get_dataloader(dataset, device, batch_size=16, shuffle=True):
    return DataLoader(
        dataset.with_format("torch", device=device), 
        batch_size=batch_size, 
        collate_fn=prepare_batch,
        num_workers=4, 
        shuffle=shuffle, 
        drop_last=True, 
        persistent_workers=True,
    )

def resample(x, sr, newsr):
    transform = Resample(
        orig_freq=sr,
        new_freq=newsr,
        resampling_method="sinc_interp_kaiser",
        lowpass_filter_width=16,
        rolloff=0.85,
        beta=8.555504641634386,
    )
    return transform(x)

def fixlength(x, L):
    x = x[:L]
    x = pad(x, (0,L-len(x)))
    return x

def preprocess(X, newsr, n_fft, win_length, hop_length, gain=0.8, bias=10, power=0.25):
    X = torch.stft(X, n_fft, hop_length=hop_length, win_length=win_length, window=torch.hann_window(win_length), onesided=True, return_complex=True)
    X = torch.abs(X)
    X = torch.stack([torch.from_numpy(librosa.pcen(x.numpy(), sr=newsr, hop_length=hop_length, gain=gain, bias=bias, power=power))
                      for x in X], 0)
    X = X.to(torch.bfloat16)
    return X


def prepare_batch(samples):
    newsr = 4000
    n_fft = 2**10 # power of 2
    win_length = 2**10
    hop_length = floor(0.0505*newsr)
    labels = []
    signals = []
    for sample in samples:
        labels.append(sample['label'])
        sr = sample['audio']['sampling_rate']
        x = sample['audio']['array']
        if (sr > newsr and len(x)!=0):
            x = resample(x, sr, newsr)
        x = fixlength(x, 3*newsr)
        signals.append(x)

    signals = torch.stack(signals, 0)
    batch = preprocess(signals,newsr, n_fft, win_length, hop_length)
    labels = torch.tensor(labels, dtype=float)
    return batch, labels

# Data augmentation

def random_mask(sample):
    # random rectangular mask
    B, H, W = sample.shape
    for b in range(B):
        for _ in range(randint(3,12)):
            w = randint(5, 15)
            h = randint(10, 100)
            x1 = randint(0, W-w)
            y1 = randint(0, H-h)
            sample[b, y1:y1+h, x1:x1+w] = 0
    return sample

def timeshift(sample):
    padsize = randint(0, 6)
    length = sample.size(2)
    randpad = torch.zeros((sample.size(0), sample.size(1), padsize), dtype=torch.float32)
    sample = torch.cat((randpad, sample), dim=2)
    sample = sample[:,:,:length]
    return sample

def add_noise(sample):
    #noise = np.random.normal(0, 0.05*sample.max(), sample.shape)
    noise = 0.05*sample.max()*torch.randn(sample.shape, dtype=torch.float32)
    sample = sample + noise
    return sample

def augment(sample):
    sample = timeshift(sample)
    sample = random_mask(sample)
    sample = add_noise(sample)
    return sample