import logging import os from pathlib import Path from typing import Union import open_clip import pandas as pd import torch import torchaudio from torch.utils.data.dataset import Dataset import torch.nn.functional as F log = logging.getLogger() class WavTextClipsDataset(Dataset): def __init__( self, root: Union[str, Path], *, captions_tsv: Union[str, Path], clips_tsv: Union[str, Path], sample_rate: int, num_samples: int, duration: int = 10, normalize_audio: bool = False, reject_silent: bool = False, tokenizer_id: str = 'ViT-H-14-378-quickgelu', multi_caption: bool = False ): self.root = Path(root) self.sample_rate = sample_rate self.num_samples = num_samples self.normalize_audio = normalize_audio self.reject_silent = reject_silent self.duration = duration self.tokenizer = open_clip.get_tokenizer(tokenizer_id) # only for clip, for t5 and clap we will get caption embeddings outside audios = sorted(os.listdir(self.root)) audios = set([ Path(audio).stem for audio in audios # file name w/o extension if audio.endswith('.wav') or audio.endswith('.flac') ]) self.captions = {} # read the caption tsv df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records') for record in df_list: id = record['id'] # file name caption = record['caption'] if not multi_caption: self.captions[id] = caption # captions: {name(no partition index): caption} !Only ONE caption will be selected for an audio clip else: if id not in self.captions.keys(): self.captions[id] = [caption] else: self.captions[id].append(caption) # read the clip tsv df_list = pd.read_csv(clips_tsv, sep='\t', dtype={ 'id': str, 'name': str }).to_dict('records') self.clips = [] for record in df_list: # partition name = record['name'] if name not in self.captions: log.warning(f'Audio {name} not found in {captions_tsv}') continue if not multi_caption: record['caption'] = self.captions[name] self.clips.append(record) # add caption to partition csv else: for caption in self.captions[name]: r = record.copy() r['caption'] = caption self.clips.append(r) # add caption to partition csv log.info(f'Found {len(self.clips)} audio files in {self.root}') self.resampler = {} def __getitem__(self, idx: int) -> torch.Tensor: try: clip = self.clips[idx] audio_name = clip['name'] audio_id = clip['id'] caption = clip['caption'] start_sample = clip['start_sample'] end_sample = clip['end_sample'] audio_path = self.root / f'{audio_name}.flac' if not audio_path.exists(): audio_path = self.root / f'{audio_name}.wav' assert audio_path.exists() audio_chunk, sample_rate = torchaudio.load(audio_path) audio_chunk = audio_chunk.mean(dim=0) # mono abs_max = audio_chunk.abs().max() if self.normalize_audio: audio_chunk = audio_chunk / abs_max * 0.95 if self.reject_silent and abs_max < 1e-6: log.warning(f'Rejecting silent audio') return None if audio_chunk.size(0) < end_sample: audio_chunk = F.pad( audio_chunk, (0, end_sample - audio_chunk.size(0)), mode='constant', value=0 ) else: audio_chunk = audio_chunk[start_sample:end_sample] # resample if sample_rate == self.sample_rate: audio_chunk = audio_chunk else: if sample_rate not in self.resampler: # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best self.resampler[sample_rate] = torchaudio.transforms.Resample( sample_rate, self.sample_rate, lowpass_filter_width=64, rolloff=0.9475937167399596, resampling_method='sinc_interp_kaiser', beta=14.769656459379492, ) audio_chunk = self.resampler[sample_rate](audio_chunk) if audio_chunk.shape[0] < self.num_samples: raise ValueError('Audio is too short') audio_chunk = audio_chunk[:self.num_samples] tokens = self.tokenizer([caption])[0] output = { 'waveform': audio_chunk, 'id': audio_id, 'caption': caption, 'tokens': tokens, } return output except Exception as e: log.error(f'Error reading {audio_path}: {e}') return None def __len__(self): return len(self.clips)