""" Manually labeled dataset TODO: 1. Merge with superpixel dataset """ import glob import numpy as np import dataloaders.augutils as myaug import torch import random import os import copy import platform import json import re import cv2 from dataloaders.common import BaseDataset, Subset, ValidationDataset # from common import BaseDataset, Subset from dataloaders.dataset_utils import* from pdb import set_trace from util.utils import CircularList from util.consts import IMG_SIZE MODE_DEFAULT = "default" MODE_FULL_SCAN = "full_scan" class ManualAnnoDataset(BaseDataset): def __init__(self, which_dataset, base_dir, idx_split, mode, image_size, transforms, scan_per_load, min_fg = '', fix_length = None, tile_z_dim = 3, nsup = 1, exclude_list = [], extern_normalize_func = None, **kwargs): """ Manually labeled dataset Args: which_dataset: name of the dataset to use base_dir: directory of dataset idx_split: index of data split as we will do cross validation mode: 'train', 'val'. transforms: data transform (augmentation) function min_fg: minimum number of positive pixels in a 2D slice, mainly for stablize training when trained on manually labeled dataset scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images nsup: number of support scans fix_length: fix the length of dataset exclude_list: Labels to be excluded extern_normalize_function: normalization function used for data pre-processing """ super(ManualAnnoDataset, self).__init__(base_dir) self.img_modality = DATASET_INFO[which_dataset]['MODALITY'] self.sep = DATASET_INFO[which_dataset]['_SEP'] self.label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME'] self.image_size = image_size self.transforms = transforms self.is_train = True if mode == 'train' else False self.phase = mode self.fix_length = fix_length self.all_label_names = self.label_name self.nclass = len(self.label_name) self.tile_z_dim = tile_z_dim self.base_dir = base_dir self.nsup = nsup self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii") ] self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x))) # make it circular for the ease of spliting folds if 'use_clahe' not in kwargs: self.use_clahe = False else: self.use_clahe = kwargs['use_clahe'] if self.use_clahe: self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(7,7)) self.use_3_slices = kwargs["use_3_slices"] if 'use_3_slices' in kwargs else False if self.use_3_slices: self.tile_z_dim=1 self.get_item_mode = MODE_DEFAULT if 'get_item_mode' in kwargs: self.get_item_mode = kwargs['get_item_mode'] self.exclude_lbs = exclude_list if len(exclude_list) > 0: print(f'###### Dataset: the following classes has been excluded {exclude_list}######') self.idx_split = idx_split self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg) self.scan_per_load = scan_per_load self.info_by_scan = None self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold if extern_normalize_func is not None: # helps to keep consistent between training and testing dataset. self.norm_func = extern_normalize_func print(f'###### Dataset: using external normalization statistics ######') else: self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()]) print(f'###### Dataset: using normalization statistics calculated from loaded data ######') if self.is_train: if scan_per_load > 0: # buffer needed self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load) else: # load the entire set without a buffer self.pid_curr_load = self.scan_ids elif mode == 'val': self.pid_curr_load = self.scan_ids self.potential_support_sid = [] else: raise Exception self.actual_dataset = self.read_dataset() self.size = len(self.actual_dataset) self.overall_slice_by_cls = self.read_classfiles() self.update_subclass_lookup() def get_scanids(self, mode, idx_split): val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup]) self.potential_support_sid = val_ids[-self.nsup:] # this is actual file scan id, not index if mode == 'train': return [ ii for ii in self.img_pids if ii not in val_ids ] elif mode == 'val': return val_ids def reload_buffer(self): """ Reload a portion of the entire dataset, if the dataset is too large 1. delete original buffer 2. update self.ids_this_batch 3. update other internel variables like __len__ """ if self.scan_per_load <= 0: print("We are not using the reload buffer, doing notiong") return -1 del self.actual_dataset del self.info_by_scan self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False ) self.actual_dataset = self.read_dataset() self.size = len(self.actual_dataset) self.update_subclass_lookup() print(f'Loader buffer reloaded with a new size of {self.size} slices') def organize_sample_fids(self): out_list = {} for curr_id in self.scan_ids: curr_dict = {} _img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz') _lb_fid = os.path.join(self.base_dir, f'label_{curr_id}.nii.gz') curr_dict["img_fid"] = _img_fid curr_dict["lbs_fid"] = _lb_fid out_list[str(curr_id)] = curr_dict return out_list def read_dataset(self): """ Build index pointers to individual slices Also keep a look-up table from scan_id, slice to index """ out_list = [] self.scan_z_idx = {} self.info_by_scan = {} # meta data of each scan glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset for scan_id, itm in self.img_lb_fids.items(): if scan_id not in self.pid_curr_load: continue img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out img = img.transpose(1,2,0) self.info_by_scan[scan_id] = _info if self.use_clahe: img = np.stack([self.clahe.apply(slice.astype(np.uint8)) for slice in img], axis=0) img = np.float32(img) img = self.norm_func(img) self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])] lb = read_nii_bysitk(itm["lbs_fid"]) lb = lb.transpose(1,2,0) lb = np.float32(lb) img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR) lb = cv2.resize(lb, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST) assert img.shape[-1] == lb.shape[-1] base_idx = img.shape[-1] // 2 # index of the middle slice # write the beginning frame out_list.append( {"img": img[..., 0: 1], "lb":lb[..., 0: 0 + 1], "is_start": True, "is_end": False, "nframe": img.shape[-1], "scan_id": scan_id, "z_id":0}) self.scan_z_idx[scan_id][0] = glb_idx glb_idx += 1 for ii in range(1, img.shape[-1] - 1): out_list.append( {"img": img[..., ii: ii + 1], "lb":lb[..., ii: ii + 1], "is_start": False, "is_end": False, "nframe": -1, "scan_id": scan_id, "z_id": ii }) self.scan_z_idx[scan_id][ii] = glb_idx glb_idx += 1 ii += 1 # last frame, note the is_end flag out_list.append( {"img": img[..., ii: ii + 1], "lb":lb[..., ii: ii+ 1], "is_start": False, "is_end": True, "nframe": -1, "scan_id": scan_id, "z_id": ii }) self.scan_z_idx[scan_id][ii] = glb_idx glb_idx += 1 return out_list def read_classfiles(self): with open( os.path.join(self.base_dir, f'.classmap_{self.min_fg}.json') , 'r' ) as fopen: cls_map = json.load( fopen) fopen.close() with open( os.path.join(self.base_dir, '.classmap_1.json') , 'r' ) as fopen: self.tp1_cls_map = json.load( fopen) fopen.close() return cls_map def __getitem__(self, index): if self.get_item_mode == MODE_DEFAULT: return self.__getitem_default__(index) elif self.get_item_mode == MODE_FULL_SCAN: return self.__get_ct_scan___(index) else: raise NotImplementedError("Unknown mode") def __get_ct_scan___(self, index): scan_n = index % len(self.scan_z_idx) scan_id = list(self.scan_z_idx.keys())[scan_n] scan_slices = self.scan_z_idx[scan_id] scan_imgs = np.concatenate([self.actual_dataset[_idx]["img"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1) scan_lbs = np.concatenate([self.actual_dataset[_idx]["lb"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1) scan_imgs = np.float32(scan_imgs) scan_lbs = np.float32(scan_lbs) scan_imgs = torch.from_numpy(scan_imgs).unsqueeze(0) scan_lbs = torch.from_numpy(scan_lbs) if self.tile_z_dim: scan_imgs = scan_imgs.repeat(self.tile_z_dim, 1, 1, 1) assert scan_imgs.ndimension() == 4, f'actual dim {scan_imgs.ndimension()}' # # reshape to C, D, H, W # scan_imgs = scan_imgs.permute(1, 0, 2, 3) # scan_lbs = scan_lbs.permute(1, 0, 2, 3) sample = {"image": scan_imgs, "label":scan_lbs, "scan_id": scan_id, } return sample def get_3_slice_adjacent_image(self, image_t, index): curr_dict = self.actual_dataset[index] prev_image = np.zeros_like(image_t) if index > 1 and not curr_dict["is_start"]: prev_dict = self.actual_dataset[index - 1] prev_image = prev_dict["img"] next_image = np.zeros_like(image_t) if index < len(self.actual_dataset) - 1 and not curr_dict["is_end"]: next_dict = self.actual_dataset[index + 1] next_image = next_dict["img"] image_t = np.concatenate([prev_image, image_t, next_image], axis=-1) return image_t def __getitem_default__(self, index): index = index % len(self.actual_dataset) curr_dict = self.actual_dataset[index] if self.is_train: if len(self.exclude_lbs) > 0: for _ex_cls in self.exclude_lbs: if curr_dict["z_id"] in self.tp1_cls_map[self.label_name[_ex_cls]][curr_dict["scan_id"]]: # this slice need to be excluded since it contains label which is supposed to be unseen return self.__getitem__(index + torch.randint(low = 0, high = self.__len__() - 1, size = (1,))) comp = np.concatenate( [curr_dict["img"], curr_dict["lb"]], axis = -1 ) if self.transforms is not None: img, lb = self.transforms(comp, c_img = 1, c_label = 1, nclass = self.nclass, use_onehot = False) else: raise Exception("No transform function is provided") else: img = curr_dict['img'] lb = curr_dict['lb'] img = np.float32(img) lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure if self.use_3_slices: img = self.get_3_slice_adjacent_image(img, index) img = torch.from_numpy( np.transpose(img, (2, 0, 1)) ) lb = torch.from_numpy( lb) if self.tile_z_dim: img = img.repeat( [ self.tile_z_dim, 1, 1] ) assert img.ndimension() == 3, f'actual dim {img.ndimension()}' is_start = curr_dict["is_start"] is_end = curr_dict["is_end"] nframe = np.int32(curr_dict["nframe"]) scan_id = curr_dict["scan_id"] z_id = curr_dict["z_id"] sample = {"image": img, "label":lb, "is_start": is_start, "is_end": is_end, "nframe": nframe, "scan_id": scan_id, "z_id": z_id } # Add auxiliary attributes if self.aux_attrib is not None: for key_prefix in self.aux_attrib: # Process the data sample, create new attributes and save them in a dictionary aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix]) for key_suffix in aux_attrib_val: # one function may create multiple attributes, so we need suffix to distinguish them sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix] return sample def __len__(self): """ copy-paste from basic naive dataset configuration """ if self.get_item_mode == MODE_FULL_SCAN: return len(self.scan_z_idx) if self.fix_length != None: assert self.fix_length >= len(self.actual_dataset) return self.fix_length else: return len(self.actual_dataset) def update_subclass_lookup(self): """ Updating the class-slice indexing list Args: [internal] overall_slice_by_cls: { class1: {pid1: [slice1, slice2, ....], pid2: [slice1, slice2]}, ...} class2: ... } out[internal]: { class1: [ idx1, idx2, ... ], class2: [ idx1, idx2, ... ], ... } """ # delete previous ones if any assert self.overall_slice_by_cls is not None if not hasattr(self, 'idx_by_class'): self.idx_by_class = {} # filter the new one given the actual list for cls in self.label_name: if cls not in self.idx_by_class.keys(): self.idx_by_class[cls] = [] else: del self.idx_by_class[cls][:] for cls, dict_by_pid in self.overall_slice_by_cls.items(): for pid, slice_list in dict_by_pid.items(): if pid not in self.pid_curr_load: continue self.idx_by_class[cls] += [ self.scan_z_idx[pid][_sli] for _sli in slice_list ] print("###### index-by-class table has been reloaded ######") def getMaskMedImg(self, label, class_id, class_ids): """ Generate FG/BG mask from the segmentation mask. Used when getting the support """ # Dense Mask fg_mask = torch.where(label == class_id, torch.ones_like(label), torch.zeros_like(label)) bg_mask = torch.where(label != class_id, torch.ones_like(label), torch.zeros_like(label)) for class_id in class_ids: bg_mask[label == class_id] = 0 return {'fg_mask': fg_mask, 'bg_mask': bg_mask} def subsets(self, sub_args_lst=None): """ Override base-class subset method Create subsets by scan_ids output: list [[] , ] """ if sub_args_lst is not None: subsets = [] ii = 0 for cls_name, index_list in self.idx_by_class.items(): subsets.append( Subset(dataset = self, indices = index_list, sub_attrib_args = sub_args_lst[ii]) ) ii += 1 else: subsets = [Subset(dataset=self, indices=index_list) for _, index_list in self.idx_by_class.items()] return subsets def get_support(self, curr_class: int, class_idx: list, scan_idx: list, npart: int): """ getting (probably multi-shot) support set for evaluation sample from 50% (1shot) or 20 35 50 65 80 (5shot) Args: curr_cls: current class to segment, starts from 1 class_idx: a list of all foreground class in nways, starts from 1 npart: how may chunks used to split the support scan_idx: a list, indicating the current **i_th** (note this is idx not pid) training scan being served as support, in self.pid_curr_load """ assert npart % 2 == 1 assert curr_class != 0; assert 0 not in class_idx # assert not self.is_train self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ] # print(f'###### Using {len(scan_idx)} shot evaluation!') if npart == 1: pcts = [0.5] else: half_part = 1 / (npart * 2) part_interval = (1.0 - 1.0 / npart) / (npart - 1) pcts = [ half_part + part_interval * ii for ii in range(npart) ] # print(f'###### Parts percentage: {pcts} ######') # norm_func = get_normalize_op(modality='MR', fids=None) out_buffer = [] # [{scanid, img, lb}] for _part in range(npart): concat_buffer = [] # for each fold do a concat in image and mask in batch dimension for scan_order in scan_idx: _scan_id = self.pid_curr_load[ scan_order ] print(f'Using scan {_scan_id} as support!') # for _pc in pcts: _zlist = self.tp1_cls_map[self.label_name[curr_class]][_scan_id] # list of indices _zid = _zlist[int(pcts[_part] * len(_zlist))] _glb_idx = self.scan_z_idx[_scan_id][_zid] # almost copy-paste __getitem__ but no augmentation curr_dict = self.actual_dataset[_glb_idx] img = curr_dict['img'] lb = curr_dict['lb'] if self.use_3_slices: prev_image = np.zeros_like(img) if _glb_idx > 1 and not curr_dict["is_start"]: prev_dict = self.actual_dataset[_glb_idx - 1] prev_image = prev_dict["img"] next_image = np.zeros_like(img) if _glb_idx < len(self.actual_dataset) - 1 and not curr_dict["is_end"]: next_dict = self.actual_dataset[_glb_idx + 1] next_image = next_dict["img"] img = np.concatenate([prev_image, img, next_image], axis=-1) img = np.float32(img) lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure img = torch.from_numpy( np.transpose(img, (2, 0, 1)) ) lb = torch.from_numpy( lb ) if self.tile_z_dim: img = img.repeat( [ self.tile_z_dim, 1, 1] ) assert img.ndimension() == 3, f'actual dim {img.ndimension()}' is_start = curr_dict["is_start"] is_end = curr_dict["is_end"] nframe = np.int32(curr_dict["nframe"]) scan_id = curr_dict["scan_id"] z_id = curr_dict["z_id"] sample = {"image": img, "label":lb, "is_start": is_start, "inst": None, "scribble": None, "is_end": is_end, "nframe": nframe, "scan_id": scan_id, "z_id": z_id } concat_buffer.append(sample) out_buffer.append({ "image": torch.stack([itm["image"] for itm in concat_buffer], dim = 0), "label": torch.stack([itm["label"] for itm in concat_buffer], dim = 0), }) # do the concat, and add to output_buffer # post-processing, including keeping the foreground and suppressing background. support_images = [] support_mask = [] support_class = [] for itm in out_buffer: support_images.append(itm["image"]) support_class.append(curr_class) support_mask.append( self.getMaskMedImg( itm["label"], curr_class, class_idx )) return {'class_ids': [support_class], 'support_images': [support_images], # 'support_mask': [support_mask], } def get_support_scan(self, curr_class: int, class_idx: list, scan_idx: list): self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ] # print(f'###### Using {len(scan_idx)} shot evaluation!') scan_slices = self.scan_z_idx[self.potential_support_sid[0]] scan_imgs = np.concatenate([self.actual_dataset[_idx]["img"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1) scan_lbs = np.concatenate([self.actual_dataset[_idx]["lb"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1) # binarize the labels scan_lbs[scan_lbs != curr_class] = 0 scan_lbs[scan_lbs == curr_class] = 1 scan_imgs = torch.from_numpy(np.float32(scan_imgs)).unsqueeze(0) scan_lbs = torch.from_numpy(np.float32(scan_lbs)) if self.tile_z_dim: scan_imgs = scan_imgs.repeat(self.tile_z_dim, 1, 1, 1) assert scan_imgs.ndimension() == 4, f'actual dim {scan_imgs.ndimension()}' # reshape to C, D, H, W sample = {"scan": scan_imgs, "labels":scan_lbs, } return sample def get_support_multiple_classes(self, classes: list, scan_idx: list, npart: int, use_3_slices=False): """ getting (probably multi-shot) support set for evaluation sample from 50% (1shot) or 20 35 50 65 80 (5shot) Args: curr_cls: current class to segment, starts from 1 class_idx: a list of all foreground class in nways, starts from 1 npart: how may chunks used to split the support scan_idx: a list, indicating the current **i_th** (note this is idx not pid) training scan being served as support, in self.pid_curr_load """ assert npart % 2 == 1 # assert curr_class != 0; assert 0 not in class_idx # assert not self.is_train self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ] # print(f'###### Using {len(scan_idx)} shot evaluation!') if npart == 1: pcts = [0.5] else: half_part = 1 / (npart * 2) part_interval = (1.0 - 1.0 / npart) / (npart - 1) pcts = [ half_part + part_interval * ii for ii in range(npart) ] # print(f'###### Parts percentage: {pcts} ######') out_buffer = [] # [{scanid, img, lb}] for _part in range(npart): concat_buffer = [] # for each fold do a concat in image and mask in batch dimension for scan_order in scan_idx: _scan_id = self.pid_curr_load[ scan_order ] print(f'Using scan {_scan_id} as support!') # for _pc in pcts: zlist = [] for curr_class in classes: zlist.append(self.tp1_cls_map[self.label_name[curr_class]][_scan_id]) # list of indices # merge all the lists in zlist and keep only the unique elements # _zlist = sorted(list(set([item for sublist in zlist for item in sublist]))) # take only the indices that appear in all of the sublist _zlist = sorted(list(set.intersection(*map(set, zlist)))) _zid = _zlist[int(pcts[_part] * len(_zlist))] _glb_idx = self.scan_z_idx[_scan_id][_zid] # almost copy-paste __getitem__ but no augmentation curr_dict = self.actual_dataset[_glb_idx] img = curr_dict['img'] lb = curr_dict['lb'] if use_3_slices: prev_image = np.zeros_like(img) if _glb_idx > 1 and not curr_dict["is_start"]: prev_dict = self.actual_dataset[_glb_idx - 1] assert prev_dict["scan_id"] == curr_dict["scan_id"] assert prev_dict["z_id"] == curr_dict["z_id"] - 1 prev_image = prev_dict["img"] next_image = np.zeros_like(img) if _glb_idx < len(self.actual_dataset) - 1 and not curr_dict["is_end"]: next_dict = self.actual_dataset[_glb_idx + 1] assert next_dict["scan_id"] == curr_dict["scan_id"] assert next_dict["z_id"] == curr_dict["z_id"] + 1 next_image = next_dict["img"] img = np.concatenate([prev_image, img, next_image], axis=-1) img = np.float32(img) lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure # zero all labels that are not in the classes arg mask = np.zeros_like(lb) for cls in classes: mask[lb == cls] = 1 lb[~mask.astype(np.bool)] = 0 img = torch.from_numpy( np.transpose(img, (2, 0, 1)) ) lb = torch.from_numpy( lb ) if self.tile_z_dim: img = img.repeat( [ self.tile_z_dim, 1, 1] ) assert img.ndimension() == 3, f'actual dim {img.ndimension()}' is_start = curr_dict["is_start"] is_end = curr_dict["is_end"] nframe = np.int32(curr_dict["nframe"]) scan_id = curr_dict["scan_id"] z_id = curr_dict["z_id"] sample = {"image": img, "label":lb, "is_start": is_start, "inst": None, "scribble": None, "is_end": is_end, "nframe": nframe, "scan_id": scan_id, "z_id": z_id } concat_buffer.append(sample) out_buffer.append({ "image": torch.stack([itm["image"] for itm in concat_buffer], dim = 0), "label": torch.stack([itm["label"] for itm in concat_buffer], dim = 0), }) # do the concat, and add to output_buffer # post-processing, including keeping the foreground and suppressing background. support_images = [] support_mask = [] support_class = [] for itm in out_buffer: support_images.append(itm["image"]) support_class.append(curr_class) # support_mask.append( self.getMaskMedImg( itm["label"], curr_class, class_idx )) support_mask.append(itm["label"]) return {'class_ids': [support_class], 'support_images': [support_images], # 'support_mask': [support_mask], 'scan_id': scan_id } def get_nii_dataset(config, image_size, **kwargs): print(f"Check config: {config}") organ_mapping = { "sabs":{ "rk": 2, "lk": 3, "liver": 6, "spleen": 1 }, "chaost2":{ "liver": 1, "rk": 2, "lk": 3, "spleen": 4 }} transforms = None data_name = config['dataset'] if data_name == 'SABS_Superpix' or data_name == 'SABS_Superpix_448' or data_name == 'SABS_Superpix_672': baseset_name = 'SABS' max_label = 13 modality="CT" elif data_name == 'C0_Superpix': raise NotImplementedError baseset_name = 'C0' max_label = 3 elif data_name == 'CHAOST2_Superpix' or data_name == 'CHAOST2_Superpix_672': baseset_name = 'CHAOST2' max_label = 4 modality="MR" elif 'lits' in data_name.lower(): baseset_name = 'LITS17' max_label = 4 else: raise ValueError(f'Dataset: {data_name} not found') # norm_func = get_normalize_op(modality=modality, fids=None) # TODO add global statistics # norm_func = None test_label = organ_mapping[baseset_name.lower()][config["curr_cls"]] base_dir = config['path'][data_name]['data_dir'] testdataset = ManualAnnoDataset(which_dataset=baseset_name, base_dir=base_dir, idx_split = config['eval_fold'], mode = 'val', scan_per_load = 1, transforms=transforms, min_fg=1, nsup = config["task"]["n_shots"], fix_length=None, image_size=image_size, # extern_normalize_func=norm_func **kwargs) testdataset = ValidationDataset(testdataset, test_classes = [test_label], npart = config["task"]["npart"]) testdataset.set_curr_cls(test_label) traindataset = None # TODO make this the support set later return traindataset, testdataset