from torch.utils.data import DataLoader import librosa from math import floor import torch from torch.nn.functional import pad from torchaudio.transforms import Resample from random import randint def get_dataloader(dataset, device, batch_size=16, shuffle=True): return DataLoader( dataset.with_format("torch", device=device), batch_size=batch_size, collate_fn=prepare_batch, num_workers=4, shuffle=shuffle, drop_last=True, persistent_workers=True, ) def resample(x, sr, newsr): transform = Resample( orig_freq=sr, new_freq=newsr, resampling_method="sinc_interp_kaiser", lowpass_filter_width=16, rolloff=0.85, beta=8.555504641634386, ) return transform(x) def fixlength(x, L): x = x[:L] x = pad(x, (0,L-len(x))) return x def preprocess(X, newsr, n_fft, win_length, hop_length, gain=0.8, bias=10, power=0.25): X = torch.stft(X, n_fft, hop_length=hop_length, win_length=win_length, window=torch.hann_window(win_length), onesided=True, return_complex=True) X = torch.abs(X) X = torch.stack([torch.from_numpy(librosa.pcen(x.numpy(), sr=newsr, hop_length=hop_length, gain=gain, bias=bias, power=power)) for x in X], 0) X = X.to(torch.bfloat16) return X def prepare_batch(samples): newsr = 4000 n_fft = 2**10 # power of 2 win_length = 2**10 hop_length = floor(0.0505*newsr) labels = [] signals = [] for sample in samples: labels.append(sample['label']) sr = sample['audio']['sampling_rate'] x = sample['audio']['array'] if (sr > newsr and len(x)!=0): x = resample(x, sr, newsr) x = fixlength(x, 3*newsr) signals.append(x) signals = torch.stack(signals, 0) batch = preprocess(signals,newsr, n_fft, win_length, hop_length) labels = torch.tensor(labels, dtype=float) return batch, labels # Data augmentation def random_mask(sample): # random rectangular mask B, H, W = sample.shape for b in range(B): for _ in range(randint(3,12)): w = randint(5, 15) h = randint(10, 100) x1 = randint(0, W-w) y1 = randint(0, H-h) sample[b, y1:y1+h, x1:x1+w] = 0 return sample def timeshift(sample): padsize = randint(0, 6) length = sample.size(2) randpad = torch.zeros((sample.size(0), sample.size(1), padsize), dtype=torch.float32) sample = torch.cat((randpad, sample), dim=2) sample = sample[:,:,:length] return sample def add_noise(sample): #noise = np.random.normal(0, 0.05*sample.max(), sample.shape) noise = 0.05*sample.max()*torch.randn(sample.shape, dtype=torch.float32) sample = sample + noise return sample def augment(sample): sample = timeshift(sample) sample = random_mask(sample) sample = add_noise(sample) return sample