Spaces:
Sleeping
Sleeping
File size: 10,601 Bytes
427d150 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
"""
Customized dataset. Extended from vanilla PANet script by Wang et al.
"""
import os
import random
import torch
import numpy as np
from dataloaders.common import ReloadPairedDataset, ValidationDataset
from dataloaders.ManualAnnoDatasetv2 import ManualAnnoDataset
def attrib_basic(_sample, class_id):
"""
Add basic attribute
Args:
_sample: data sample
class_id: class label asscociated with the data
(sometimes indicting from which subset the data are drawn)
"""
return {'class_id': class_id}
def getMaskOnly(label, class_id, class_ids):
"""
Generate FG/BG mask from the segmentation mask
Args:
label:
semantic mask
scribble:
scribble mask
class_id:
semantic class of interest
class_ids:
all class id in this episode
"""
# Dense Mask
fg_mask = torch.where(label == class_id,
torch.ones_like(label), torch.zeros_like(label))
bg_mask = torch.where(label != class_id,
torch.ones_like(label), torch.zeros_like(label))
for class_id in class_ids:
bg_mask[label == class_id] = 0
return {'fg_mask': fg_mask,
'bg_mask': bg_mask}
def getMasks(*args, **kwargs):
raise NotImplementedError
def fewshot_pairing(paired_sample, n_ways, n_shots, cnt_query, coco=False, mask_only = True):
"""
Postprocess paired sample for fewshot settings
For now only 1-way is tested but we leave multi-way possible (inherited from original PANet)
Args:
paired_sample:
data sample from a PairedDataset
n_ways:
n-way few-shot learning
n_shots:
n-shot few-shot learning
cnt_query:
number of query images for each class in the support set
coco:
MS COCO dataset. This is from the original PANet dataset but lets keep it for further extension
mask_only:
only give masks and no scribbles/ instances. Suitable for medical images (for now)
"""
if not mask_only:
raise NotImplementedError
###### Compose the support and query image list ######
cumsum_idx = np.cumsum([0,] + [n_shots + x for x in cnt_query]) # seperation for supports and queries
# support class ids
class_ids = [paired_sample[cumsum_idx[i]]['basic_class_id'] for i in range(n_ways)] # class ids for each image (support and query)
# support images
support_images = [[paired_sample[cumsum_idx[i] + j]['image'] for j in range(n_shots)]
for i in range(n_ways)] # fetch support images for each class
# support image labels
if coco:
support_labels = [[paired_sample[cumsum_idx[i] + j]['label'][class_ids[i]]
for j in range(n_shots)] for i in range(n_ways)]
else:
support_labels = [[paired_sample[cumsum_idx[i] + j]['label'] for j in range(n_shots)]
for i in range(n_ways)]
if not mask_only:
support_scribbles = [[paired_sample[cumsum_idx[i] + j]['scribble'] for j in range(n_shots)]
for i in range(n_ways)]
support_insts = [[paired_sample[cumsum_idx[i] + j]['inst'] for j in range(n_shots)]
for i in range(n_ways)]
else:
support_insts = []
# query images, masks and class indices
query_images = [paired_sample[cumsum_idx[i+1] - j - 1]['image'] for i in range(n_ways)
for j in range(cnt_query[i])]
if coco:
query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'][class_ids[i]]
for i in range(n_ways) for j in range(cnt_query[i])]
else:
query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'] for i in range(n_ways)
for j in range(cnt_query[i])]
query_cls_idx = [sorted([0,] + [class_ids.index(x) + 1
for x in set(np.unique(query_label)) & set(class_ids)])
for query_label in query_labels]
###### Generate support image masks ######
if not mask_only:
support_mask = [[getMasks(support_labels[way][shot], support_scribbles[way][shot],
class_ids[way], class_ids)
for shot in range(n_shots)] for way in range(n_ways)]
else:
support_mask = [[getMaskOnly(support_labels[way][shot],
class_ids[way], class_ids)
for shot in range(n_shots)] for way in range(n_ways)]
###### Generate query label (class indices in one episode, i.e. the ground truth)######
query_labels_tmp = [torch.zeros_like(x) for x in query_labels]
for i, query_label_tmp in enumerate(query_labels_tmp):
query_label_tmp[query_labels[i] == 255] = 255
for j in range(n_ways):
query_label_tmp[query_labels[i] == class_ids[j]] = j + 1
###### Generate query mask for each semantic class (including BG) ######
# BG class
query_masks = [[torch.where(query_label == 0,
torch.ones_like(query_label),
torch.zeros_like(query_label))[None, ...],]
for query_label in query_labels]
# Other classes in query image
for i, query_label in enumerate(query_labels):
for idx in query_cls_idx[i][1:]:
mask = torch.where(query_label == class_ids[idx - 1],
torch.ones_like(query_label),
torch.zeros_like(query_label))[None, ...]
query_masks[i].append(mask)
return {'class_ids': class_ids,
'support_images': support_images,
'support_mask': support_mask,
'support_inst': support_insts, # leave these interfaces
'support_scribbles': support_scribbles,
'query_images': query_images,
'query_labels': query_labels_tmp,
'query_masks': query_masks,
'query_cls_idx': query_cls_idx,
}
def med_fewshot(dataset_name, base_dir, idx_split, mode, scan_per_load,
transforms, act_labels, n_ways, n_shots, max_iters_per_load, min_fg = '', n_queries=1, fix_parent_len = None, exclude_list = [], **kwargs):
"""
Dataset wrapper
Args:
dataset_name:
indicates what dataset to use
base_dir:
dataset directory
mode:
which mode to use
choose from ('train', 'val', 'trainval', 'trainaug')
idx_split:
index of split
scan_per_load:
number of scans to load into memory as the dataset is large
use that together with reload_buffer
transforms:
transformations to be performed on images/masks
act_labels:
active labels involved in training process. Should be a subset of all labels
n_ways:
n-way few-shot learning, should be no more than # of object class labels
n_shots:
n-shot few-shot learning
max_iters_per_load:
number of pairs per load (epoch size)
n_queries:
number of query images
fix_parent_len:
fixed length of the parent dataset
"""
med_set = ManualAnnoDataset
mydataset = med_set(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode,\
scan_per_load = scan_per_load, transforms=transforms, min_fg = min_fg, fix_length = fix_parent_len,\
exclude_list = exclude_list, **kwargs)
mydataset.add_attrib('basic', attrib_basic, {})
# Create sub-datasets and add class_id attribute. Here the class file is internally loaded and reloaded inside
subsets = mydataset.subsets([{'basic': {'class_id': ii}}
for ii, _ in enumerate(mydataset.label_name)])
# Choose the classes of queries
cnt_query = np.bincount(random.choices(population=range(n_ways), k=n_queries), minlength=n_ways)
# Number of queries for each way
# Set the number of images for each class
n_elements = [n_shots + x for x in cnt_query] # <n_shot> supports + <cnt_quert>[i] queries
# Create paired dataset. We do not include background.
paired_data = ReloadPairedDataset([subsets[ii] for ii in act_labels], n_elements=n_elements, curr_max_iters=max_iters_per_load,
pair_based_transforms=[
(fewshot_pairing, {'n_ways': n_ways, 'n_shots': n_shots,
'cnt_query': cnt_query, 'mask_only': True})])
return paired_data, mydataset
def update_loader_dset(loader, parent_set):
"""
Update data loader and the parent dataset behind
Args:
loader: actual dataloader
parent_set: parent dataset which actually stores the data
"""
parent_set.reload_buffer()
loader.dataset.update_index()
print(f'###### Loader and dataset have been updated ######' )
def med_fewshot_val(dataset_name, base_dir, idx_split, scan_per_load, act_labels, npart, fix_length = None, nsup = 1, transforms=None, mode='val', **kwargs):
"""
validation set for med images
Args:
dataset_name:
indicates what dataset to use
base_dir:
SABS dataset directory
mode: (original split)
which split to use
choose from ('train', 'val', 'trainval', 'trainaug')
idx_split:
index of split
scan_per_batch:
number of scans to load into memory as the dataset is large
use that together with reload_buffer
act_labels:
actual labels involved in training process. Should be a subset of all labels
npart: number of chunks for splitting a 3d volume
nsup: number of support scans, equivalent to nshot
"""
mydataset = ManualAnnoDataset(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode, scan_per_load = scan_per_load, transforms=transforms, min_fg = 1, fix_length = fix_length, nsup = nsup, **kwargs)
mydataset.add_attrib('basic', attrib_basic, {})
valset = ValidationDataset(mydataset, test_classes = act_labels, npart = npart)
return valset, mydataset |