|
from torch.utils.data import Dataset as TorchDataset |
|
import pandas as pd |
|
import torchaudio |
|
import torch |
|
|
|
class ChainsawDataset(TorchDataset): |
|
|
|
def __init__(self): |
|
self.path="../datasets/freesound/" |
|
self.ds = pd.read_csv(self.path+"labels.csv") |
|
|
|
def __getitem__(self, index): |
|
file, label = self.ds.iloc[index] |
|
x, sr = torchaudio.load(self.path+file) |
|
x = x.squeeze() |
|
return { |
|
'audio': { |
|
'path': file, |
|
'array': x, |
|
'sampling_rate': torch.tensor(sr), |
|
}, |
|
'label': torch.tensor(label) |
|
} |
|
|
|
def __len__(self): |
|
return len(self.ds) |