Spaces:
Sleeping
Sleeping
""" | |
Customized dataset. Extended from vanilla PANet script by Wang et al. | |
""" | |
import os | |
import random | |
import torch | |
import numpy as np | |
from dataloaders.common import ReloadPairedDataset, ValidationDataset | |
from dataloaders.ManualAnnoDatasetv2 import ManualAnnoDataset | |
def attrib_basic(_sample, class_id): | |
""" | |
Add basic attribute | |
Args: | |
_sample: data sample | |
class_id: class label asscociated with the data | |
(sometimes indicting from which subset the data are drawn) | |
""" | |
return {'class_id': class_id} | |
def getMaskOnly(label, class_id, class_ids): | |
""" | |
Generate FG/BG mask from the segmentation mask | |
Args: | |
label: | |
semantic mask | |
scribble: | |
scribble mask | |
class_id: | |
semantic class of interest | |
class_ids: | |
all class id in this episode | |
""" | |
# Dense Mask | |
fg_mask = torch.where(label == class_id, | |
torch.ones_like(label), torch.zeros_like(label)) | |
bg_mask = torch.where(label != class_id, | |
torch.ones_like(label), torch.zeros_like(label)) | |
for class_id in class_ids: | |
bg_mask[label == class_id] = 0 | |
return {'fg_mask': fg_mask, | |
'bg_mask': bg_mask} | |
def getMasks(*args, **kwargs): | |
raise NotImplementedError | |
def fewshot_pairing(paired_sample, n_ways, n_shots, cnt_query, coco=False, mask_only = True): | |
""" | |
Postprocess paired sample for fewshot settings | |
For now only 1-way is tested but we leave multi-way possible (inherited from original PANet) | |
Args: | |
paired_sample: | |
data sample from a PairedDataset | |
n_ways: | |
n-way few-shot learning | |
n_shots: | |
n-shot few-shot learning | |
cnt_query: | |
number of query images for each class in the support set | |
coco: | |
MS COCO dataset. This is from the original PANet dataset but lets keep it for further extension | |
mask_only: | |
only give masks and no scribbles/ instances. Suitable for medical images (for now) | |
""" | |
if not mask_only: | |
raise NotImplementedError | |
###### Compose the support and query image list ###### | |
cumsum_idx = np.cumsum([0,] + [n_shots + x for x in cnt_query]) # seperation for supports and queries | |
# support class ids | |
class_ids = [paired_sample[cumsum_idx[i]]['basic_class_id'] for i in range(n_ways)] # class ids for each image (support and query) | |
# support images | |
support_images = [[paired_sample[cumsum_idx[i] + j]['image'] for j in range(n_shots)] | |
for i in range(n_ways)] # fetch support images for each class | |
# support image labels | |
if coco: | |
support_labels = [[paired_sample[cumsum_idx[i] + j]['label'][class_ids[i]] | |
for j in range(n_shots)] for i in range(n_ways)] | |
else: | |
support_labels = [[paired_sample[cumsum_idx[i] + j]['label'] for j in range(n_shots)] | |
for i in range(n_ways)] | |
if not mask_only: | |
support_scribbles = [[paired_sample[cumsum_idx[i] + j]['scribble'] for j in range(n_shots)] | |
for i in range(n_ways)] | |
support_insts = [[paired_sample[cumsum_idx[i] + j]['inst'] for j in range(n_shots)] | |
for i in range(n_ways)] | |
else: | |
support_insts = [] | |
# query images, masks and class indices | |
query_images = [paired_sample[cumsum_idx[i+1] - j - 1]['image'] for i in range(n_ways) | |
for j in range(cnt_query[i])] | |
if coco: | |
query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'][class_ids[i]] | |
for i in range(n_ways) for j in range(cnt_query[i])] | |
else: | |
query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'] for i in range(n_ways) | |
for j in range(cnt_query[i])] | |
query_cls_idx = [sorted([0,] + [class_ids.index(x) + 1 | |
for x in set(np.unique(query_label)) & set(class_ids)]) | |
for query_label in query_labels] | |
###### Generate support image masks ###### | |
if not mask_only: | |
support_mask = [[getMasks(support_labels[way][shot], support_scribbles[way][shot], | |
class_ids[way], class_ids) | |
for shot in range(n_shots)] for way in range(n_ways)] | |
else: | |
support_mask = [[getMaskOnly(support_labels[way][shot], | |
class_ids[way], class_ids) | |
for shot in range(n_shots)] for way in range(n_ways)] | |
###### Generate query label (class indices in one episode, i.e. the ground truth)###### | |
query_labels_tmp = [torch.zeros_like(x) for x in query_labels] | |
for i, query_label_tmp in enumerate(query_labels_tmp): | |
query_label_tmp[query_labels[i] == 255] = 255 | |
for j in range(n_ways): | |
query_label_tmp[query_labels[i] == class_ids[j]] = j + 1 | |
###### Generate query mask for each semantic class (including BG) ###### | |
# BG class | |
query_masks = [[torch.where(query_label == 0, | |
torch.ones_like(query_label), | |
torch.zeros_like(query_label))[None, ...],] | |
for query_label in query_labels] | |
# Other classes in query image | |
for i, query_label in enumerate(query_labels): | |
for idx in query_cls_idx[i][1:]: | |
mask = torch.where(query_label == class_ids[idx - 1], | |
torch.ones_like(query_label), | |
torch.zeros_like(query_label))[None, ...] | |
query_masks[i].append(mask) | |
return {'class_ids': class_ids, | |
'support_images': support_images, | |
'support_mask': support_mask, | |
'support_inst': support_insts, # leave these interfaces | |
'support_scribbles': support_scribbles, | |
'query_images': query_images, | |
'query_labels': query_labels_tmp, | |
'query_masks': query_masks, | |
'query_cls_idx': query_cls_idx, | |
} | |
def med_fewshot(dataset_name, base_dir, idx_split, mode, scan_per_load, | |
transforms, act_labels, n_ways, n_shots, max_iters_per_load, min_fg = '', n_queries=1, fix_parent_len = None, exclude_list = [], **kwargs): | |
""" | |
Dataset wrapper | |
Args: | |
dataset_name: | |
indicates what dataset to use | |
base_dir: | |
dataset directory | |
mode: | |
which mode to use | |
choose from ('train', 'val', 'trainval', 'trainaug') | |
idx_split: | |
index of split | |
scan_per_load: | |
number of scans to load into memory as the dataset is large | |
use that together with reload_buffer | |
transforms: | |
transformations to be performed on images/masks | |
act_labels: | |
active labels involved in training process. Should be a subset of all labels | |
n_ways: | |
n-way few-shot learning, should be no more than # of object class labels | |
n_shots: | |
n-shot few-shot learning | |
max_iters_per_load: | |
number of pairs per load (epoch size) | |
n_queries: | |
number of query images | |
fix_parent_len: | |
fixed length of the parent dataset | |
""" | |
med_set = ManualAnnoDataset | |
mydataset = med_set(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode,\ | |
scan_per_load = scan_per_load, transforms=transforms, min_fg = min_fg, fix_length = fix_parent_len,\ | |
exclude_list = exclude_list, **kwargs) | |
mydataset.add_attrib('basic', attrib_basic, {}) | |
# Create sub-datasets and add class_id attribute. Here the class file is internally loaded and reloaded inside | |
subsets = mydataset.subsets([{'basic': {'class_id': ii}} | |
for ii, _ in enumerate(mydataset.label_name)]) | |
# Choose the classes of queries | |
cnt_query = np.bincount(random.choices(population=range(n_ways), k=n_queries), minlength=n_ways) | |
# Number of queries for each way | |
# Set the number of images for each class | |
n_elements = [n_shots + x for x in cnt_query] # <n_shot> supports + <cnt_quert>[i] queries | |
# Create paired dataset. We do not include background. | |
paired_data = ReloadPairedDataset([subsets[ii] for ii in act_labels], n_elements=n_elements, curr_max_iters=max_iters_per_load, | |
pair_based_transforms=[ | |
(fewshot_pairing, {'n_ways': n_ways, 'n_shots': n_shots, | |
'cnt_query': cnt_query, 'mask_only': True})]) | |
return paired_data, mydataset | |
def update_loader_dset(loader, parent_set): | |
""" | |
Update data loader and the parent dataset behind | |
Args: | |
loader: actual dataloader | |
parent_set: parent dataset which actually stores the data | |
""" | |
parent_set.reload_buffer() | |
loader.dataset.update_index() | |
print(f'###### Loader and dataset have been updated ######' ) | |
def med_fewshot_val(dataset_name, base_dir, idx_split, scan_per_load, act_labels, npart, fix_length = None, nsup = 1, transforms=None, mode='val', **kwargs): | |
""" | |
validation set for med images | |
Args: | |
dataset_name: | |
indicates what dataset to use | |
base_dir: | |
SABS dataset directory | |
mode: (original split) | |
which split to use | |
choose from ('train', 'val', 'trainval', 'trainaug') | |
idx_split: | |
index of split | |
scan_per_batch: | |
number of scans to load into memory as the dataset is large | |
use that together with reload_buffer | |
act_labels: | |
actual labels involved in training process. Should be a subset of all labels | |
npart: number of chunks for splitting a 3d volume | |
nsup: number of support scans, equivalent to nshot | |
""" | |
mydataset = ManualAnnoDataset(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode, scan_per_load = scan_per_load, transforms=transforms, min_fg = 1, fix_length = fix_length, nsup = nsup, **kwargs) | |
mydataset.add_attrib('basic', attrib_basic, {}) | |
valset = ValidationDataset(mydataset, test_classes = act_labels, npart = npart) | |
return valset, mydataset |