""" Dataset for training with pseudolabels TODO: 1. Merge with manual annotated dataset 2. superpixel_scale -> superpix_config, feed like a dict """ import glob import numpy as np import dataloaders.augutils as myaug import torch import random import os import copy import platform import json import re import cv2 from dataloaders.common import BaseDataset, Subset from dataloaders.dataset_utils import* from pdb import set_trace from util.utils import CircularList from util.consts import IMG_SIZE class SuperpixelDataset(BaseDataset): def __init__(self, which_dataset, base_dir, idx_split, mode, image_size, transforms, scan_per_load, num_rep = 2, min_fg = '', nsup = 1, fix_length = None, tile_z_dim = 3, exclude_list = [], train_list = [], superpix_scale = 'SMALL', norm_mean=None, norm_std=None, supervised_train=False, use_3_slices=False, **kwargs): """ Pseudolabel dataset Args: which_dataset: name of the dataset to use base_dir: directory of dataset idx_split: index of data split as we will do cross validation mode: 'train', 'val'. nsup: number of scans used as support. currently idle for superpixel dataset transforms: data transform (augmentation) function scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time num_rep: Number of augmentation applied for a same pseudolabel tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images fix_length: fix the length of dataset exclude_list: Labels to be excluded superpix_scale: config of superpixels """ super(SuperpixelDataset, self).__init__(base_dir) self.img_modality = DATASET_INFO[which_dataset]['MODALITY'] self.sep = DATASET_INFO[which_dataset]['_SEP'] self.pseu_label_name = DATASET_INFO[which_dataset]['PSEU_LABEL_NAME'] self.real_label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME'] self.image_size = image_size self.transforms = transforms self.is_train = True if mode == 'train' else False self.supervised_train = supervised_train if self.supervised_train and len(train_list) == 0: raise Exception('Please provide training labels') # assert mode == 'train' self.fix_length = fix_length if self.supervised_train: # self.nclass = len(self.real_label_name) self.nclass = len(self.pseu_label_name) else: self.nclass = len(self.pseu_label_name) self.num_rep = num_rep self.tile_z_dim = tile_z_dim self.use_3_slices = use_3_slices if tile_z_dim > 1 and self.use_3_slices: raise Exception("tile_z_dim and use_3_slices shouldn't be used together") # find scans in the data folder self.nsup = nsup self.base_dir = base_dir self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii") ] self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x))) # experiment configs self.exclude_lbs = exclude_list self.train_list = train_list self.superpix_scale = superpix_scale if len(exclude_list) > 0: print(f'###### Dataset: the following classes has been excluded {exclude_list}######') self.idx_split = idx_split self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg) self.scan_per_load = scan_per_load self.info_by_scan = None self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()], ct_mean=norm_mean, ct_std=norm_std) if self.is_train: if scan_per_load > 0: # if the dataset is too large, only reload a subset in each sub-epoch self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load) else: # load the entire set without a buffer self.pid_curr_load = self.scan_ids elif mode == 'val': self.pid_curr_load = self.scan_ids else: raise Exception self.use_clahe = False if kwargs['use_clahe']: self.use_clahe = True clip_limit = 4.0 if self.img_modality == 'MR' else 2.0 self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(7,7)) self.actual_dataset = self.read_dataset() self.size = len(self.actual_dataset) self.overall_slice_by_cls = self.read_classfiles() print("###### Initial scans loaded: ######") print(self.pid_curr_load) def get_scanids(self, mode, idx_split): """ Load scans by train-test split leaving one additional scan as the support scan. if the last fold, taking scan 0 as the additional one Args: idx_split: index for spliting cross-validation folds """ val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup]) if mode == 'train': return [ ii for ii in self.img_pids if ii not in val_ids ] elif mode == 'val': return val_ids def reload_buffer(self): """ Reload a only portion of the entire dataset, if the dataset is too large 1. delete original buffer 2. update self.ids_this_batch 3. update other internel variables like __len__ """ if self.scan_per_load <= 0: print("We are not using the reload buffer, doing notiong") return -1 del self.actual_dataset del self.info_by_scan self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False ) self.actual_dataset = self.read_dataset() self.size = len(self.actual_dataset) self.update_subclass_lookup() print(f'Loader buffer reloaded with a new size of {self.size} slices') def organize_sample_fids(self): out_list = {} for curr_id in self.scan_ids: curr_dict = {} _img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz') _lb_fid = os.path.join(self.base_dir, f'superpix-{self.superpix_scale}_{curr_id}.nii.gz') _gt_lb_fid = os.path.join(self.base_dir, f'label_{curr_id}.nii.gz') curr_dict["img_fid"] = _img_fid curr_dict["lbs_fid"] = _lb_fid curr_dict["gt_lbs_fid"] = _gt_lb_fid out_list[str(curr_id)] = curr_dict return out_list def read_dataset(self): """ Read images into memory and store them in 2D Build tables for the position of an individual 2D slice in the entire dataset """ out_list = [] self.scan_z_idx = {} self.info_by_scan = {} # meta data of each scan glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset for scan_id, itm in self.img_lb_fids.items(): if scan_id not in self.pid_curr_load: continue img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out # read connected graph of labels if self.use_clahe: # img = nself.clahe.apply(img.astype(np.uint8)) if self.img_modality == 'MR': img = np.stack([((slice - slice.min()) / (slice.max() - slice.min())) * 255 for slice in img], axis=0) img = np.stack([self.clahe.apply(slice.astype(np.uint8)) for slice in img], axis=0) img = img.transpose(1,2,0) self.info_by_scan[scan_id] = _info img = np.float32(img) img = self.norm_func(img) self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])] if self.supervised_train: lb = read_nii_bysitk(itm["gt_lbs_fid"]) else: lb = read_nii_bysitk(itm["lbs_fid"]) lb = lb.transpose(1,2,0) lb = np.int32(lb) # resize img and lb to self.image_size img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR) lb = cv2.resize(lb, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST) # format of slices: [axial_H x axial_W x Z] if self.supervised_train: # remove all images that dont have the training labels del_indices = [i for i in range(img.shape[-1]) if not np.any(np.isin(lb[..., i], self.train_list))] # create an new img and lb without indices in del_indices new_img = np.zeros((img.shape[0], img.shape[1], img.shape[2] - len(del_indices))) new_lb = np.zeros((lb.shape[0], lb.shape[1], lb.shape[2] - len(del_indices))) new_img = img[..., ~np.isin(np.arange(img.shape[-1]), del_indices)] new_lb = lb[..., ~np.isin(np.arange(lb.shape[-1]), del_indices)] img = new_img lb = new_lb a = [i for i in range(img.shape[-1]) if lb[...,i].max() == 0] nframes = img.shape[-1] assert img.shape[-1] == lb.shape[-1] base_idx = img.shape[-1] // 2 # index of the middle slice # re-organize 3D images into 2D slices and record essential information for each slice out_list.append( {"img": img[..., 0: 1], "lb":lb[..., 0: 0 + 1], "sup_max_cls": lb[..., 0: 0 + 1].max(), "is_start": True, "is_end": False, "nframe": nframes, "scan_id": scan_id, "z_id":0, }) self.scan_z_idx[scan_id][0] = glb_idx glb_idx += 1 for ii in range(1, img.shape[-1] - 1): out_list.append( {"img": img[..., ii: ii + 1], "lb":lb[..., ii: ii + 1], "is_start": False, "is_end": False, "sup_max_cls": lb[..., ii: ii + 1].max(), "nframe": nframes, "scan_id": scan_id, "z_id": ii, }) self.scan_z_idx[scan_id][ii] = glb_idx glb_idx += 1 ii += 1 # last slice of a 3D volume out_list.append( {"img": img[..., ii: ii + 1], "lb":lb[..., ii: ii+ 1], "is_start": False, "is_end": True, "sup_max_cls": lb[..., ii: ii + 1].max(), "nframe": nframes, "scan_id": scan_id, "z_id": ii, }) self.scan_z_idx[scan_id][ii] = glb_idx glb_idx += 1 return out_list def read_classfiles(self): """ Load the scan-slice-class indexing file """ with open( os.path.join(self.base_dir, f'.classmap_{self.min_fg}.json') , 'r' ) as fopen: cls_map = json.load( fopen) fopen.close() with open( os.path.join(self.base_dir, '.classmap_1.json') , 'r' ) as fopen: self.tp1_cls_map = json.load( fopen) fopen.close() return cls_map def get_superpixels_similarity(self, sp1, sp2): pass def supcls_pick_binarize(self, super_map, sup_max_cls, bi_val=None, conn_graph=None, img=None): if bi_val is None: # bi_val = np.random.randint(1, sup_max_cls) bi_val = random.choice(list(np.unique(super_map))) if conn_graph is not None and img is not None: # get number of neighbors of bi_val neighbors = conn_graph[bi_val] # pick a random number of neighbors and merge them n_neighbors = np.random.randint(0, len(neighbors)) try: neighbors = random.sample(neighbors, n_neighbors) except TypeError: neighbors = [] # merge neighbors super_map = np.where(np.isin(super_map, neighbors), bi_val, super_map) return np.float32(super_map == bi_val) def supcls_pick(self, super_map): return random.choice(list(np.unique(super_map))) def get_3_slice_adjacent_image(self, image_t, index): curr_dict = self.actual_dataset[index] prev_image = np.zeros_like(image_t) if index > 1 and not curr_dict["is_start"]: prev_dict = self.actual_dataset[index - 1] prev_image = prev_dict["img"] next_image = np.zeros_like(image_t) if index < len(self.actual_dataset) - 1 and not curr_dict["is_end"]: next_dict = self.actual_dataset[index + 1] next_image = next_dict["img"] image_t = np.concatenate([prev_image, image_t, next_image], axis=-1) return image_t def __getitem__(self, index): index = index % len(self.actual_dataset) curr_dict = self.actual_dataset[index] sup_max_cls = curr_dict['sup_max_cls'] if sup_max_cls < 1: return self.__getitem__(index + 1) image_t = curr_dict["img"] label_raw = curr_dict["lb"] if self.use_3_slices: image_t = self.get_3_slice_adjacent_image(image_t, index) for _ex_cls in self.exclude_lbs: if curr_dict["z_id"] in self.tp1_cls_map[self.real_label_name[_ex_cls]][curr_dict["scan_id"]]: # if using setting 1, this slice need to be excluded since it contains label which is supposed to be unseen return self.__getitem__(torch.randint(low = 0, high = self.__len__() - 1, size = (1,))) if self.supervised_train: superpix_label = -1 label_t = np.float32(label_raw) lb_id = random.choice(list(set(np.unique(label_raw)) & set(self.train_list))) label_t[label_t != lb_id] = 0 label_t[label_t == lb_id] = 1 else: superpix_label = self.supcls_pick(label_raw) label_t = np.float32(label_raw == superpix_label) pair_buffer = [] comp = np.concatenate( [image_t, label_t], axis = -1 ) for ii in range(self.num_rep): if self.transforms is not None: img, lb = self.transforms(comp, c_img = image_t.shape[-1], c_label = 1, nclass = self.nclass, is_train = True, use_onehot = False) else: img, lb = comp[:, :, 0:1], comp[:, :, 1:2] # if ii % 2 == 0: # label_raw = lb # lb = lb == superpix_label img = torch.from_numpy( np.transpose( img, (2, 0, 1)) ).float() lb = torch.from_numpy( lb.squeeze(-1)).float() img = img.repeat( [ self.tile_z_dim, 1, 1] ) is_start = curr_dict["is_start"] is_end = curr_dict["is_end"] nframe = np.int32(curr_dict["nframe"]) scan_id = curr_dict["scan_id"] z_id = curr_dict["z_id"] sample = {"image": img, "label":lb, "is_start": is_start, "is_end": is_end, "nframe": nframe, "scan_id": scan_id, "z_id": z_id } # Add auxiliary attributes if self.aux_attrib is not None: for key_prefix in self.aux_attrib: # Process the data sample, create new attributes and save them in a dictionary aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix]) for key_suffix in aux_attrib_val: # one function may create multiple attributes, so we need suffix to distinguish them sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix] pair_buffer.append(sample) support_images = [] support_mask = [] support_class = [] query_images = [] query_labels = [] query_class = [] for idx, itm in enumerate(pair_buffer): if idx % 2 == 0: support_images.append(itm["image"]) support_class.append(1) # pseudolabel class support_mask.append( self.getMaskMedImg( itm["label"], 1, [1] )) else: query_images.append(itm["image"]) query_class.append(1) query_labels.append( itm["label"]) return {'class_ids': [support_class], 'support_images': [support_images], # 'superpix_label': superpix_label, 'superpix_label_raw': label_raw[:,:,0], 'support_mask': [support_mask], 'query_images': query_images, # 'query_labels': query_labels, 'scan_id': scan_id, 'z_id': z_id, 'nframe': nframe, } def __len__(self): """ copy-paste from basic naive dataset configuration """ if self.fix_length != None: assert self.fix_length >= len(self.actual_dataset) return self.fix_length else: return len(self.actual_dataset) def getMaskMedImg(self, label, class_id, class_ids): """ Generate FG/BG mask from the segmentation mask Args: label: semantic mask class_id: semantic class of interest class_ids: all class id in this episode """ fg_mask = torch.where(label == class_id, torch.ones_like(label), torch.zeros_like(label)) bg_mask = torch.where(label != class_id, torch.ones_like(label), torch.zeros_like(label)) for class_id in class_ids: bg_mask[label == class_id] = 0 return {'fg_mask': fg_mask, 'bg_mask': bg_mask}