File size: 5,468 Bytes
3a1da90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import logging
import os
from pathlib import Path
from typing import Union

import open_clip
import pandas as pd
import torch
import torchaudio
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F

log = logging.getLogger()


class WavTextClipsDataset(Dataset):

    def __init__(
        self,
        root: Union[str, Path],
        *,
        captions_tsv: Union[str, Path],
        clips_tsv: Union[str, Path],
        sample_rate: int,
        num_samples: int,
        duration: int = 10, 
        normalize_audio: bool = False,
        reject_silent: bool = False,
        tokenizer_id: str = 'ViT-H-14-378-quickgelu',
        multi_caption: bool = False
    ):
        self.root = Path(root)
        self.sample_rate = sample_rate
        self.num_samples = num_samples
        self.normalize_audio = normalize_audio
        self.reject_silent = reject_silent
        self.duration = duration
        self.tokenizer = open_clip.get_tokenizer(tokenizer_id)  # only for clip, for t5 and clap we will get caption embeddings outside 

        audios = sorted(os.listdir(self.root))
        audios = set([
            Path(audio).stem for audio in audios  # file name w/o extension
            if audio.endswith('.wav') or audio.endswith('.flac')
        ])
        self.captions = {}

        # read the caption tsv
        df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
        for record in df_list:
            id = record['id']  # file name
            caption = record['caption']
            if not multi_caption: 
                self.captions[id] = caption  # captions: {name(no partition index): caption} !Only ONE caption will be selected for an audio clip
            else: 
                if id not in self.captions.keys(): 
                    self.captions[id] = [caption]
                else: 
                    self.captions[id].append(caption)

        # read the clip tsv
        df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
            'id': str,
            'name': str
        }).to_dict('records')
        self.clips = []
        for record in df_list:  # partition
            name = record['name']
            if name not in self.captions:  
                log.warning(f'Audio {name} not found in {captions_tsv}')
                continue
            
            if not multi_caption: 
                record['caption'] = self.captions[name]
                self.clips.append(record)  # add caption to partition csv
            else: 
                for caption in self.captions[name]: 
                    r = record.copy()
                    r['caption'] = caption
                    self.clips.append(r)  # add caption to partition csv

        log.info(f'Found {len(self.clips)} audio files in {self.root}')  

        self.resampler = {}

    def __getitem__(self, idx: int) -> torch.Tensor:
        try:
            clip = self.clips[idx]
            audio_name = clip['name']
            audio_id = clip['id']
            caption = clip['caption']
            start_sample = clip['start_sample']
            end_sample = clip['end_sample']

            audio_path = self.root / f'{audio_name}.flac'
            if not audio_path.exists():
                audio_path = self.root / f'{audio_name}.wav'
                assert audio_path.exists()

            audio_chunk, sample_rate = torchaudio.load(audio_path)
            audio_chunk = audio_chunk.mean(dim=0)  # mono
            abs_max = audio_chunk.abs().max()
            if self.normalize_audio:
                audio_chunk = audio_chunk / abs_max * 0.95

            if self.reject_silent and abs_max < 1e-6:
                log.warning(f'Rejecting silent audio')
                return None
            if audio_chunk.size(0) < end_sample: 
                audio_chunk = F.pad(
                    audio_chunk, 
                    (0, end_sample - audio_chunk.size(0)), 
                    mode='constant',
                    value=0
                ) 
            else: 
                audio_chunk = audio_chunk[start_sample:end_sample]

            # resample
            if sample_rate == self.sample_rate:
                audio_chunk = audio_chunk
            else:
                if sample_rate not in self.resampler:
                    # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
                    self.resampler[sample_rate] = torchaudio.transforms.Resample(
                        sample_rate,
                        self.sample_rate,
                        lowpass_filter_width=64,
                        rolloff=0.9475937167399596,
                        resampling_method='sinc_interp_kaiser',
                        beta=14.769656459379492,
                    )
                audio_chunk = self.resampler[sample_rate](audio_chunk)

            if audio_chunk.shape[0] < self.num_samples:
                raise ValueError('Audio is too short')
            audio_chunk = audio_chunk[:self.num_samples]

            tokens = self.tokenizer([caption])[0]

            output = {
                'waveform': audio_chunk,
                'id': audio_id,
                'caption': caption,
                'tokens': tokens,
            }

            return output
        except Exception as e:
            log.error(f'Error reading {audio_path}: {e}')
            return None

    def __len__(self):
        return len(self.clips)