|
import bisect |
|
|
|
import torch |
|
from torch.utils.data.dataset import Dataset |
|
|
|
|
|
|
|
class MultiModalDataset(Dataset): |
|
datasets: list[Dataset] |
|
cumulative_sizes: list[int] |
|
|
|
@staticmethod |
|
def cumsum(sequence): |
|
r, s = [], 0 |
|
for e in sequence: |
|
l = len(e) |
|
r.append(l + s) |
|
s += l |
|
return r |
|
|
|
def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]): |
|
super().__init__() |
|
self.video_datasets = list(video_datasets) if video_datasets else [] |
|
self.audio_datasets = list(audio_datasets) if audio_datasets else [] |
|
self.datasets = self.video_datasets + self.audio_datasets |
|
|
|
self.cumulative_sizes = self.cumsum(self.datasets) |
|
|
|
def __len__(self): |
|
return self.cumulative_sizes[-1] |
|
|
|
def __getitem__(self, idx): |
|
if idx < 0: |
|
if -idx > len(self): |
|
raise ValueError("absolute value of index should not exceed dataset length") |
|
idx = len(self) + idx |
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
|
if dataset_idx == 0: |
|
sample_idx = idx |
|
else: |
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
|
return self.datasets[dataset_idx][sample_idx] |
|
|
|
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]: |
|
if self.video_datasets == []: |
|
raise NotImplementedError(f'This function should not be called for audio-text dataset', |
|
'Please load latents stats manually instead') |
|
return self.audio_datasets[0].compute_latent_stats() |
|
else: |
|
return self.video_datasets[0].compute_latent_stats() |