|
import logging |
|
from pathlib import Path |
|
from typing import Union, Optional |
|
|
|
import pandas as pd |
|
import torch |
|
from tensordict import TensorDict |
|
from torch.utils.data.dataset import Dataset |
|
from torch.utils.data import DataLoader |
|
|
|
from meanaudio.utils.dist_utils import local_rank |
|
import numpy as np |
|
import glob |
|
import torch.nn.functional as F |
|
log = logging.getLogger() |
|
|
|
|
|
class ExtractedAudio(Dataset): |
|
def __init__( |
|
self, |
|
tsv_path: Union[str, Path], |
|
*, |
|
concat_text_fc: bool, |
|
npz_dir: Union[str, Path], |
|
data_dim: dict[str, int], |
|
repa_npz_dir: Optional[Union[str, Path]], |
|
exclude_cls: Optional[bool], |
|
repa_version: Optional[int], |
|
): |
|
super().__init__() |
|
self.data_dim = data_dim |
|
self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records') |
|
self.ids = [str(d['id']) for d in self.df_list] |
|
npz_files = glob.glob(f"{npz_dir}/*.npz") |
|
self.concat_text_fc = concat_text_fc |
|
self.exclude_cls = exclude_cls |
|
self.repa_version = repa_version |
|
|
|
if self.concat_text_fc: |
|
log.info(f'We will concat the pooled text_features and text_features_c for text condition') |
|
|
|
|
|
sample = np.load(f'{npz_dir}/0.npz') |
|
mean_s = [len(npz_files)] + list(sample['mean'].shape) |
|
std_s = [len(npz_files)] + list(sample['std'].shape) |
|
text_features_s = [len(npz_files)] + list(sample['text_features'].shape) |
|
text_features_c_s = [len(npz_files)] + list(sample['text_features_c'].shape) |
|
if self.concat_text_fc: |
|
text_features_c_s[-1] = text_features_c_s[-1] + text_features_s[-1] |
|
|
|
log.info(f'Loading {len(npz_files)} npz files from {npz_dir}') |
|
log.info(f'Loaded mean: {mean_s}.') |
|
log.info(f'Loaded std: {std_s}.') |
|
log.info(f'Loaded text features: {text_features_s}.') |
|
log.info(f'Loaded text features_c: {text_features_c_s}.') |
|
|
|
assert len(npz_files) == len(self.df_list), 'Number mismatch between npz files and tsv items' |
|
assert mean_s[1] == self.data_dim['latent_seq_len'], \ |
|
f'{mean_s[1]} != {self.data_dim["latent_seq_len"]}' |
|
assert std_s[1] == self.data_dim['latent_seq_len'], \ |
|
f'{std_s[1]} != {self.data_dim["latent_seq_len"]}' |
|
assert text_features_s[1] == self.data_dim['text_seq_len'], \ |
|
f'{text_features_s[1]} != {self.data_dim["text_seq_len"]}' |
|
assert text_features_s[-1] == self.data_dim['text_dim'], \ |
|
f'{text_features_s[-1]} != {self.data_dim["text_dim"]}' |
|
assert text_features_c_s[-1] == self.data_dim['text_c_dim'], \ |
|
f'{text_features_c_s[-1]} != {self.data_dim["text_c_dim"]}' |
|
|
|
self.npz_dir = npz_dir |
|
if repa_npz_dir != None: |
|
self.repa_npz_dir = repa_npz_dir |
|
sample = np.load(f'{repa_npz_dir}/0.npz') |
|
repa_npz_files = glob.glob(f"{repa_npz_dir}/*.npz") |
|
log.info(f'Loading {len(repa_npz_files)} npz representations from {repa_npz_dir}') |
|
es_s = [len(repa_npz_files)] + list(sample['es'].shape) |
|
if self.repa_version == 2: |
|
es_s[1] = 65 |
|
elif self.repa_version == 3: |
|
es_s[1] = 1 |
|
else: |
|
if self.exclude_cls: |
|
es_s[1] = es_s[1] - 1 |
|
|
|
log.info(f'Loaded es: {es_s}') |
|
assert len(repa_npz_files) == len(npz_files), 'Number mismatch between repa npz files and latent npz files' |
|
assert es_s[1] == self.data_dim['repa_seq_len'], \ |
|
f'{es_s[1]} != {self.data_dim["repa_seq_len"]}' |
|
assert es_s[-1] == self.data_dim['repa_seq_dim'], \ |
|
f'{es_s[-1]} != {self.data_dim["repa_seq_dim"]}' |
|
else: |
|
self.repa_npz_dir = None |
|
|
|
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
|
raise NotImplementedError('Please manually compute latent stats outside. ') |
|
|
|
def __getitem__(self, idx): |
|
npz_path = f'{self.npz_dir}/{idx}.npz' |
|
np_data = np.load(npz_path) |
|
text_features = torch.from_numpy(np_data['text_features']) |
|
text_features_c = torch.from_numpy(np_data['text_features_c']) |
|
if self.concat_text_fc: |
|
text_features_c = torch.cat([text_features.mean(dim=-2), |
|
text_features_c], dim=-1) |
|
|
|
out_dict = { |
|
'id': str(self.df_list[idx]['id']), |
|
'a_mean': torch.from_numpy(np_data['mean']), |
|
'a_std': torch.from_numpy(np_data['std']), |
|
'text_features': text_features, |
|
'text_features_c': text_features_c, |
|
'caption': self.df_list[idx]['caption'], |
|
} |
|
if self.repa_npz_dir != None: |
|
repa_npz_path = f'{self.repa_npz_dir}/{idx}.npz' |
|
repa_np_data = np.load(repa_npz_path) |
|
zs = torch.from_numpy(repa_np_data['es']) |
|
|
|
if self.repa_version == 1: |
|
if self.exclude_cls: |
|
zs = zs[1:,:] |
|
if self.repa_version == 2: |
|
z_cls = zs[0] |
|
|
|
zs = F.avg_pool2d(zs[1:,:].unsqueeze(0), |
|
kernel_size=(8, 1), |
|
stride=(8, 1)).squeeze() |
|
zs = torch.cat((z_cls.unsqueeze(0), zs), dim=0) |
|
elif self.repa_version == 3: |
|
zs = zs[0].unsqueeze(0) |
|
|
|
out_dict['zs'] = zs |
|
|
|
return out_dict |
|
|
|
def __len__(self): |
|
return len(self.ids) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
from meanaudio.utils.dist_utils import info_if_rank_zero, local_rank, world_size |
|
import torch.distributed as distributed |
|
from datetime import timedelta |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
|
|
def distributed_setup(): |
|
distributed.init_process_group(backend="nccl", timeout=timedelta(hours=2)) |
|
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}') |
|
return local_rank, world_size |
|
|
|
distributed_setup() |
|
|
|
tsv_path = '/hpc_stor03/sjtu_home/xiquan.li/TTA/MMAudio/training/audiocaps/train-memmap-t5-clap.tsv' |
|
|
|
data_dim = {'latent_seq_len': 312, |
|
'text_seq_len': 77, |
|
'text_dim': 1024, |
|
'text_c_dim': 512} |
|
|
|
dataset = ExtractedAudio(tsv_path=tsv_path, |
|
npz_dir=npz_dir, |
|
data_dim=data_dim) |
|
loader = DataLoader(dataset, |
|
16, |
|
num_workers=8, |
|
persistent_workers=8, |
|
pin_memory=False) |
|
train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=True) |
|
|
|
|
|
for b in loader: |
|
print(b['a_mean'].shape) |
|
break |