Spaces:
Sleeping
Sleeping
File size: 9,787 Bytes
427d150 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
"""
Dataset classes for common uses
Extended from vanilla PANet code by Wang et al.
"""
import random
import torch
from torch.utils.data import Dataset
class BaseDataset(Dataset):
"""
Base Dataset
Args:
base_dir:
dataset directory
"""
def __init__(self, base_dir):
self._base_dir = base_dir
self.aux_attrib = {}
self.aux_attrib_args = {}
self.ids = [] # must be overloaded in subclass
def add_attrib(self, key, func, func_args):
"""
Add attribute to the data sample dict
Args:
key:
key in the data sample dict for the new attribute
e.g. sample['click_map'], sample['depth_map']
func:
function to process a data sample and create an attribute (e.g. user clicks)
func_args:
extra arguments to pass, expected a dict
"""
if key in self.aux_attrib:
raise KeyError("Attribute '{0}' already exists, please use 'set_attrib'.".format(key))
else:
self.set_attrib(key, func, func_args)
def set_attrib(self, key, func, func_args):
"""
Set attribute in the data sample dict
Args:
key:
key in the data sample dict for the new attribute
e.g. sample['click_map'], sample['depth_map']
func:
function to process a data sample and create an attribute (e.g. user clicks)
func_args:
extra arguments to pass, expected a dict
"""
self.aux_attrib[key] = func
self.aux_attrib_args[key] = func_args
def del_attrib(self, key):
"""
Remove attribute in the data sample dict
Args:
key:
key in the data sample dict
"""
self.aux_attrib.pop(key)
self.aux_attrib_args.pop(key)
def subsets(self, sub_ids, sub_args_lst=None):
"""
Create subsets by ids
Args:
sub_ids:
a sequence of sequences, each sequence contains data ids for one subset
sub_args_lst:
a list of args for some subset-specific auxiliary attribute function
"""
indices = [[self.ids.index(id_) for id_ in ids] for ids in sub_ids]
if sub_args_lst is not None:
subsets = [Subset(dataset=self, indices=index, sub_attrib_args=args)
for index, args in zip(indices, sub_args_lst)]
else:
subsets = [Subset(dataset=self, indices=index) for index in indices]
return subsets
def __len__(self):
pass
def __getitem__(self, idx):
pass
class ReloadPairedDataset(Dataset):
"""
Make pairs of data from dataset
Eable only loading part of the entire data in each epoach and then reload to the next part
Args:
datasets:
source datasets, expect a list of Dataset.
Each dataset indices a certain class. It contains a list of all z-indices of this class for each scan
n_elements:
number of elements in a pair
curr_max_iters:
number of pairs in an epoch
pair_based_transforms:
some transformation performed on a pair basis, expect a list of functions,
each function takes a pair sample and return a transformed one.
"""
def __init__(self, datasets, n_elements, curr_max_iters,
pair_based_transforms=None):
super().__init__()
self.datasets = datasets
self.n_datasets = len(self.datasets)
self.n_data = [len(dataset) for dataset in self.datasets]
self.n_elements = n_elements
self.curr_max_iters = curr_max_iters
self.pair_based_transforms = pair_based_transforms
self.update_index()
def update_index(self):
"""
update the order of batches for the next episode
"""
# update number of elements for each subset
if hasattr(self, 'indices'):
n_data_old = self.n_data # DEBUG
self.n_data = [len(dataset) for dataset in self.datasets]
if isinstance(self.n_elements, list):
self.indices = [[(dataset_idx, data_idx) for i, dataset_idx in enumerate(random.sample(range(self.n_datasets), k=len(self.n_elements))) # select which way(s) to use
for data_idx in random.sample(range(self.n_data[dataset_idx]), k=self.n_elements[i])] # for each way, which sample to use
for i_iter in range(self.curr_max_iters)] # sample <self.curr_max_iters> iterations
elif self.n_elements > self.n_datasets:
raise ValueError("When 'same=False', 'n_element' should be no more than n_datasets")
else:
self.indices = [[(dataset_idx, random.randrange(self.n_data[dataset_idx]))
for dataset_idx in random.sample(range(self.n_datasets),
k=n_elements)]
for i in range(curr_max_iters)]
def __len__(self):
return self.curr_max_iters
def __getitem__(self, idx):
sample = [self.datasets[dataset_idx][data_idx]
for dataset_idx, data_idx in self.indices[idx]]
if self.pair_based_transforms is not None:
for transform, args in self.pair_based_transforms:
sample = transform(sample, **args)
return sample
class Subset(Dataset):
"""
Subset of a dataset at specified indices. Used for seperating a dataset by class in our context
Args:
dataset:
The whole Dataset
indices:
Indices of samples of the current class in the entire dataset
sub_attrib_args:
Subset-specific arguments for attribute functions, expected a dict
"""
def __init__(self, dataset, indices, sub_attrib_args=None):
self.dataset = dataset
self.indices = indices
self.sub_attrib_args = sub_attrib_args
def __getitem__(self, idx):
if self.sub_attrib_args is not None:
for key in self.sub_attrib_args:
# Make sure the dataset already has the corresponding attributes
# Here we only make the arguments subset dependent
# (i.e. pass different arguments for each subset)
self.dataset.aux_attrib_args[key].update(self.sub_attrib_args[key])
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)
class ValidationDataset(Dataset):
"""
Dataset for validation
Args:
dataset:
source dataset with a __getitem__ method
test_classes:
test classes
npart: int. number of parts, used for evaluation when assigning support images
"""
def __init__(self, dataset, test_classes: list, npart: int):
super().__init__()
self.dataset = dataset
self.__curr_cls = None
self.test_classes = test_classes
self.dataset.aux_attrib = None
self.npart = npart
def set_curr_cls(self, curr_cls):
assert curr_cls in self.test_classes
self.__curr_cls = curr_cls
def get_curr_cls(self):
return self.__curr_cls
def read_dataset(self):
"""
override original read_dataset to allow reading with z_margin
"""
raise NotImplementedError
def __len__(self):
return len(self.dataset)
def label_strip(self, label):
"""
mask unrelated labels out
"""
out = torch.where(label == self.__curr_cls,
torch.ones_like(label), torch.zeros_like(label))
return out
def __getitem__(self, idx):
if self.__curr_cls is None:
raise Exception("Please initialize current class first")
sample = self.dataset[idx]
sample["label"] = self.label_strip( sample["label"] )
sample["label_t"] = sample["label"].unsqueeze(-1).data.numpy()
labelname = self.dataset.all_label_names[self.__curr_cls]
z_min = min(self.dataset.tp1_cls_map[labelname][sample['scan_id']])
z_max = max(self.dataset.tp1_cls_map[labelname][sample['scan_id']])
sample["z_min"], sample["z_max"] = z_min, z_max
try:
part_assign = int((sample["z_id"] - z_min) // ((z_max - z_min) / self.npart))
except:
part_assign = 0
# print("###### DATASET: support only has one valid slice ######")
if part_assign < 0:
part_assign = 0
elif part_assign >= self.npart:
part_assign = self.npart - 1
sample["part_assign"] = part_assign
sample["case"] = sample["scan_id"]
return sample
def get_support_set(self, config, n_support=3):
support_batched = self.dataset.get_support(curr_class=self.__curr_cls, class_idx= [self.__curr_cls], scan_idx=config["support_idx"], npart=config["task"]["npart"])
support_images = [img for way in support_batched["support_images"] for img in way]
support_labels = [fgmask['fg_mask'] for way in support_batched["support_mask"] for fgmask in way]
support_scan_id = self.dataset.potential_support_sid
return {"support_images": support_images, "support_labels": support_labels, "support_scan_id": support_scan_id}
|