import logging import os from collections import defaultdict from pathlib import Path from typing import Union import pandas as pd import torch from torch.utils.data.dataset import Dataset log = logging.getLogger() class AudioCapsData(Dataset): def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]): df = pd.read_csv(csv_path).to_dict(orient='records') audio_files = sorted(os.listdir(audio_path)) audio_files = set( [Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')]) self.data = [] for row in df: self.data.append({ 'name': row['name'], 'caption': row['caption'], }) self.audio_path = Path(audio_path) self.csv_path = Path(csv_path) log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}') def __getitem__(self, idx: int) -> torch.Tensor: return self.data[idx] def __len__(self): return len(self.data)