|
import logging |
|
import os |
|
from pathlib import Path |
|
from typing import Union |
|
|
|
import open_clip |
|
import pandas as pd |
|
import torch |
|
import torchaudio |
|
from torch.utils.data.dataset import Dataset |
|
import torch.nn.functional as F |
|
|
|
log = logging.getLogger() |
|
|
|
|
|
class WavTextClipsDataset(Dataset): |
|
|
|
def __init__( |
|
self, |
|
root: Union[str, Path], |
|
*, |
|
captions_tsv: Union[str, Path], |
|
clips_tsv: Union[str, Path], |
|
sample_rate: int, |
|
num_samples: int, |
|
duration: int = 10, |
|
normalize_audio: bool = False, |
|
reject_silent: bool = False, |
|
tokenizer_id: str = 'ViT-H-14-378-quickgelu', |
|
multi_caption: bool = False |
|
): |
|
self.root = Path(root) |
|
self.sample_rate = sample_rate |
|
self.num_samples = num_samples |
|
self.normalize_audio = normalize_audio |
|
self.reject_silent = reject_silent |
|
self.duration = duration |
|
self.tokenizer = open_clip.get_tokenizer(tokenizer_id) |
|
|
|
audios = sorted(os.listdir(self.root)) |
|
audios = set([ |
|
Path(audio).stem for audio in audios |
|
if audio.endswith('.wav') or audio.endswith('.flac') |
|
]) |
|
self.captions = {} |
|
|
|
|
|
df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records') |
|
for record in df_list: |
|
id = record['id'] |
|
caption = record['caption'] |
|
if not multi_caption: |
|
self.captions[id] = caption |
|
else: |
|
if id not in self.captions.keys(): |
|
self.captions[id] = [caption] |
|
else: |
|
self.captions[id].append(caption) |
|
|
|
|
|
df_list = pd.read_csv(clips_tsv, sep='\t', dtype={ |
|
'id': str, |
|
'name': str |
|
}).to_dict('records') |
|
self.clips = [] |
|
for record in df_list: |
|
name = record['name'] |
|
if name not in self.captions: |
|
log.warning(f'Audio {name} not found in {captions_tsv}') |
|
continue |
|
|
|
if not multi_caption: |
|
record['caption'] = self.captions[name] |
|
self.clips.append(record) |
|
else: |
|
for caption in self.captions[name]: |
|
r = record.copy() |
|
r['caption'] = caption |
|
self.clips.append(r) |
|
|
|
log.info(f'Found {len(self.clips)} audio files in {self.root}') |
|
|
|
self.resampler = {} |
|
|
|
def __getitem__(self, idx: int) -> torch.Tensor: |
|
try: |
|
clip = self.clips[idx] |
|
audio_name = clip['name'] |
|
audio_id = clip['id'] |
|
caption = clip['caption'] |
|
start_sample = clip['start_sample'] |
|
end_sample = clip['end_sample'] |
|
|
|
audio_path = self.root / f'{audio_name}.flac' |
|
if not audio_path.exists(): |
|
audio_path = self.root / f'{audio_name}.wav' |
|
assert audio_path.exists() |
|
|
|
audio_chunk, sample_rate = torchaudio.load(audio_path) |
|
audio_chunk = audio_chunk.mean(dim=0) |
|
abs_max = audio_chunk.abs().max() |
|
if self.normalize_audio: |
|
audio_chunk = audio_chunk / abs_max * 0.95 |
|
|
|
if self.reject_silent and abs_max < 1e-6: |
|
log.warning(f'Rejecting silent audio') |
|
return None |
|
if audio_chunk.size(0) < end_sample: |
|
audio_chunk = F.pad( |
|
audio_chunk, |
|
(0, end_sample - audio_chunk.size(0)), |
|
mode='constant', |
|
value=0 |
|
) |
|
else: |
|
audio_chunk = audio_chunk[start_sample:end_sample] |
|
|
|
|
|
if sample_rate == self.sample_rate: |
|
audio_chunk = audio_chunk |
|
else: |
|
if sample_rate not in self.resampler: |
|
|
|
self.resampler[sample_rate] = torchaudio.transforms.Resample( |
|
sample_rate, |
|
self.sample_rate, |
|
lowpass_filter_width=64, |
|
rolloff=0.9475937167399596, |
|
resampling_method='sinc_interp_kaiser', |
|
beta=14.769656459379492, |
|
) |
|
audio_chunk = self.resampler[sample_rate](audio_chunk) |
|
|
|
if audio_chunk.shape[0] < self.num_samples: |
|
raise ValueError('Audio is too short') |
|
audio_chunk = audio_chunk[:self.num_samples] |
|
|
|
tokens = self.tokenizer([caption])[0] |
|
|
|
output = { |
|
'waveform': audio_chunk, |
|
'id': audio_id, |
|
'caption': caption, |
|
'tokens': tokens, |
|
} |
|
|
|
return output |
|
except Exception as e: |
|
log.error(f'Error reading {audio_path}: {e}') |
|
return None |
|
|
|
def __len__(self): |
|
return len(self.clips) |
|
|