""" Dataset classes for common uses Extended from vanilla PANet code by Wang et al. """ import random import torch from torch.utils.data import Dataset class BaseDataset(Dataset): """ Base Dataset Args: base_dir: dataset directory """ def __init__(self, base_dir): self._base_dir = base_dir self.aux_attrib = {} self.aux_attrib_args = {} self.ids = [] # must be overloaded in subclass def add_attrib(self, key, func, func_args): """ Add attribute to the data sample dict Args: key: key in the data sample dict for the new attribute e.g. sample['click_map'], sample['depth_map'] func: function to process a data sample and create an attribute (e.g. user clicks) func_args: extra arguments to pass, expected a dict """ if key in self.aux_attrib: raise KeyError("Attribute '{0}' already exists, please use 'set_attrib'.".format(key)) else: self.set_attrib(key, func, func_args) def set_attrib(self, key, func, func_args): """ Set attribute in the data sample dict Args: key: key in the data sample dict for the new attribute e.g. sample['click_map'], sample['depth_map'] func: function to process a data sample and create an attribute (e.g. user clicks) func_args: extra arguments to pass, expected a dict """ self.aux_attrib[key] = func self.aux_attrib_args[key] = func_args def del_attrib(self, key): """ Remove attribute in the data sample dict Args: key: key in the data sample dict """ self.aux_attrib.pop(key) self.aux_attrib_args.pop(key) def subsets(self, sub_ids, sub_args_lst=None): """ Create subsets by ids Args: sub_ids: a sequence of sequences, each sequence contains data ids for one subset sub_args_lst: a list of args for some subset-specific auxiliary attribute function """ indices = [[self.ids.index(id_) for id_ in ids] for ids in sub_ids] if sub_args_lst is not None: subsets = [Subset(dataset=self, indices=index, sub_attrib_args=args) for index, args in zip(indices, sub_args_lst)] else: subsets = [Subset(dataset=self, indices=index) for index in indices] return subsets def __len__(self): pass def __getitem__(self, idx): pass class ReloadPairedDataset(Dataset): """ Make pairs of data from dataset Eable only loading part of the entire data in each epoach and then reload to the next part Args: datasets: source datasets, expect a list of Dataset. Each dataset indices a certain class. It contains a list of all z-indices of this class for each scan n_elements: number of elements in a pair curr_max_iters: number of pairs in an epoch pair_based_transforms: some transformation performed on a pair basis, expect a list of functions, each function takes a pair sample and return a transformed one. """ def __init__(self, datasets, n_elements, curr_max_iters, pair_based_transforms=None): super().__init__() self.datasets = datasets self.n_datasets = len(self.datasets) self.n_data = [len(dataset) for dataset in self.datasets] self.n_elements = n_elements self.curr_max_iters = curr_max_iters self.pair_based_transforms = pair_based_transforms self.update_index() def update_index(self): """ update the order of batches for the next episode """ # update number of elements for each subset if hasattr(self, 'indices'): n_data_old = self.n_data # DEBUG self.n_data = [len(dataset) for dataset in self.datasets] if isinstance(self.n_elements, list): self.indices = [[(dataset_idx, data_idx) for i, dataset_idx in enumerate(random.sample(range(self.n_datasets), k=len(self.n_elements))) # select which way(s) to use for data_idx in random.sample(range(self.n_data[dataset_idx]), k=self.n_elements[i])] # for each way, which sample to use for i_iter in range(self.curr_max_iters)] # sample iterations elif self.n_elements > self.n_datasets: raise ValueError("When 'same=False', 'n_element' should be no more than n_datasets") else: self.indices = [[(dataset_idx, random.randrange(self.n_data[dataset_idx])) for dataset_idx in random.sample(range(self.n_datasets), k=n_elements)] for i in range(curr_max_iters)] def __len__(self): return self.curr_max_iters def __getitem__(self, idx): sample = [self.datasets[dataset_idx][data_idx] for dataset_idx, data_idx in self.indices[idx]] if self.pair_based_transforms is not None: for transform, args in self.pair_based_transforms: sample = transform(sample, **args) return sample class Subset(Dataset): """ Subset of a dataset at specified indices. Used for seperating a dataset by class in our context Args: dataset: The whole Dataset indices: Indices of samples of the current class in the entire dataset sub_attrib_args: Subset-specific arguments for attribute functions, expected a dict """ def __init__(self, dataset, indices, sub_attrib_args=None): self.dataset = dataset self.indices = indices self.sub_attrib_args = sub_attrib_args def __getitem__(self, idx): if self.sub_attrib_args is not None: for key in self.sub_attrib_args: # Make sure the dataset already has the corresponding attributes # Here we only make the arguments subset dependent # (i.e. pass different arguments for each subset) self.dataset.aux_attrib_args[key].update(self.sub_attrib_args[key]) return self.dataset[self.indices[idx]] def __len__(self): return len(self.indices) class ValidationDataset(Dataset): """ Dataset for validation Args: dataset: source dataset with a __getitem__ method test_classes: test classes npart: int. number of parts, used for evaluation when assigning support images """ def __init__(self, dataset, test_classes: list, npart: int): super().__init__() self.dataset = dataset self.__curr_cls = None self.test_classes = test_classes self.dataset.aux_attrib = None self.npart = npart def set_curr_cls(self, curr_cls): assert curr_cls in self.test_classes self.__curr_cls = curr_cls def get_curr_cls(self): return self.__curr_cls def read_dataset(self): """ override original read_dataset to allow reading with z_margin """ raise NotImplementedError def __len__(self): return len(self.dataset) def label_strip(self, label): """ mask unrelated labels out """ out = torch.where(label == self.__curr_cls, torch.ones_like(label), torch.zeros_like(label)) return out def __getitem__(self, idx): if self.__curr_cls is None: raise Exception("Please initialize current class first") sample = self.dataset[idx] sample["label"] = self.label_strip( sample["label"] ) sample["label_t"] = sample["label"].unsqueeze(-1).data.numpy() labelname = self.dataset.all_label_names[self.__curr_cls] z_min = min(self.dataset.tp1_cls_map[labelname][sample['scan_id']]) z_max = max(self.dataset.tp1_cls_map[labelname][sample['scan_id']]) sample["z_min"], sample["z_max"] = z_min, z_max try: part_assign = int((sample["z_id"] - z_min) // ((z_max - z_min) / self.npart)) except: part_assign = 0 # print("###### DATASET: support only has one valid slice ######") if part_assign < 0: part_assign = 0 elif part_assign >= self.npart: part_assign = self.npart - 1 sample["part_assign"] = part_assign sample["case"] = sample["scan_id"] return sample def get_support_set(self, config, n_support=3): support_batched = self.dataset.get_support(curr_class=self.__curr_cls, class_idx= [self.__curr_cls], scan_idx=config["support_idx"], npart=config["task"]["npart"]) support_images = [img for way in support_batched["support_images"] for img in way] support_labels = [fgmask['fg_mask'] for way in support_batched["support_mask"] for fgmask in way] support_scan_id = self.dataset.potential_support_sid return {"support_images": support_images, "support_labels": support_labels, "support_scan_id": support_scan_id}