File size: 3,001 Bytes
5ad4868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from torch.utils.data import DataLoader
import librosa
from math import floor
import torch
from torch.nn.functional import pad
from torchaudio.transforms import Resample
#from random import randint


def get_dataloader(dataset, device, batch_size=16, shuffle=True):
    return DataLoader(
        dataset.with_format("torch", device=device), 
        batch_size=batch_size, 
        collate_fn=prepare_batch,
        num_workers=4, 
        shuffle=shuffle, 
        drop_last=True, 
        persistent_workers=True,
    )

def resample(x, sr, newsr):
    transform = Resample(
        orig_freq=sr,
        new_freq=newsr,
        resampling_method="sinc_interp_kaiser",
        lowpass_filter_width=16,
        rolloff=0.85,
        beta=8.555504641634386,
    )
    return transform(x)

def fixlength(x, L):
    x = x[:L]
    x = pad(x, (0,L-len(x)))
    return x

def preprocess(X, newsr, n_fft, win_length, hop_length, gain=0.8, bias=10, power=0.25):
    X = torch.stft(X, n_fft, hop_length=hop_length, win_length=win_length, window=torch.hann_window(win_length), onesided=True, return_complex=True)
    X = torch.abs(X)
    X = torch.stack([torch.from_numpy(librosa.pcen(x.numpy(), sr=newsr, hop_length=hop_length, gain=gain, bias=bias, power=power))
                      for x in X], 0)
    X = X.to(torch.bfloat16)
    return X



def prepare_batch(samples):
    #maxlen=60
    newsr = 4000
    n_fft = 2**10 # power of 2
    win_length = 2**10
    hop_length = floor(0.0505*newsr)
    labels = []
    signals = []
    for sample in samples:
        labels.append(sample['label'])
        sr = sample['audio']['sampling_rate']
        x = sample['audio']['array']
        if (sr > newsr and len(x)!=0):
            x = resample(x, sr, newsr)
        x = fixlength(x, 3*newsr)
        signals.append(x)

    signals = torch.stack(signals, 0)
    batch = preprocess(signals,newsr, n_fft, win_length, hop_length)
    labels = torch.tensor(labels, dtype=float)
    return batch, labels

# def random_mask(sample):
#     # random rectangular mask
#     B, H, W = sample.shape
#     for b in range(B):
#         for _ in range(randint(3,12)):
#             w = randint(5, 15)
#             h = randint(10, 100)
#             x1 = randint(0, W-w)
#             y1 = randint(0, H-h)
#             sample[b, y1:y1+h, x1:x1+w] = 0
#     return sample

# def timeshift(sample):
#     padsize = randint(0, 6)
#     length = sample.size(2)
#     randpad = torch.zeros((sample.size(0), sample.size(1), padsize), dtype=torch.float32)
#     sample = torch.cat((randpad, sample), dim=2)
#     sample = sample[:,:,:length]
#     return sample

# def add_noise(sample):
#     #noise = np.random.normal(0, 0.05*sample.max(), sample.shape)
#     noise = 0.05*sample.max()*torch.randn(sample.shape, dtype=torch.float32)
#     sample = sample + noise
#     return sample

# def augment(sample):
#     sample = timeshift(sample)
#     sample = random_mask(sample)
#     sample = add_noise(sample)
#     return sample