Spaces:
Sleeping
Sleeping
""" | |
Manually labeled dataset | |
TODO: | |
1. Merge with superpixel dataset | |
""" | |
import glob | |
import numpy as np | |
import dataloaders.augutils as myaug | |
import torch | |
import random | |
import os | |
import copy | |
import platform | |
import json | |
import re | |
import cv2 | |
from dataloaders.common import BaseDataset, Subset, ValidationDataset | |
# from common import BaseDataset, Subset | |
from dataloaders.dataset_utils import* | |
from pdb import set_trace | |
from util.utils import CircularList | |
from util.consts import IMG_SIZE | |
MODE_DEFAULT = "default" | |
MODE_FULL_SCAN = "full_scan" | |
class ManualAnnoDataset(BaseDataset): | |
def __init__(self, which_dataset, base_dir, idx_split, mode, image_size, transforms, scan_per_load, min_fg = '', fix_length = None, tile_z_dim = 3, nsup = 1, exclude_list = [], extern_normalize_func = None, **kwargs): | |
""" | |
Manually labeled dataset | |
Args: | |
which_dataset: name of the dataset to use | |
base_dir: directory of dataset | |
idx_split: index of data split as we will do cross validation | |
mode: 'train', 'val'. | |
transforms: data transform (augmentation) function | |
min_fg: minimum number of positive pixels in a 2D slice, mainly for stablize training when trained on manually labeled dataset | |
scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time | |
tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images | |
nsup: number of support scans | |
fix_length: fix the length of dataset | |
exclude_list: Labels to be excluded | |
extern_normalize_function: normalization function used for data pre-processing | |
""" | |
super(ManualAnnoDataset, self).__init__(base_dir) | |
self.img_modality = DATASET_INFO[which_dataset]['MODALITY'] | |
self.sep = DATASET_INFO[which_dataset]['_SEP'] | |
self.label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME'] | |
self.image_size = image_size | |
self.transforms = transforms | |
self.is_train = True if mode == 'train' else False | |
self.phase = mode | |
self.fix_length = fix_length | |
self.all_label_names = self.label_name | |
self.nclass = len(self.label_name) | |
self.tile_z_dim = tile_z_dim | |
self.base_dir = base_dir | |
self.nsup = nsup | |
self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii") ] | |
self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x))) # make it circular for the ease of spliting folds | |
if 'use_clahe' not in kwargs: | |
self.use_clahe = False | |
else: | |
self.use_clahe = kwargs['use_clahe'] | |
if self.use_clahe: | |
self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(7,7)) | |
self.use_3_slices = kwargs["use_3_slices"] if 'use_3_slices' in kwargs else False | |
if self.use_3_slices: | |
self.tile_z_dim=1 | |
self.get_item_mode = MODE_DEFAULT | |
if 'get_item_mode' in kwargs: | |
self.get_item_mode = kwargs['get_item_mode'] | |
self.exclude_lbs = exclude_list | |
if len(exclude_list) > 0: | |
print(f'###### Dataset: the following classes has been excluded {exclude_list}######') | |
self.idx_split = idx_split | |
self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold | |
self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg) | |
self.scan_per_load = scan_per_load | |
self.info_by_scan = None | |
self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold | |
if extern_normalize_func is not None: # helps to keep consistent between training and testing dataset. | |
self.norm_func = extern_normalize_func | |
print(f'###### Dataset: using external normalization statistics ######') | |
else: | |
self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()]) | |
print(f'###### Dataset: using normalization statistics calculated from loaded data ######') | |
if self.is_train: | |
if scan_per_load > 0: # buffer needed | |
self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load) | |
else: # load the entire set without a buffer | |
self.pid_curr_load = self.scan_ids | |
elif mode == 'val': | |
self.pid_curr_load = self.scan_ids | |
self.potential_support_sid = [] | |
else: | |
raise Exception | |
self.actual_dataset = self.read_dataset() | |
self.size = len(self.actual_dataset) | |
self.overall_slice_by_cls = self.read_classfiles() | |
self.update_subclass_lookup() | |
def get_scanids(self, mode, idx_split): | |
val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup]) | |
self.potential_support_sid = val_ids[-self.nsup:] # this is actual file scan id, not index | |
if mode == 'train': | |
return [ ii for ii in self.img_pids if ii not in val_ids ] | |
elif mode == 'val': | |
return val_ids | |
def reload_buffer(self): | |
""" | |
Reload a portion of the entire dataset, if the dataset is too large | |
1. delete original buffer | |
2. update self.ids_this_batch | |
3. update other internel variables like __len__ | |
""" | |
if self.scan_per_load <= 0: | |
print("We are not using the reload buffer, doing notiong") | |
return -1 | |
del self.actual_dataset | |
del self.info_by_scan | |
self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False ) | |
self.actual_dataset = self.read_dataset() | |
self.size = len(self.actual_dataset) | |
self.update_subclass_lookup() | |
print(f'Loader buffer reloaded with a new size of {self.size} slices') | |
def organize_sample_fids(self): | |
out_list = {} | |
for curr_id in self.scan_ids: | |
curr_dict = {} | |
_img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz') | |
_lb_fid = os.path.join(self.base_dir, f'label_{curr_id}.nii.gz') | |
curr_dict["img_fid"] = _img_fid | |
curr_dict["lbs_fid"] = _lb_fid | |
out_list[str(curr_id)] = curr_dict | |
return out_list | |
def read_dataset(self): | |
""" | |
Build index pointers to individual slices | |
Also keep a look-up table from scan_id, slice to index | |
""" | |
out_list = [] | |
self.scan_z_idx = {} | |
self.info_by_scan = {} # meta data of each scan | |
glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset | |
for scan_id, itm in self.img_lb_fids.items(): | |
if scan_id not in self.pid_curr_load: | |
continue | |
img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out | |
img = img.transpose(1,2,0) | |
self.info_by_scan[scan_id] = _info | |
if self.use_clahe: | |
img = np.stack([self.clahe.apply(slice.astype(np.uint8)) for slice in img], axis=0) | |
img = np.float32(img) | |
img = self.norm_func(img) | |
self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])] | |
lb = read_nii_bysitk(itm["lbs_fid"]) | |
lb = lb.transpose(1,2,0) | |
lb = np.float32(lb) | |
img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR) | |
lb = cv2.resize(lb, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST) | |
assert img.shape[-1] == lb.shape[-1] | |
base_idx = img.shape[-1] // 2 # index of the middle slice | |
# write the beginning frame | |
out_list.append( {"img": img[..., 0: 1], | |
"lb":lb[..., 0: 0 + 1], | |
"is_start": True, | |
"is_end": False, | |
"nframe": img.shape[-1], | |
"scan_id": scan_id, | |
"z_id":0}) | |
self.scan_z_idx[scan_id][0] = glb_idx | |
glb_idx += 1 | |
for ii in range(1, img.shape[-1] - 1): | |
out_list.append( {"img": img[..., ii: ii + 1], | |
"lb":lb[..., ii: ii + 1], | |
"is_start": False, | |
"is_end": False, | |
"nframe": -1, | |
"scan_id": scan_id, | |
"z_id": ii | |
}) | |
self.scan_z_idx[scan_id][ii] = glb_idx | |
glb_idx += 1 | |
ii += 1 # last frame, note the is_end flag | |
out_list.append( {"img": img[..., ii: ii + 1], | |
"lb":lb[..., ii: ii+ 1], | |
"is_start": False, | |
"is_end": True, | |
"nframe": -1, | |
"scan_id": scan_id, | |
"z_id": ii | |
}) | |
self.scan_z_idx[scan_id][ii] = glb_idx | |
glb_idx += 1 | |
return out_list | |
def read_classfiles(self): | |
with open( os.path.join(self.base_dir, f'.classmap_{self.min_fg}.json') , 'r' ) as fopen: | |
cls_map = json.load( fopen) | |
fopen.close() | |
with open( os.path.join(self.base_dir, '.classmap_1.json') , 'r' ) as fopen: | |
self.tp1_cls_map = json.load( fopen) | |
fopen.close() | |
return cls_map | |
def __getitem__(self, index): | |
if self.get_item_mode == MODE_DEFAULT: | |
return self.__getitem_default__(index) | |
elif self.get_item_mode == MODE_FULL_SCAN: | |
return self.__get_ct_scan___(index) | |
else: | |
raise NotImplementedError("Unknown mode") | |
def __get_ct_scan___(self, index): | |
scan_n = index % len(self.scan_z_idx) | |
scan_id = list(self.scan_z_idx.keys())[scan_n] | |
scan_slices = self.scan_z_idx[scan_id] | |
scan_imgs = np.concatenate([self.actual_dataset[_idx]["img"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1) | |
scan_lbs = np.concatenate([self.actual_dataset[_idx]["lb"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1) | |
scan_imgs = np.float32(scan_imgs) | |
scan_lbs = np.float32(scan_lbs) | |
scan_imgs = torch.from_numpy(scan_imgs).unsqueeze(0) | |
scan_lbs = torch.from_numpy(scan_lbs) | |
if self.tile_z_dim: | |
scan_imgs = scan_imgs.repeat(self.tile_z_dim, 1, 1, 1) | |
assert scan_imgs.ndimension() == 4, f'actual dim {scan_imgs.ndimension()}' | |
# # reshape to C, D, H, W | |
# scan_imgs = scan_imgs.permute(1, 0, 2, 3) | |
# scan_lbs = scan_lbs.permute(1, 0, 2, 3) | |
sample = {"image": scan_imgs, | |
"label":scan_lbs, | |
"scan_id": scan_id, | |
} | |
return sample | |
def get_3_slice_adjacent_image(self, image_t, index): | |
curr_dict = self.actual_dataset[index] | |
prev_image = np.zeros_like(image_t) | |
if index > 1 and not curr_dict["is_start"]: | |
prev_dict = self.actual_dataset[index - 1] | |
prev_image = prev_dict["img"] | |
next_image = np.zeros_like(image_t) | |
if index < len(self.actual_dataset) - 1 and not curr_dict["is_end"]: | |
next_dict = self.actual_dataset[index + 1] | |
next_image = next_dict["img"] | |
image_t = np.concatenate([prev_image, image_t, next_image], axis=-1) | |
return image_t | |
def __getitem_default__(self, index): | |
index = index % len(self.actual_dataset) | |
curr_dict = self.actual_dataset[index] | |
if self.is_train: | |
if len(self.exclude_lbs) > 0: | |
for _ex_cls in self.exclude_lbs: | |
if curr_dict["z_id"] in self.tp1_cls_map[self.label_name[_ex_cls]][curr_dict["scan_id"]]: # this slice need to be excluded since it contains label which is supposed to be unseen | |
return self.__getitem__(index + torch.randint(low = 0, high = self.__len__() - 1, size = (1,))) | |
comp = np.concatenate( [curr_dict["img"], curr_dict["lb"]], axis = -1 ) | |
if self.transforms is not None: | |
img, lb = self.transforms(comp, c_img = 1, c_label = 1, nclass = self.nclass, use_onehot = False) | |
else: | |
raise Exception("No transform function is provided") | |
else: | |
img = curr_dict['img'] | |
lb = curr_dict['lb'] | |
img = np.float32(img) | |
lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure | |
if self.use_3_slices: | |
img = self.get_3_slice_adjacent_image(img, index) | |
img = torch.from_numpy( np.transpose(img, (2, 0, 1)) ) | |
lb = torch.from_numpy( lb) | |
if self.tile_z_dim: | |
img = img.repeat( [ self.tile_z_dim, 1, 1] ) | |
assert img.ndimension() == 3, f'actual dim {img.ndimension()}' | |
is_start = curr_dict["is_start"] | |
is_end = curr_dict["is_end"] | |
nframe = np.int32(curr_dict["nframe"]) | |
scan_id = curr_dict["scan_id"] | |
z_id = curr_dict["z_id"] | |
sample = {"image": img, | |
"label":lb, | |
"is_start": is_start, | |
"is_end": is_end, | |
"nframe": nframe, | |
"scan_id": scan_id, | |
"z_id": z_id | |
} | |
# Add auxiliary attributes | |
if self.aux_attrib is not None: | |
for key_prefix in self.aux_attrib: | |
# Process the data sample, create new attributes and save them in a dictionary | |
aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix]) | |
for key_suffix in aux_attrib_val: | |
# one function may create multiple attributes, so we need suffix to distinguish them | |
sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix] | |
return sample | |
def __len__(self): | |
""" | |
copy-paste from basic naive dataset configuration | |
""" | |
if self.get_item_mode == MODE_FULL_SCAN: | |
return len(self.scan_z_idx) | |
if self.fix_length != None: | |
assert self.fix_length >= len(self.actual_dataset) | |
return self.fix_length | |
else: | |
return len(self.actual_dataset) | |
def update_subclass_lookup(self): | |
""" | |
Updating the class-slice indexing list | |
Args: | |
[internal] overall_slice_by_cls: | |
{ | |
class1: {pid1: [slice1, slice2, ....], | |
pid2: [slice1, slice2]}, | |
...} | |
class2: | |
... | |
} | |
out[internal]: | |
{ | |
class1: [ idx1, idx2, ... ], | |
class2: [ idx1, idx2, ... ], | |
... | |
} | |
""" | |
# delete previous ones if any | |
assert self.overall_slice_by_cls is not None | |
if not hasattr(self, 'idx_by_class'): | |
self.idx_by_class = {} | |
# filter the new one given the actual list | |
for cls in self.label_name: | |
if cls not in self.idx_by_class.keys(): | |
self.idx_by_class[cls] = [] | |
else: | |
del self.idx_by_class[cls][:] | |
for cls, dict_by_pid in self.overall_slice_by_cls.items(): | |
for pid, slice_list in dict_by_pid.items(): | |
if pid not in self.pid_curr_load: | |
continue | |
self.idx_by_class[cls] += [ self.scan_z_idx[pid][_sli] for _sli in slice_list ] | |
print("###### index-by-class table has been reloaded ######") | |
def getMaskMedImg(self, label, class_id, class_ids): | |
""" | |
Generate FG/BG mask from the segmentation mask. Used when getting the support | |
""" | |
# Dense Mask | |
fg_mask = torch.where(label == class_id, | |
torch.ones_like(label), torch.zeros_like(label)) | |
bg_mask = torch.where(label != class_id, | |
torch.ones_like(label), torch.zeros_like(label)) | |
for class_id in class_ids: | |
bg_mask[label == class_id] = 0 | |
return {'fg_mask': fg_mask, | |
'bg_mask': bg_mask} | |
def subsets(self, sub_args_lst=None): | |
""" | |
Override base-class subset method | |
Create subsets by scan_ids | |
output: list [[<fid in each class>] <class1>, <class2> ] | |
""" | |
if sub_args_lst is not None: | |
subsets = [] | |
ii = 0 | |
for cls_name, index_list in self.idx_by_class.items(): | |
subsets.append( Subset(dataset = self, indices = index_list, sub_attrib_args = sub_args_lst[ii]) ) | |
ii += 1 | |
else: | |
subsets = [Subset(dataset=self, indices=index_list) for _, index_list in self.idx_by_class.items()] | |
return subsets | |
def get_support(self, curr_class: int, class_idx: list, scan_idx: list, npart: int): | |
""" | |
getting (probably multi-shot) support set for evaluation | |
sample from 50% (1shot) or 20 35 50 65 80 (5shot) | |
Args: | |
curr_cls: current class to segment, starts from 1 | |
class_idx: a list of all foreground class in nways, starts from 1 | |
npart: how may chunks used to split the support | |
scan_idx: a list, indicating the current **i_th** (note this is idx not pid) training scan | |
being served as support, in self.pid_curr_load | |
""" | |
assert npart % 2 == 1 | |
assert curr_class != 0; assert 0 not in class_idx | |
# assert not self.is_train | |
self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ] | |
# print(f'###### Using {len(scan_idx)} shot evaluation!') | |
if npart == 1: | |
pcts = [0.5] | |
else: | |
half_part = 1 / (npart * 2) | |
part_interval = (1.0 - 1.0 / npart) / (npart - 1) | |
pcts = [ half_part + part_interval * ii for ii in range(npart) ] | |
# print(f'###### Parts percentage: {pcts} ######') | |
# norm_func = get_normalize_op(modality='MR', fids=None) | |
out_buffer = [] # [{scanid, img, lb}] | |
for _part in range(npart): | |
concat_buffer = [] # for each fold do a concat in image and mask in batch dimension | |
for scan_order in scan_idx: | |
_scan_id = self.pid_curr_load[ scan_order ] | |
print(f'Using scan {_scan_id} as support!') | |
# for _pc in pcts: | |
_zlist = self.tp1_cls_map[self.label_name[curr_class]][_scan_id] # list of indices | |
_zid = _zlist[int(pcts[_part] * len(_zlist))] | |
_glb_idx = self.scan_z_idx[_scan_id][_zid] | |
# almost copy-paste __getitem__ but no augmentation | |
curr_dict = self.actual_dataset[_glb_idx] | |
img = curr_dict['img'] | |
lb = curr_dict['lb'] | |
if self.use_3_slices: | |
prev_image = np.zeros_like(img) | |
if _glb_idx > 1 and not curr_dict["is_start"]: | |
prev_dict = self.actual_dataset[_glb_idx - 1] | |
prev_image = prev_dict["img"] | |
next_image = np.zeros_like(img) | |
if _glb_idx < len(self.actual_dataset) - 1 and not curr_dict["is_end"]: | |
next_dict = self.actual_dataset[_glb_idx + 1] | |
next_image = next_dict["img"] | |
img = np.concatenate([prev_image, img, next_image], axis=-1) | |
img = np.float32(img) | |
lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure | |
img = torch.from_numpy( np.transpose(img, (2, 0, 1)) ) | |
lb = torch.from_numpy( lb ) | |
if self.tile_z_dim: | |
img = img.repeat( [ self.tile_z_dim, 1, 1] ) | |
assert img.ndimension() == 3, f'actual dim {img.ndimension()}' | |
is_start = curr_dict["is_start"] | |
is_end = curr_dict["is_end"] | |
nframe = np.int32(curr_dict["nframe"]) | |
scan_id = curr_dict["scan_id"] | |
z_id = curr_dict["z_id"] | |
sample = {"image": img, | |
"label":lb, | |
"is_start": is_start, | |
"inst": None, | |
"scribble": None, | |
"is_end": is_end, | |
"nframe": nframe, | |
"scan_id": scan_id, | |
"z_id": z_id | |
} | |
concat_buffer.append(sample) | |
out_buffer.append({ | |
"image": torch.stack([itm["image"] for itm in concat_buffer], dim = 0), | |
"label": torch.stack([itm["label"] for itm in concat_buffer], dim = 0), | |
}) | |
# do the concat, and add to output_buffer | |
# post-processing, including keeping the foreground and suppressing background. | |
support_images = [] | |
support_mask = [] | |
support_class = [] | |
for itm in out_buffer: | |
support_images.append(itm["image"]) | |
support_class.append(curr_class) | |
support_mask.append( self.getMaskMedImg( itm["label"], curr_class, class_idx )) | |
return {'class_ids': [support_class], | |
'support_images': [support_images], # | |
'support_mask': [support_mask], | |
} | |
def get_support_scan(self, curr_class: int, class_idx: list, scan_idx: list): | |
self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ] | |
# print(f'###### Using {len(scan_idx)} shot evaluation!') | |
scan_slices = self.scan_z_idx[self.potential_support_sid[0]] | |
scan_imgs = np.concatenate([self.actual_dataset[_idx]["img"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1) | |
scan_lbs = np.concatenate([self.actual_dataset[_idx]["lb"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1) | |
# binarize the labels | |
scan_lbs[scan_lbs != curr_class] = 0 | |
scan_lbs[scan_lbs == curr_class] = 1 | |
scan_imgs = torch.from_numpy(np.float32(scan_imgs)).unsqueeze(0) | |
scan_lbs = torch.from_numpy(np.float32(scan_lbs)) | |
if self.tile_z_dim: | |
scan_imgs = scan_imgs.repeat(self.tile_z_dim, 1, 1, 1) | |
assert scan_imgs.ndimension() == 4, f'actual dim {scan_imgs.ndimension()}' | |
# reshape to C, D, H, W | |
sample = {"scan": scan_imgs, | |
"labels":scan_lbs, | |
} | |
return sample | |
def get_support_multiple_classes(self, classes: list, scan_idx: list, npart: int, use_3_slices=False): | |
""" | |
getting (probably multi-shot) support set for evaluation | |
sample from 50% (1shot) or 20 35 50 65 80 (5shot) | |
Args: | |
curr_cls: current class to segment, starts from 1 | |
class_idx: a list of all foreground class in nways, starts from 1 | |
npart: how may chunks used to split the support | |
scan_idx: a list, indicating the current **i_th** (note this is idx not pid) training scan | |
being served as support, in self.pid_curr_load | |
""" | |
assert npart % 2 == 1 | |
# assert curr_class != 0; assert 0 not in class_idx | |
# assert not self.is_train | |
self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ] | |
# print(f'###### Using {len(scan_idx)} shot evaluation!') | |
if npart == 1: | |
pcts = [0.5] | |
else: | |
half_part = 1 / (npart * 2) | |
part_interval = (1.0 - 1.0 / npart) / (npart - 1) | |
pcts = [ half_part + part_interval * ii for ii in range(npart) ] | |
# print(f'###### Parts percentage: {pcts} ######') | |
out_buffer = [] # [{scanid, img, lb}] | |
for _part in range(npart): | |
concat_buffer = [] # for each fold do a concat in image and mask in batch dimension | |
for scan_order in scan_idx: | |
_scan_id = self.pid_curr_load[ scan_order ] | |
print(f'Using scan {_scan_id} as support!') | |
# for _pc in pcts: | |
zlist = [] | |
for curr_class in classes: | |
zlist.append(self.tp1_cls_map[self.label_name[curr_class]][_scan_id]) # list of indices | |
# merge all the lists in zlist and keep only the unique elements | |
# _zlist = sorted(list(set([item for sublist in zlist for item in sublist]))) | |
# take only the indices that appear in all of the sublist | |
_zlist = sorted(list(set.intersection(*map(set, zlist)))) | |
_zid = _zlist[int(pcts[_part] * len(_zlist))] | |
_glb_idx = self.scan_z_idx[_scan_id][_zid] | |
# almost copy-paste __getitem__ but no augmentation | |
curr_dict = self.actual_dataset[_glb_idx] | |
img = curr_dict['img'] | |
lb = curr_dict['lb'] | |
if use_3_slices: | |
prev_image = np.zeros_like(img) | |
if _glb_idx > 1 and not curr_dict["is_start"]: | |
prev_dict = self.actual_dataset[_glb_idx - 1] | |
assert prev_dict["scan_id"] == curr_dict["scan_id"] | |
assert prev_dict["z_id"] == curr_dict["z_id"] - 1 | |
prev_image = prev_dict["img"] | |
next_image = np.zeros_like(img) | |
if _glb_idx < len(self.actual_dataset) - 1 and not curr_dict["is_end"]: | |
next_dict = self.actual_dataset[_glb_idx + 1] | |
assert next_dict["scan_id"] == curr_dict["scan_id"] | |
assert next_dict["z_id"] == curr_dict["z_id"] + 1 | |
next_image = next_dict["img"] | |
img = np.concatenate([prev_image, img, next_image], axis=-1) | |
img = np.float32(img) | |
lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure | |
# zero all labels that are not in the classes arg | |
mask = np.zeros_like(lb) | |
for cls in classes: | |
mask[lb == cls] = 1 | |
lb[~mask.astype(np.bool)] = 0 | |
img = torch.from_numpy( np.transpose(img, (2, 0, 1)) ) | |
lb = torch.from_numpy( lb ) | |
if self.tile_z_dim: | |
img = img.repeat( [ self.tile_z_dim, 1, 1] ) | |
assert img.ndimension() == 3, f'actual dim {img.ndimension()}' | |
is_start = curr_dict["is_start"] | |
is_end = curr_dict["is_end"] | |
nframe = np.int32(curr_dict["nframe"]) | |
scan_id = curr_dict["scan_id"] | |
z_id = curr_dict["z_id"] | |
sample = {"image": img, | |
"label":lb, | |
"is_start": is_start, | |
"inst": None, | |
"scribble": None, | |
"is_end": is_end, | |
"nframe": nframe, | |
"scan_id": scan_id, | |
"z_id": z_id | |
} | |
concat_buffer.append(sample) | |
out_buffer.append({ | |
"image": torch.stack([itm["image"] for itm in concat_buffer], dim = 0), | |
"label": torch.stack([itm["label"] for itm in concat_buffer], dim = 0), | |
}) | |
# do the concat, and add to output_buffer | |
# post-processing, including keeping the foreground and suppressing background. | |
support_images = [] | |
support_mask = [] | |
support_class = [] | |
for itm in out_buffer: | |
support_images.append(itm["image"]) | |
support_class.append(curr_class) | |
# support_mask.append( self.getMaskMedImg( itm["label"], curr_class, class_idx )) | |
support_mask.append(itm["label"]) | |
return {'class_ids': [support_class], | |
'support_images': [support_images], # | |
'support_mask': [support_mask], | |
'scan_id': scan_id | |
} | |
def get_nii_dataset(config, image_size, **kwargs): | |
print(f"Check config: {config}") | |
organ_mapping = { | |
"sabs":{ | |
"rk": 2, | |
"lk": 3, | |
"liver": 6, | |
"spleen": 1 | |
}, | |
"chaost2":{ | |
"liver": 1, | |
"rk": 2, | |
"lk": 3, | |
"spleen": 4 | |
}} | |
transforms = None | |
data_name = config['dataset'] | |
if data_name == 'SABS_Superpix' or data_name == 'SABS_Superpix_448' or data_name == 'SABS_Superpix_672': | |
baseset_name = 'SABS' | |
max_label = 13 | |
modality="CT" | |
elif data_name == 'C0_Superpix': | |
raise NotImplementedError | |
baseset_name = 'C0' | |
max_label = 3 | |
elif data_name == 'CHAOST2_Superpix' or data_name == 'CHAOST2_Superpix_672': | |
baseset_name = 'CHAOST2' | |
max_label = 4 | |
modality="MR" | |
elif 'lits' in data_name.lower(): | |
baseset_name = 'LITS17' | |
max_label = 4 | |
else: | |
raise ValueError(f'Dataset: {data_name} not found') | |
# norm_func = get_normalize_op(modality=modality, fids=None) # TODO add global statistics | |
# norm_func = None | |
test_label = organ_mapping[baseset_name.lower()][config["curr_cls"]] | |
base_dir = config['path'][data_name]['data_dir'] | |
testdataset = ManualAnnoDataset(which_dataset=baseset_name, | |
base_dir=base_dir, | |
idx_split = config['eval_fold'], | |
mode = 'val', | |
scan_per_load = 1, | |
transforms=transforms, | |
min_fg=1, | |
nsup = config["task"]["n_shots"], | |
fix_length=None, | |
image_size=image_size, | |
# extern_normalize_func=norm_func | |
**kwargs) | |
testdataset = ValidationDataset(testdataset, test_classes = [test_label], npart = config["task"]["npart"]) | |
testdataset.set_curr_cls(test_label) | |
traindataset = None # TODO make this the support set later | |
return traindataset, testdataset |