Spaces:
Sleeping
Sleeping
""" | |
Dataset classes for common uses | |
Extended from vanilla PANet code by Wang et al. | |
""" | |
import random | |
import torch | |
from torch.utils.data import Dataset | |
class BaseDataset(Dataset): | |
""" | |
Base Dataset | |
Args: | |
base_dir: | |
dataset directory | |
""" | |
def __init__(self, base_dir): | |
self._base_dir = base_dir | |
self.aux_attrib = {} | |
self.aux_attrib_args = {} | |
self.ids = [] # must be overloaded in subclass | |
def add_attrib(self, key, func, func_args): | |
""" | |
Add attribute to the data sample dict | |
Args: | |
key: | |
key in the data sample dict for the new attribute | |
e.g. sample['click_map'], sample['depth_map'] | |
func: | |
function to process a data sample and create an attribute (e.g. user clicks) | |
func_args: | |
extra arguments to pass, expected a dict | |
""" | |
if key in self.aux_attrib: | |
raise KeyError("Attribute '{0}' already exists, please use 'set_attrib'.".format(key)) | |
else: | |
self.set_attrib(key, func, func_args) | |
def set_attrib(self, key, func, func_args): | |
""" | |
Set attribute in the data sample dict | |
Args: | |
key: | |
key in the data sample dict for the new attribute | |
e.g. sample['click_map'], sample['depth_map'] | |
func: | |
function to process a data sample and create an attribute (e.g. user clicks) | |
func_args: | |
extra arguments to pass, expected a dict | |
""" | |
self.aux_attrib[key] = func | |
self.aux_attrib_args[key] = func_args | |
def del_attrib(self, key): | |
""" | |
Remove attribute in the data sample dict | |
Args: | |
key: | |
key in the data sample dict | |
""" | |
self.aux_attrib.pop(key) | |
self.aux_attrib_args.pop(key) | |
def subsets(self, sub_ids, sub_args_lst=None): | |
""" | |
Create subsets by ids | |
Args: | |
sub_ids: | |
a sequence of sequences, each sequence contains data ids for one subset | |
sub_args_lst: | |
a list of args for some subset-specific auxiliary attribute function | |
""" | |
indices = [[self.ids.index(id_) for id_ in ids] for ids in sub_ids] | |
if sub_args_lst is not None: | |
subsets = [Subset(dataset=self, indices=index, sub_attrib_args=args) | |
for index, args in zip(indices, sub_args_lst)] | |
else: | |
subsets = [Subset(dataset=self, indices=index) for index in indices] | |
return subsets | |
def __len__(self): | |
pass | |
def __getitem__(self, idx): | |
pass | |
class ReloadPairedDataset(Dataset): | |
""" | |
Make pairs of data from dataset | |
Eable only loading part of the entire data in each epoach and then reload to the next part | |
Args: | |
datasets: | |
source datasets, expect a list of Dataset. | |
Each dataset indices a certain class. It contains a list of all z-indices of this class for each scan | |
n_elements: | |
number of elements in a pair | |
curr_max_iters: | |
number of pairs in an epoch | |
pair_based_transforms: | |
some transformation performed on a pair basis, expect a list of functions, | |
each function takes a pair sample and return a transformed one. | |
""" | |
def __init__(self, datasets, n_elements, curr_max_iters, | |
pair_based_transforms=None): | |
super().__init__() | |
self.datasets = datasets | |
self.n_datasets = len(self.datasets) | |
self.n_data = [len(dataset) for dataset in self.datasets] | |
self.n_elements = n_elements | |
self.curr_max_iters = curr_max_iters | |
self.pair_based_transforms = pair_based_transforms | |
self.update_index() | |
def update_index(self): | |
""" | |
update the order of batches for the next episode | |
""" | |
# update number of elements for each subset | |
if hasattr(self, 'indices'): | |
n_data_old = self.n_data # DEBUG | |
self.n_data = [len(dataset) for dataset in self.datasets] | |
if isinstance(self.n_elements, list): | |
self.indices = [[(dataset_idx, data_idx) for i, dataset_idx in enumerate(random.sample(range(self.n_datasets), k=len(self.n_elements))) # select which way(s) to use | |
for data_idx in random.sample(range(self.n_data[dataset_idx]), k=self.n_elements[i])] # for each way, which sample to use | |
for i_iter in range(self.curr_max_iters)] # sample <self.curr_max_iters> iterations | |
elif self.n_elements > self.n_datasets: | |
raise ValueError("When 'same=False', 'n_element' should be no more than n_datasets") | |
else: | |
self.indices = [[(dataset_idx, random.randrange(self.n_data[dataset_idx])) | |
for dataset_idx in random.sample(range(self.n_datasets), | |
k=n_elements)] | |
for i in range(curr_max_iters)] | |
def __len__(self): | |
return self.curr_max_iters | |
def __getitem__(self, idx): | |
sample = [self.datasets[dataset_idx][data_idx] | |
for dataset_idx, data_idx in self.indices[idx]] | |
if self.pair_based_transforms is not None: | |
for transform, args in self.pair_based_transforms: | |
sample = transform(sample, **args) | |
return sample | |
class Subset(Dataset): | |
""" | |
Subset of a dataset at specified indices. Used for seperating a dataset by class in our context | |
Args: | |
dataset: | |
The whole Dataset | |
indices: | |
Indices of samples of the current class in the entire dataset | |
sub_attrib_args: | |
Subset-specific arguments for attribute functions, expected a dict | |
""" | |
def __init__(self, dataset, indices, sub_attrib_args=None): | |
self.dataset = dataset | |
self.indices = indices | |
self.sub_attrib_args = sub_attrib_args | |
def __getitem__(self, idx): | |
if self.sub_attrib_args is not None: | |
for key in self.sub_attrib_args: | |
# Make sure the dataset already has the corresponding attributes | |
# Here we only make the arguments subset dependent | |
# (i.e. pass different arguments for each subset) | |
self.dataset.aux_attrib_args[key].update(self.sub_attrib_args[key]) | |
return self.dataset[self.indices[idx]] | |
def __len__(self): | |
return len(self.indices) | |
class ValidationDataset(Dataset): | |
""" | |
Dataset for validation | |
Args: | |
dataset: | |
source dataset with a __getitem__ method | |
test_classes: | |
test classes | |
npart: int. number of parts, used for evaluation when assigning support images | |
""" | |
def __init__(self, dataset, test_classes: list, npart: int): | |
super().__init__() | |
self.dataset = dataset | |
self.__curr_cls = None | |
self.test_classes = test_classes | |
self.dataset.aux_attrib = None | |
self.npart = npart | |
def set_curr_cls(self, curr_cls): | |
assert curr_cls in self.test_classes | |
self.__curr_cls = curr_cls | |
def get_curr_cls(self): | |
return self.__curr_cls | |
def read_dataset(self): | |
""" | |
override original read_dataset to allow reading with z_margin | |
""" | |
raise NotImplementedError | |
def __len__(self): | |
return len(self.dataset) | |
def label_strip(self, label): | |
""" | |
mask unrelated labels out | |
""" | |
out = torch.where(label == self.__curr_cls, | |
torch.ones_like(label), torch.zeros_like(label)) | |
return out | |
def __getitem__(self, idx): | |
if self.__curr_cls is None: | |
raise Exception("Please initialize current class first") | |
sample = self.dataset[idx] | |
sample["label"] = self.label_strip( sample["label"] ) | |
sample["label_t"] = sample["label"].unsqueeze(-1).data.numpy() | |
labelname = self.dataset.all_label_names[self.__curr_cls] | |
z_min = min(self.dataset.tp1_cls_map[labelname][sample['scan_id']]) | |
z_max = max(self.dataset.tp1_cls_map[labelname][sample['scan_id']]) | |
sample["z_min"], sample["z_max"] = z_min, z_max | |
try: | |
part_assign = int((sample["z_id"] - z_min) // ((z_max - z_min) / self.npart)) | |
except: | |
part_assign = 0 | |
# print("###### DATASET: support only has one valid slice ######") | |
if part_assign < 0: | |
part_assign = 0 | |
elif part_assign >= self.npart: | |
part_assign = self.npart - 1 | |
sample["part_assign"] = part_assign | |
sample["case"] = sample["scan_id"] | |
return sample | |
def get_support_set(self, config, n_support=3): | |
support_batched = self.dataset.get_support(curr_class=self.__curr_cls, class_idx= [self.__curr_cls], scan_idx=config["support_idx"], npart=config["task"]["npart"]) | |
support_images = [img for way in support_batched["support_images"] for img in way] | |
support_labels = [fgmask['fg_mask'] for way in support_batched["support_mask"] for fgmask in way] | |
support_scan_id = self.dataset.potential_support_sid | |
return {"support_images": support_images, "support_labels": support_labels, "support_scan_id": support_scan_id} | |