LoGoSAM_demo / dataloaders /GenericSuperDatasetv2.py
quandn2003's picture
Upload folder using huggingface_hub
427d150 verified
"""
Dataset for training with pseudolabels
TODO:
1. Merge with manual annotated dataset
2. superpixel_scale -> superpix_config, feed like a dict
"""
import glob
import numpy as np
import dataloaders.augutils as myaug
import torch
import random
import os
import copy
import platform
import json
import re
import cv2
from dataloaders.common import BaseDataset, Subset
from dataloaders.dataset_utils import*
from pdb import set_trace
from util.utils import CircularList
from util.consts import IMG_SIZE
class SuperpixelDataset(BaseDataset):
def __init__(self, which_dataset, base_dir, idx_split, mode, image_size, transforms, scan_per_load, num_rep = 2, min_fg = '', nsup = 1, fix_length = None, tile_z_dim = 3, exclude_list = [], train_list = [], superpix_scale = 'SMALL', norm_mean=None, norm_std=None, supervised_train=False, use_3_slices=False, **kwargs):
"""
Pseudolabel dataset
Args:
which_dataset: name of the dataset to use
base_dir: directory of dataset
idx_split: index of data split as we will do cross validation
mode: 'train', 'val'.
nsup: number of scans used as support. currently idle for superpixel dataset
transforms: data transform (augmentation) function
scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time
num_rep: Number of augmentation applied for a same pseudolabel
tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images
fix_length: fix the length of dataset
exclude_list: Labels to be excluded
superpix_scale: config of superpixels
"""
super(SuperpixelDataset, self).__init__(base_dir)
self.img_modality = DATASET_INFO[which_dataset]['MODALITY']
self.sep = DATASET_INFO[which_dataset]['_SEP']
self.pseu_label_name = DATASET_INFO[which_dataset]['PSEU_LABEL_NAME']
self.real_label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME']
self.image_size = image_size
self.transforms = transforms
self.is_train = True if mode == 'train' else False
self.supervised_train = supervised_train
if self.supervised_train and len(train_list) == 0:
raise Exception('Please provide training labels')
# assert mode == 'train'
self.fix_length = fix_length
if self.supervised_train:
# self.nclass = len(self.real_label_name)
self.nclass = len(self.pseu_label_name)
else:
self.nclass = len(self.pseu_label_name)
self.num_rep = num_rep
self.tile_z_dim = tile_z_dim
self.use_3_slices = use_3_slices
if tile_z_dim > 1 and self.use_3_slices:
raise Exception("tile_z_dim and use_3_slices shouldn't be used together")
# find scans in the data folder
self.nsup = nsup
self.base_dir = base_dir
self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii") ]
self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x)))
# experiment configs
self.exclude_lbs = exclude_list
self.train_list = train_list
self.superpix_scale = superpix_scale
if len(exclude_list) > 0:
print(f'###### Dataset: the following classes has been excluded {exclude_list}######')
self.idx_split = idx_split
self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold
self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg)
self.scan_per_load = scan_per_load
self.info_by_scan = None
self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold
self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()], ct_mean=norm_mean, ct_std=norm_std)
if self.is_train:
if scan_per_load > 0: # if the dataset is too large, only reload a subset in each sub-epoch
self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load)
else: # load the entire set without a buffer
self.pid_curr_load = self.scan_ids
elif mode == 'val':
self.pid_curr_load = self.scan_ids
else:
raise Exception
self.use_clahe = False
if kwargs['use_clahe']:
self.use_clahe = True
clip_limit = 4.0 if self.img_modality == 'MR' else 2.0
self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(7,7))
self.actual_dataset = self.read_dataset()
self.size = len(self.actual_dataset)
self.overall_slice_by_cls = self.read_classfiles()
print("###### Initial scans loaded: ######")
print(self.pid_curr_load)
def get_scanids(self, mode, idx_split):
"""
Load scans by train-test split
leaving one additional scan as the support scan. if the last fold, taking scan 0 as the additional one
Args:
idx_split: index for spliting cross-validation folds
"""
val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup])
if mode == 'train':
return [ ii for ii in self.img_pids if ii not in val_ids ]
elif mode == 'val':
return val_ids
def reload_buffer(self):
"""
Reload a only portion of the entire dataset, if the dataset is too large
1. delete original buffer
2. update self.ids_this_batch
3. update other internel variables like __len__
"""
if self.scan_per_load <= 0:
print("We are not using the reload buffer, doing notiong")
return -1
del self.actual_dataset
del self.info_by_scan
self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False )
self.actual_dataset = self.read_dataset()
self.size = len(self.actual_dataset)
self.update_subclass_lookup()
print(f'Loader buffer reloaded with a new size of {self.size} slices')
def organize_sample_fids(self):
out_list = {}
for curr_id in self.scan_ids:
curr_dict = {}
_img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz')
_lb_fid = os.path.join(self.base_dir, f'superpix-{self.superpix_scale}_{curr_id}.nii.gz')
_gt_lb_fid = os.path.join(self.base_dir, f'label_{curr_id}.nii.gz')
curr_dict["img_fid"] = _img_fid
curr_dict["lbs_fid"] = _lb_fid
curr_dict["gt_lbs_fid"] = _gt_lb_fid
out_list[str(curr_id)] = curr_dict
return out_list
def read_dataset(self):
"""
Read images into memory and store them in 2D
Build tables for the position of an individual 2D slice in the entire dataset
"""
out_list = []
self.scan_z_idx = {}
self.info_by_scan = {} # meta data of each scan
glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset
for scan_id, itm in self.img_lb_fids.items():
if scan_id not in self.pid_curr_load:
continue
img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out
# read connected graph of labels
if self.use_clahe:
# img = nself.clahe.apply(img.astype(np.uint8))
if self.img_modality == 'MR':
img = np.stack([((slice - slice.min()) / (slice.max() - slice.min())) * 255 for slice in img], axis=0)
img = np.stack([self.clahe.apply(slice.astype(np.uint8)) for slice in img], axis=0)
img = img.transpose(1,2,0)
self.info_by_scan[scan_id] = _info
img = np.float32(img)
img = self.norm_func(img)
self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])]
if self.supervised_train:
lb = read_nii_bysitk(itm["gt_lbs_fid"])
else:
lb = read_nii_bysitk(itm["lbs_fid"])
lb = lb.transpose(1,2,0)
lb = np.int32(lb)
# resize img and lb to self.image_size
img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
lb = cv2.resize(lb, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)
# format of slices: [axial_H x axial_W x Z]
if self.supervised_train:
# remove all images that dont have the training labels
del_indices = [i for i in range(img.shape[-1]) if not np.any(np.isin(lb[..., i], self.train_list))]
# create an new img and lb without indices in del_indices
new_img = np.zeros((img.shape[0], img.shape[1], img.shape[2] - len(del_indices)))
new_lb = np.zeros((lb.shape[0], lb.shape[1], lb.shape[2] - len(del_indices)))
new_img = img[..., ~np.isin(np.arange(img.shape[-1]), del_indices)]
new_lb = lb[..., ~np.isin(np.arange(lb.shape[-1]), del_indices)]
img = new_img
lb = new_lb
a = [i for i in range(img.shape[-1]) if lb[...,i].max() == 0]
nframes = img.shape[-1]
assert img.shape[-1] == lb.shape[-1]
base_idx = img.shape[-1] // 2 # index of the middle slice
# re-organize 3D images into 2D slices and record essential information for each slice
out_list.append( {"img": img[..., 0: 1],
"lb":lb[..., 0: 0 + 1],
"sup_max_cls": lb[..., 0: 0 + 1].max(),
"is_start": True,
"is_end": False,
"nframe": nframes,
"scan_id": scan_id,
"z_id":0,
})
self.scan_z_idx[scan_id][0] = glb_idx
glb_idx += 1
for ii in range(1, img.shape[-1] - 1):
out_list.append( {"img": img[..., ii: ii + 1],
"lb":lb[..., ii: ii + 1],
"is_start": False,
"is_end": False,
"sup_max_cls": lb[..., ii: ii + 1].max(),
"nframe": nframes,
"scan_id": scan_id,
"z_id": ii,
})
self.scan_z_idx[scan_id][ii] = glb_idx
glb_idx += 1
ii += 1 # last slice of a 3D volume
out_list.append( {"img": img[..., ii: ii + 1],
"lb":lb[..., ii: ii+ 1],
"is_start": False,
"is_end": True,
"sup_max_cls": lb[..., ii: ii + 1].max(),
"nframe": nframes,
"scan_id": scan_id,
"z_id": ii,
})
self.scan_z_idx[scan_id][ii] = glb_idx
glb_idx += 1
return out_list
def read_classfiles(self):
"""
Load the scan-slice-class indexing file
"""
with open( os.path.join(self.base_dir, f'.classmap_{self.min_fg}.json') , 'r' ) as fopen:
cls_map = json.load( fopen)
fopen.close()
with open( os.path.join(self.base_dir, '.classmap_1.json') , 'r' ) as fopen:
self.tp1_cls_map = json.load( fopen)
fopen.close()
return cls_map
def get_superpixels_similarity(self, sp1, sp2):
pass
def supcls_pick_binarize(self, super_map, sup_max_cls, bi_val=None, conn_graph=None, img=None):
if bi_val is None:
# bi_val = np.random.randint(1, sup_max_cls)
bi_val = random.choice(list(np.unique(super_map)))
if conn_graph is not None and img is not None:
# get number of neighbors of bi_val
neighbors = conn_graph[bi_val]
# pick a random number of neighbors and merge them
n_neighbors = np.random.randint(0, len(neighbors))
try:
neighbors = random.sample(neighbors, n_neighbors)
except TypeError:
neighbors = []
# merge neighbors
super_map = np.where(np.isin(super_map, neighbors), bi_val, super_map)
return np.float32(super_map == bi_val)
def supcls_pick(self, super_map):
return random.choice(list(np.unique(super_map)))
def get_3_slice_adjacent_image(self, image_t, index):
curr_dict = self.actual_dataset[index]
prev_image = np.zeros_like(image_t)
if index > 1 and not curr_dict["is_start"]:
prev_dict = self.actual_dataset[index - 1]
prev_image = prev_dict["img"]
next_image = np.zeros_like(image_t)
if index < len(self.actual_dataset) - 1 and not curr_dict["is_end"]:
next_dict = self.actual_dataset[index + 1]
next_image = next_dict["img"]
image_t = np.concatenate([prev_image, image_t, next_image], axis=-1)
return image_t
def __getitem__(self, index):
index = index % len(self.actual_dataset)
curr_dict = self.actual_dataset[index]
sup_max_cls = curr_dict['sup_max_cls']
if sup_max_cls < 1:
return self.__getitem__(index + 1)
image_t = curr_dict["img"]
label_raw = curr_dict["lb"]
if self.use_3_slices:
image_t = self.get_3_slice_adjacent_image(image_t, index)
for _ex_cls in self.exclude_lbs:
if curr_dict["z_id"] in self.tp1_cls_map[self.real_label_name[_ex_cls]][curr_dict["scan_id"]]: # if using setting 1, this slice need to be excluded since it contains label which is supposed to be unseen
return self.__getitem__(torch.randint(low = 0, high = self.__len__() - 1, size = (1,)))
if self.supervised_train:
superpix_label = -1
label_t = np.float32(label_raw)
lb_id = random.choice(list(set(np.unique(label_raw)) & set(self.train_list)))
label_t[label_t != lb_id] = 0
label_t[label_t == lb_id] = 1
else:
superpix_label = self.supcls_pick(label_raw)
label_t = np.float32(label_raw == superpix_label)
pair_buffer = []
comp = np.concatenate( [image_t, label_t], axis = -1 )
for ii in range(self.num_rep):
if self.transforms is not None:
img, lb = self.transforms(comp, c_img = image_t.shape[-1], c_label = 1, nclass = self.nclass, is_train = True, use_onehot = False)
else:
img, lb = comp[:, :, 0:1], comp[:, :, 1:2]
# if ii % 2 == 0:
# label_raw = lb
# lb = lb == superpix_label
img = torch.from_numpy( np.transpose( img, (2, 0, 1)) ).float()
lb = torch.from_numpy( lb.squeeze(-1)).float()
img = img.repeat( [ self.tile_z_dim, 1, 1] )
is_start = curr_dict["is_start"]
is_end = curr_dict["is_end"]
nframe = np.int32(curr_dict["nframe"])
scan_id = curr_dict["scan_id"]
z_id = curr_dict["z_id"]
sample = {"image": img,
"label":lb,
"is_start": is_start,
"is_end": is_end,
"nframe": nframe,
"scan_id": scan_id,
"z_id": z_id
}
# Add auxiliary attributes
if self.aux_attrib is not None:
for key_prefix in self.aux_attrib:
# Process the data sample, create new attributes and save them in a dictionary
aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix])
for key_suffix in aux_attrib_val:
# one function may create multiple attributes, so we need suffix to distinguish them
sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix]
pair_buffer.append(sample)
support_images = []
support_mask = []
support_class = []
query_images = []
query_labels = []
query_class = []
for idx, itm in enumerate(pair_buffer):
if idx % 2 == 0:
support_images.append(itm["image"])
support_class.append(1) # pseudolabel class
support_mask.append( self.getMaskMedImg( itm["label"], 1, [1] ))
else:
query_images.append(itm["image"])
query_class.append(1)
query_labels.append( itm["label"])
return {'class_ids': [support_class],
'support_images': [support_images], #
'superpix_label': superpix_label,
'superpix_label_raw': label_raw[:,:,0],
'support_mask': [support_mask],
'query_images': query_images, #
'query_labels': query_labels,
'scan_id': scan_id,
'z_id': z_id,
'nframe': nframe,
}
def __len__(self):
"""
copy-paste from basic naive dataset configuration
"""
if self.fix_length != None:
assert self.fix_length >= len(self.actual_dataset)
return self.fix_length
else:
return len(self.actual_dataset)
def getMaskMedImg(self, label, class_id, class_ids):
"""
Generate FG/BG mask from the segmentation mask
Args:
label: semantic mask
class_id: semantic class of interest
class_ids: all class id in this episode
"""
fg_mask = torch.where(label == class_id,
torch.ones_like(label), torch.zeros_like(label))
bg_mask = torch.where(label != class_id,
torch.ones_like(label), torch.zeros_like(label))
for class_id in class_ids:
bg_mask[label == class_id] = 0
return {'fg_mask': fg_mask,
'bg_mask': bg_mask}