File size: 2,949 Bytes
5ad4868 0388c00 5ad4868 0388c00 5ad4868 0388c00 5ad4868 0388c00 5ad4868 0388c00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from torch.utils.data import DataLoader
import librosa
from math import floor
import torch
from torch.nn.functional import pad
from torchaudio.transforms import Resample
from random import randint
def get_dataloader(dataset, device, batch_size=16, shuffle=True):
return DataLoader(
dataset.with_format("torch", device=device),
batch_size=batch_size,
collate_fn=prepare_batch,
num_workers=4,
shuffle=shuffle,
drop_last=True,
persistent_workers=True,
)
def resample(x, sr, newsr):
transform = Resample(
orig_freq=sr,
new_freq=newsr,
resampling_method="sinc_interp_kaiser",
lowpass_filter_width=16,
rolloff=0.85,
beta=8.555504641634386,
)
return transform(x)
def fixlength(x, L):
x = x[:L]
x = pad(x, (0,L-len(x)))
return x
def preprocess(X, newsr, n_fft, win_length, hop_length, gain=0.8, bias=10, power=0.25):
X = torch.stft(X, n_fft, hop_length=hop_length, win_length=win_length, window=torch.hann_window(win_length), onesided=True, return_complex=True)
X = torch.abs(X)
X = torch.stack([torch.from_numpy(librosa.pcen(x.numpy(), sr=newsr, hop_length=hop_length, gain=gain, bias=bias, power=power))
for x in X], 0)
X = X.to(torch.bfloat16)
return X
def prepare_batch(samples):
newsr = 4000
n_fft = 2**10 # power of 2
win_length = 2**10
hop_length = floor(0.0505*newsr)
labels = []
signals = []
for sample in samples:
labels.append(sample['label'])
sr = sample['audio']['sampling_rate']
x = sample['audio']['array']
if (sr > newsr and len(x)!=0):
x = resample(x, sr, newsr)
x = fixlength(x, 3*newsr)
signals.append(x)
signals = torch.stack(signals, 0)
batch = preprocess(signals,newsr, n_fft, win_length, hop_length)
labels = torch.tensor(labels, dtype=float)
return batch, labels
# Data augmentation
def random_mask(sample):
# random rectangular mask
B, H, W = sample.shape
for b in range(B):
for _ in range(randint(3,12)):
w = randint(5, 15)
h = randint(10, 100)
x1 = randint(0, W-w)
y1 = randint(0, H-h)
sample[b, y1:y1+h, x1:x1+w] = 0
return sample
def timeshift(sample):
padsize = randint(0, 6)
length = sample.size(2)
randpad = torch.zeros((sample.size(0), sample.size(1), padsize), dtype=torch.float32)
sample = torch.cat((randpad, sample), dim=2)
sample = sample[:,:,:length]
return sample
def add_noise(sample):
#noise = np.random.normal(0, 0.05*sample.max(), sample.shape)
noise = 0.05*sample.max()*torch.randn(sample.shape, dtype=torch.float32)
sample = sample + noise
return sample
def augment(sample):
sample = timeshift(sample)
sample = random_mask(sample)
sample = add_noise(sample)
return sample |