Spaces:
Sleeping
Sleeping
""" | |
Dataset for training with pseudolabels | |
TODO: | |
1. Merge with manual annotated dataset | |
2. superpixel_scale -> superpix_config, feed like a dict | |
""" | |
import glob | |
import numpy as np | |
import dataloaders.augutils as myaug | |
import torch | |
import random | |
import os | |
import copy | |
import platform | |
import json | |
import re | |
import cv2 | |
from dataloaders.common import BaseDataset, Subset | |
from dataloaders.dataset_utils import* | |
from pdb import set_trace | |
from util.utils import CircularList | |
from util.consts import IMG_SIZE | |
class SuperpixelDataset(BaseDataset): | |
def __init__(self, which_dataset, base_dir, idx_split, mode, image_size, transforms, scan_per_load, num_rep = 2, min_fg = '', nsup = 1, fix_length = None, tile_z_dim = 3, exclude_list = [], train_list = [], superpix_scale = 'SMALL', norm_mean=None, norm_std=None, supervised_train=False, use_3_slices=False, **kwargs): | |
""" | |
Pseudolabel dataset | |
Args: | |
which_dataset: name of the dataset to use | |
base_dir: directory of dataset | |
idx_split: index of data split as we will do cross validation | |
mode: 'train', 'val'. | |
nsup: number of scans used as support. currently idle for superpixel dataset | |
transforms: data transform (augmentation) function | |
scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time | |
num_rep: Number of augmentation applied for a same pseudolabel | |
tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images | |
fix_length: fix the length of dataset | |
exclude_list: Labels to be excluded | |
superpix_scale: config of superpixels | |
""" | |
super(SuperpixelDataset, self).__init__(base_dir) | |
self.img_modality = DATASET_INFO[which_dataset]['MODALITY'] | |
self.sep = DATASET_INFO[which_dataset]['_SEP'] | |
self.pseu_label_name = DATASET_INFO[which_dataset]['PSEU_LABEL_NAME'] | |
self.real_label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME'] | |
self.image_size = image_size | |
self.transforms = transforms | |
self.is_train = True if mode == 'train' else False | |
self.supervised_train = supervised_train | |
if self.supervised_train and len(train_list) == 0: | |
raise Exception('Please provide training labels') | |
# assert mode == 'train' | |
self.fix_length = fix_length | |
if self.supervised_train: | |
# self.nclass = len(self.real_label_name) | |
self.nclass = len(self.pseu_label_name) | |
else: | |
self.nclass = len(self.pseu_label_name) | |
self.num_rep = num_rep | |
self.tile_z_dim = tile_z_dim | |
self.use_3_slices = use_3_slices | |
if tile_z_dim > 1 and self.use_3_slices: | |
raise Exception("tile_z_dim and use_3_slices shouldn't be used together") | |
# find scans in the data folder | |
self.nsup = nsup | |
self.base_dir = base_dir | |
self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii") ] | |
self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x))) | |
# experiment configs | |
self.exclude_lbs = exclude_list | |
self.train_list = train_list | |
self.superpix_scale = superpix_scale | |
if len(exclude_list) > 0: | |
print(f'###### Dataset: the following classes has been excluded {exclude_list}######') | |
self.idx_split = idx_split | |
self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold | |
self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg) | |
self.scan_per_load = scan_per_load | |
self.info_by_scan = None | |
self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold | |
self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()], ct_mean=norm_mean, ct_std=norm_std) | |
if self.is_train: | |
if scan_per_load > 0: # if the dataset is too large, only reload a subset in each sub-epoch | |
self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load) | |
else: # load the entire set without a buffer | |
self.pid_curr_load = self.scan_ids | |
elif mode == 'val': | |
self.pid_curr_load = self.scan_ids | |
else: | |
raise Exception | |
self.use_clahe = False | |
if kwargs['use_clahe']: | |
self.use_clahe = True | |
clip_limit = 4.0 if self.img_modality == 'MR' else 2.0 | |
self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(7,7)) | |
self.actual_dataset = self.read_dataset() | |
self.size = len(self.actual_dataset) | |
self.overall_slice_by_cls = self.read_classfiles() | |
print("###### Initial scans loaded: ######") | |
print(self.pid_curr_load) | |
def get_scanids(self, mode, idx_split): | |
""" | |
Load scans by train-test split | |
leaving one additional scan as the support scan. if the last fold, taking scan 0 as the additional one | |
Args: | |
idx_split: index for spliting cross-validation folds | |
""" | |
val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup]) | |
if mode == 'train': | |
return [ ii for ii in self.img_pids if ii not in val_ids ] | |
elif mode == 'val': | |
return val_ids | |
def reload_buffer(self): | |
""" | |
Reload a only portion of the entire dataset, if the dataset is too large | |
1. delete original buffer | |
2. update self.ids_this_batch | |
3. update other internel variables like __len__ | |
""" | |
if self.scan_per_load <= 0: | |
print("We are not using the reload buffer, doing notiong") | |
return -1 | |
del self.actual_dataset | |
del self.info_by_scan | |
self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False ) | |
self.actual_dataset = self.read_dataset() | |
self.size = len(self.actual_dataset) | |
self.update_subclass_lookup() | |
print(f'Loader buffer reloaded with a new size of {self.size} slices') | |
def organize_sample_fids(self): | |
out_list = {} | |
for curr_id in self.scan_ids: | |
curr_dict = {} | |
_img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz') | |
_lb_fid = os.path.join(self.base_dir, f'superpix-{self.superpix_scale}_{curr_id}.nii.gz') | |
_gt_lb_fid = os.path.join(self.base_dir, f'label_{curr_id}.nii.gz') | |
curr_dict["img_fid"] = _img_fid | |
curr_dict["lbs_fid"] = _lb_fid | |
curr_dict["gt_lbs_fid"] = _gt_lb_fid | |
out_list[str(curr_id)] = curr_dict | |
return out_list | |
def read_dataset(self): | |
""" | |
Read images into memory and store them in 2D | |
Build tables for the position of an individual 2D slice in the entire dataset | |
""" | |
out_list = [] | |
self.scan_z_idx = {} | |
self.info_by_scan = {} # meta data of each scan | |
glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset | |
for scan_id, itm in self.img_lb_fids.items(): | |
if scan_id not in self.pid_curr_load: | |
continue | |
img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out | |
# read connected graph of labels | |
if self.use_clahe: | |
# img = nself.clahe.apply(img.astype(np.uint8)) | |
if self.img_modality == 'MR': | |
img = np.stack([((slice - slice.min()) / (slice.max() - slice.min())) * 255 for slice in img], axis=0) | |
img = np.stack([self.clahe.apply(slice.astype(np.uint8)) for slice in img], axis=0) | |
img = img.transpose(1,2,0) | |
self.info_by_scan[scan_id] = _info | |
img = np.float32(img) | |
img = self.norm_func(img) | |
self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])] | |
if self.supervised_train: | |
lb = read_nii_bysitk(itm["gt_lbs_fid"]) | |
else: | |
lb = read_nii_bysitk(itm["lbs_fid"]) | |
lb = lb.transpose(1,2,0) | |
lb = np.int32(lb) | |
# resize img and lb to self.image_size | |
img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR) | |
lb = cv2.resize(lb, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST) | |
# format of slices: [axial_H x axial_W x Z] | |
if self.supervised_train: | |
# remove all images that dont have the training labels | |
del_indices = [i for i in range(img.shape[-1]) if not np.any(np.isin(lb[..., i], self.train_list))] | |
# create an new img and lb without indices in del_indices | |
new_img = np.zeros((img.shape[0], img.shape[1], img.shape[2] - len(del_indices))) | |
new_lb = np.zeros((lb.shape[0], lb.shape[1], lb.shape[2] - len(del_indices))) | |
new_img = img[..., ~np.isin(np.arange(img.shape[-1]), del_indices)] | |
new_lb = lb[..., ~np.isin(np.arange(lb.shape[-1]), del_indices)] | |
img = new_img | |
lb = new_lb | |
a = [i for i in range(img.shape[-1]) if lb[...,i].max() == 0] | |
nframes = img.shape[-1] | |
assert img.shape[-1] == lb.shape[-1] | |
base_idx = img.shape[-1] // 2 # index of the middle slice | |
# re-organize 3D images into 2D slices and record essential information for each slice | |
out_list.append( {"img": img[..., 0: 1], | |
"lb":lb[..., 0: 0 + 1], | |
"sup_max_cls": lb[..., 0: 0 + 1].max(), | |
"is_start": True, | |
"is_end": False, | |
"nframe": nframes, | |
"scan_id": scan_id, | |
"z_id":0, | |
}) | |
self.scan_z_idx[scan_id][0] = glb_idx | |
glb_idx += 1 | |
for ii in range(1, img.shape[-1] - 1): | |
out_list.append( {"img": img[..., ii: ii + 1], | |
"lb":lb[..., ii: ii + 1], | |
"is_start": False, | |
"is_end": False, | |
"sup_max_cls": lb[..., ii: ii + 1].max(), | |
"nframe": nframes, | |
"scan_id": scan_id, | |
"z_id": ii, | |
}) | |
self.scan_z_idx[scan_id][ii] = glb_idx | |
glb_idx += 1 | |
ii += 1 # last slice of a 3D volume | |
out_list.append( {"img": img[..., ii: ii + 1], | |
"lb":lb[..., ii: ii+ 1], | |
"is_start": False, | |
"is_end": True, | |
"sup_max_cls": lb[..., ii: ii + 1].max(), | |
"nframe": nframes, | |
"scan_id": scan_id, | |
"z_id": ii, | |
}) | |
self.scan_z_idx[scan_id][ii] = glb_idx | |
glb_idx += 1 | |
return out_list | |
def read_classfiles(self): | |
""" | |
Load the scan-slice-class indexing file | |
""" | |
with open( os.path.join(self.base_dir, f'.classmap_{self.min_fg}.json') , 'r' ) as fopen: | |
cls_map = json.load( fopen) | |
fopen.close() | |
with open( os.path.join(self.base_dir, '.classmap_1.json') , 'r' ) as fopen: | |
self.tp1_cls_map = json.load( fopen) | |
fopen.close() | |
return cls_map | |
def get_superpixels_similarity(self, sp1, sp2): | |
pass | |
def supcls_pick_binarize(self, super_map, sup_max_cls, bi_val=None, conn_graph=None, img=None): | |
if bi_val is None: | |
# bi_val = np.random.randint(1, sup_max_cls) | |
bi_val = random.choice(list(np.unique(super_map))) | |
if conn_graph is not None and img is not None: | |
# get number of neighbors of bi_val | |
neighbors = conn_graph[bi_val] | |
# pick a random number of neighbors and merge them | |
n_neighbors = np.random.randint(0, len(neighbors)) | |
try: | |
neighbors = random.sample(neighbors, n_neighbors) | |
except TypeError: | |
neighbors = [] | |
# merge neighbors | |
super_map = np.where(np.isin(super_map, neighbors), bi_val, super_map) | |
return np.float32(super_map == bi_val) | |
def supcls_pick(self, super_map): | |
return random.choice(list(np.unique(super_map))) | |
def get_3_slice_adjacent_image(self, image_t, index): | |
curr_dict = self.actual_dataset[index] | |
prev_image = np.zeros_like(image_t) | |
if index > 1 and not curr_dict["is_start"]: | |
prev_dict = self.actual_dataset[index - 1] | |
prev_image = prev_dict["img"] | |
next_image = np.zeros_like(image_t) | |
if index < len(self.actual_dataset) - 1 and not curr_dict["is_end"]: | |
next_dict = self.actual_dataset[index + 1] | |
next_image = next_dict["img"] | |
image_t = np.concatenate([prev_image, image_t, next_image], axis=-1) | |
return image_t | |
def __getitem__(self, index): | |
index = index % len(self.actual_dataset) | |
curr_dict = self.actual_dataset[index] | |
sup_max_cls = curr_dict['sup_max_cls'] | |
if sup_max_cls < 1: | |
return self.__getitem__(index + 1) | |
image_t = curr_dict["img"] | |
label_raw = curr_dict["lb"] | |
if self.use_3_slices: | |
image_t = self.get_3_slice_adjacent_image(image_t, index) | |
for _ex_cls in self.exclude_lbs: | |
if curr_dict["z_id"] in self.tp1_cls_map[self.real_label_name[_ex_cls]][curr_dict["scan_id"]]: # if using setting 1, this slice need to be excluded since it contains label which is supposed to be unseen | |
return self.__getitem__(torch.randint(low = 0, high = self.__len__() - 1, size = (1,))) | |
if self.supervised_train: | |
superpix_label = -1 | |
label_t = np.float32(label_raw) | |
lb_id = random.choice(list(set(np.unique(label_raw)) & set(self.train_list))) | |
label_t[label_t != lb_id] = 0 | |
label_t[label_t == lb_id] = 1 | |
else: | |
superpix_label = self.supcls_pick(label_raw) | |
label_t = np.float32(label_raw == superpix_label) | |
pair_buffer = [] | |
comp = np.concatenate( [image_t, label_t], axis = -1 ) | |
for ii in range(self.num_rep): | |
if self.transforms is not None: | |
img, lb = self.transforms(comp, c_img = image_t.shape[-1], c_label = 1, nclass = self.nclass, is_train = True, use_onehot = False) | |
else: | |
img, lb = comp[:, :, 0:1], comp[:, :, 1:2] | |
# if ii % 2 == 0: | |
# label_raw = lb | |
# lb = lb == superpix_label | |
img = torch.from_numpy( np.transpose( img, (2, 0, 1)) ).float() | |
lb = torch.from_numpy( lb.squeeze(-1)).float() | |
img = img.repeat( [ self.tile_z_dim, 1, 1] ) | |
is_start = curr_dict["is_start"] | |
is_end = curr_dict["is_end"] | |
nframe = np.int32(curr_dict["nframe"]) | |
scan_id = curr_dict["scan_id"] | |
z_id = curr_dict["z_id"] | |
sample = {"image": img, | |
"label":lb, | |
"is_start": is_start, | |
"is_end": is_end, | |
"nframe": nframe, | |
"scan_id": scan_id, | |
"z_id": z_id | |
} | |
# Add auxiliary attributes | |
if self.aux_attrib is not None: | |
for key_prefix in self.aux_attrib: | |
# Process the data sample, create new attributes and save them in a dictionary | |
aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix]) | |
for key_suffix in aux_attrib_val: | |
# one function may create multiple attributes, so we need suffix to distinguish them | |
sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix] | |
pair_buffer.append(sample) | |
support_images = [] | |
support_mask = [] | |
support_class = [] | |
query_images = [] | |
query_labels = [] | |
query_class = [] | |
for idx, itm in enumerate(pair_buffer): | |
if idx % 2 == 0: | |
support_images.append(itm["image"]) | |
support_class.append(1) # pseudolabel class | |
support_mask.append( self.getMaskMedImg( itm["label"], 1, [1] )) | |
else: | |
query_images.append(itm["image"]) | |
query_class.append(1) | |
query_labels.append( itm["label"]) | |
return {'class_ids': [support_class], | |
'support_images': [support_images], # | |
'superpix_label': superpix_label, | |
'superpix_label_raw': label_raw[:,:,0], | |
'support_mask': [support_mask], | |
'query_images': query_images, # | |
'query_labels': query_labels, | |
'scan_id': scan_id, | |
'z_id': z_id, | |
'nframe': nframe, | |
} | |
def __len__(self): | |
""" | |
copy-paste from basic naive dataset configuration | |
""" | |
if self.fix_length != None: | |
assert self.fix_length >= len(self.actual_dataset) | |
return self.fix_length | |
else: | |
return len(self.actual_dataset) | |
def getMaskMedImg(self, label, class_id, class_ids): | |
""" | |
Generate FG/BG mask from the segmentation mask | |
Args: | |
label: semantic mask | |
class_id: semantic class of interest | |
class_ids: all class id in this episode | |
""" | |
fg_mask = torch.where(label == class_id, | |
torch.ones_like(label), torch.zeros_like(label)) | |
bg_mask = torch.where(label != class_id, | |
torch.ones_like(label), torch.zeros_like(label)) | |
for class_id in class_ids: | |
bg_mask[label == class_id] = 0 | |
return {'fg_mask': fg_mask, | |
'bg_mask': bg_mask} | |