|
from .catalog import DatasetCatalog |
|
from ldm.util import instantiate_from_config |
|
import torch |
|
|
|
|
|
|
|
|
|
class ConCatDataset(): |
|
def __init__(self, dataset_name_list, ROOT, which_embedder, train=True, repeats=None): |
|
self.datasets = [] |
|
cul_previous_dataset_length = 0 |
|
offset_map = [] |
|
which_dataset = [] |
|
|
|
if repeats is None: |
|
repeats = [1] * len(dataset_name_list) |
|
else: |
|
assert len(repeats) == len(dataset_name_list) |
|
|
|
|
|
Catalog = DatasetCatalog(ROOT, which_embedder) |
|
for dataset_idx, (dataset_name, yaml_params) in enumerate(dataset_name_list.items()): |
|
repeat = repeats[dataset_idx] |
|
|
|
dataset_dict = getattr(Catalog, dataset_name) |
|
|
|
target = dataset_dict['target'] |
|
params = dataset_dict['train_params'] if train else dataset_dict['val_params'] |
|
if yaml_params is not None: |
|
params.update(yaml_params) |
|
dataset = instantiate_from_config( dict(target=target, params=params) ) |
|
|
|
self.datasets.append(dataset) |
|
for _ in range(repeat): |
|
offset_map.append( torch.ones(len(dataset))*cul_previous_dataset_length ) |
|
which_dataset.append( torch.ones(len(dataset))*dataset_idx ) |
|
cul_previous_dataset_length += len(dataset) |
|
offset_map = torch.cat(offset_map, dim=0).long() |
|
self.total_length = cul_previous_dataset_length |
|
|
|
self.mapping = torch.arange(self.total_length) - offset_map |
|
self.which_dataset = torch.cat(which_dataset, dim=0).long() |
|
|
|
|
|
def total_images(self): |
|
count = 0 |
|
for dataset in self.datasets: |
|
print(dataset.total_images()) |
|
count += dataset.total_images() |
|
return count |
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
dataset = self.datasets[ self.which_dataset[idx] ] |
|
return dataset[ self.mapping[idx] ] |
|
|
|
|
|
def __len__(self): |
|
return self.total_length |
|
|
|
|
|
|
|
|
|
|
|
|