File size: 7,272 Bytes
3a1da90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import logging
from pathlib import Path
from typing import Union, Optional
import pandas as pd
import torch
from tensordict import TensorDict
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from meanaudio.utils.dist_utils import local_rank
import numpy as np
import glob
import torch.nn.functional as F
log = logging.getLogger()
class ExtractedAudio(Dataset):
def __init__(
self,
tsv_path: Union[str, Path],
*,
concat_text_fc: bool,
npz_dir: Union[str, Path],
data_dim: dict[str, int],
repa_npz_dir: Optional[Union[str, Path]], # if passed, repa features (zs) would be returned
exclude_cls: Optional[bool],
repa_version: Optional[int],
):
super().__init__()
self.data_dim = data_dim
self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records') # id, caption
self.ids = [str(d['id']) for d in self.df_list]
npz_files = glob.glob(f"{npz_dir}/*.npz")
self.concat_text_fc = concat_text_fc
self.exclude_cls = exclude_cls
self.repa_version = repa_version
if self.concat_text_fc:
log.info(f'We will concat the pooled text_features and text_features_c for text condition')
# dimension check
sample = np.load(f'{npz_dir}/0.npz')
mean_s = [len(npz_files)] + list(sample['mean'].shape)
std_s = [len(npz_files)] + list(sample['std'].shape)
text_features_s = [len(npz_files)] + list(sample['text_features'].shape)
text_features_c_s = [len(npz_files)] + list(sample['text_features_c'].shape)
if self.concat_text_fc:
text_features_c_s[-1] = text_features_c_s[-1] + text_features_s[-1]
log.info(f'Loading {len(npz_files)} npz files from {npz_dir}')
log.info(f'Loaded mean: {mean_s}.')
log.info(f'Loaded std: {std_s}.')
log.info(f'Loaded text features: {text_features_s}.')
log.info(f'Loaded text features_c: {text_features_c_s}.')
assert len(npz_files) == len(self.df_list), 'Number mismatch between npz files and tsv items'
assert mean_s[1] == self.data_dim['latent_seq_len'], \
f'{mean_s[1]} != {self.data_dim["latent_seq_len"]}'
assert std_s[1] == self.data_dim['latent_seq_len'], \
f'{std_s[1]} != {self.data_dim["latent_seq_len"]}'
assert text_features_s[1] == self.data_dim['text_seq_len'], \
f'{text_features_s[1]} != {self.data_dim["text_seq_len"]}'
assert text_features_s[-1] == self.data_dim['text_dim'], \
f'{text_features_s[-1]} != {self.data_dim["text_dim"]}'
assert text_features_c_s[-1] == self.data_dim['text_c_dim'], \
f'{text_features_c_s[-1]} != {self.data_dim["text_c_dim"]}'
self.npz_dir = npz_dir
if repa_npz_dir != None:
self.repa_npz_dir = repa_npz_dir
sample = np.load(f'{repa_npz_dir}/0.npz')
repa_npz_files = glob.glob(f"{repa_npz_dir}/*.npz")
log.info(f'Loading {len(repa_npz_files)} npz representations from {repa_npz_dir}')
es_s = [len(repa_npz_files)] + list(sample['es'].shape)
if self.repa_version == 2:
es_s[1] = 65 # ad-hoc 8x downsampling for EAT
elif self.repa_version == 3:
es_s[1] = 1 # we only use cls token for alignment
else:
if self.exclude_cls:
es_s[1] = es_s[1] - 1
log.info(f'Loaded es: {es_s}')
assert len(repa_npz_files) == len(npz_files), 'Number mismatch between repa npz files and latent npz files'
assert es_s[1] == self.data_dim['repa_seq_len'], \
f'{es_s[1]} != {self.data_dim["repa_seq_len"]}'
assert es_s[-1] == self.data_dim['repa_seq_dim'], \
f'{es_s[-1]} != {self.data_dim["repa_seq_dim"]}'
else:
self.repa_npz_dir = None
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
# !TODO here we may consider load pre-computed latent mean & std
raise NotImplementedError('Please manually compute latent stats outside. ')
def __getitem__(self, idx):
npz_path = f'{self.npz_dir}/{idx}.npz'
np_data = np.load(npz_path)
text_features = torch.from_numpy(np_data['text_features'])
text_features_c = torch.from_numpy(np_data['text_features_c'])
if self.concat_text_fc:
text_features_c = torch.cat([text_features.mean(dim=-2),
text_features_c], dim=-1) # [b, d+d_c]
out_dict = {
'id': str(self.df_list[idx]['id']),
'a_mean': torch.from_numpy(np_data['mean']),
'a_std': torch.from_numpy(np_data['std']),
'text_features': text_features,
'text_features_c': text_features_c,
'caption': self.df_list[idx]['caption'],
}
if self.repa_npz_dir != None:
repa_npz_path = f'{self.repa_npz_dir}/{idx}.npz'
repa_np_data = np.load(repa_npz_path)
zs = torch.from_numpy(repa_np_data['es'])
if self.repa_version == 1:
if self.exclude_cls:
zs = zs[1:,:]
if self.repa_version == 2:
z_cls = zs[0] # (dim)
# zs = zs[1:,:].view(64, 8, 768)
zs = F.avg_pool2d(zs[1:,:].unsqueeze(0),
kernel_size=(8, 1),
stride=(8, 1)).squeeze() # (64, 768)
zs = torch.cat((z_cls.unsqueeze(0), zs), dim=0)
elif self.repa_version == 3: # cls token
zs = zs[0].unsqueeze(0)
out_dict['zs'] = zs #!TODO Here field is WRONG for eat features (should be zs)
return out_dict
def __len__(self):
return len(self.ids)
if __name__ == '__main__':
from meanaudio.utils.dist_utils import info_if_rank_zero, local_rank, world_size
import torch.distributed as distributed
from datetime import timedelta
from torch.utils.data.distributed import DistributedSampler
def distributed_setup():
distributed.init_process_group(backend="nccl", timeout=timedelta(hours=2))
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
return local_rank, world_size
distributed_setup()
tsv_path = '/hpc_stor03/sjtu_home/xiquan.li/TTA/MMAudio/training/audiocaps/train-memmap-t5-clap.tsv'
data_dim = {'latent_seq_len': 312,
'text_seq_len': 77,
'text_dim': 1024,
'text_c_dim': 512}
dataset = ExtractedAudio(tsv_path=tsv_path,
npz_dir=npz_dir,
data_dim=data_dim)
loader = DataLoader(dataset,
16,
num_workers=8,
persistent_workers=8,
pin_memory=False)
train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=True)
for b in loader:
print(b['a_mean'].shape)
break |