""" Copied from https://github.com/talshaharabany/AutoSAM """ import os from PIL import Image import torch.utils.data as data import torchvision.transforms as transforms import numpy as np import random import torch from dataloaders.PolypTransforms import get_polyp_transform import cv2 KVASIR = "Kvasir" CLINIC_DB = "CVC-ClinicDB" COLON_DB = "CVC-ColonDB" ETIS_DB = "ETIS-LaribPolypDB" CVC300 = "CVC-300" DATASETS = (KVASIR, CLINIC_DB, COLON_DB, ETIS_DB) EXCLUDE_DS = (CVC300, ) def create_suppport_set_for_polyps(n_support=10): """ create a text file contating n_support_images for each dataset """ root_dir = "/disk4/Lev/Projects/Self-supervised-Fewshot-Medical-Image-Segmentation/data/PolypDataset/TrainDataset" supp_images = [] supp_masks = [] image_dir = os.path.join(root_dir, "images") mask_dir = os.path.join(root_dir, "masks") # randonly sample n_support images and masks image_paths = sorted([os.path.join(image_dir, f) for f in os.listdir( image_dir) if f.endswith('.jpg') or f.endswith('.png')]) mask_paths = sorted([os.path.join(mask_dir, f) for f in os.listdir( mask_dir) if f.endswith('.png')]) while len(supp_images) < n_support: index = random.randint(0, len(image_paths) - 1) # check that the index is not already in the support set if image_paths[index] in supp_images: continue supp_images.append(image_paths[index]) supp_masks.append(mask_paths[index]) with open(os.path.join(root_dir, "support.txt"), 'w') as file: for image_path, mask_path in zip(supp_images, supp_masks): file.write(f"{image_path} {mask_path}\n") def create_train_val_test_split_for_polyps(): root_dir = "/disk4/Lev/Projects/Self-supervised-Fewshot-Medical-Image-Segmentation/data/PolypDataset/" # for each subdir in root_dir, create a split file num_train_images_per_dataset = { "CVC-ClinicDB": 548, "Kvasir": 900, "CVC-300": 0, "CVC-ColonDB": 0} num_test_images_per_dataset = { "CVC-ClinicDB": 64, "Kvasir": 100, "CVC-300": 60, "CVC-ColonDB": 380} for subdir in os.listdir(root_dir): subdir_path = os.path.join(root_dir, subdir) if os.path.isdir(subdir_path): split_file = os.path.join(subdir_path, "split.txt") image_dir = os.path.join(subdir_path, "images") create_train_val_test_split( image_dir, split_file, train_number=num_train_images_per_dataset[subdir], test_number=num_test_images_per_dataset[subdir]) def create_train_val_test_split(root, split_file, train_number=100, test_number=20): """ Create a train, val, test split file for a dataset root: root directory of dataset split_file: name of split file to create train_ratio: ratio of train set val_ratio: ratio of val set test_ratio: ratio of test set """ # Get all files in root directory files = os.listdir(root) # Filter out non-image files, remove suffix files = [f.split('.')[0] for f in files if f.endswith('.jpg') or f.endswith('.png')] # Shuffle files random.shuffle(files) # Calculate number of files for each split num_files = len(files) num_train = train_number num_test = test_number num_val = num_files - num_train - num_test print(f"num_train: {num_train}, num_val: {num_val}, num_test: {num_test}") # Create splits train = files[:num_train] val = files[num_train:num_train + num_val] test = files[num_train + num_val:] # Write splits to file with open(split_file, 'w') as file: file.write("train\n") for f in train: file.write(f + "\n") file.write("val\n") for f in val: file.write(f + "\n") file.write("test\n") for f in test: file.write(f + "\n") class PolypDataset(data.Dataset): """ dataloader for polyp segmentation tasks """ def __init__(self, root, image_root=None, gt_root=None, trainsize=352, augmentations=None, train=True, sam_trans=None, datasets=DATASETS, image_size=(1024, 1024), ds_mean=None, ds_std=None): self.trainsize = trainsize self.augmentations = augmentations self.datasets = datasets if isinstance(image_size, int): image_size = (image_size, image_size) self.image_size = image_size if image_root is not None and gt_root is not None: self.images = [ os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] self.gts = [ os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png')] # also look in subdirectories for subdir in os.listdir(image_root): # if not dir, continue if not os.path.isdir(os.path.join(image_root, subdir)): continue subdir_image_root = os.path.join(image_root, subdir) subdir_gt_root = os.path.join(gt_root, subdir) self.images.extend([os.path.join(subdir_image_root, f) for f in os.listdir( subdir_image_root) if f.endswith('.jpg') or f.endswith('.png')]) self.gts.extend([os.path.join(subdir_gt_root, f) for f in os.listdir( subdir_gt_root) if f.endswith('.png')]) else: self.images, self.gts = self.get_image_gt_pairs( root, split="train" if train else "test", datasets=self.datasets) self.images = sorted(self.images) self.gts = sorted(self.gts) if not 'VPS' in root: self.filter_files_and_get_ds_mean_and_std() if ds_mean is not None and ds_std is not None: self.mean, self.std = ds_mean, ds_std self.size = len(self.images) self.train = train self.sam_trans = sam_trans if self.sam_trans is not None: # sam trans takes care of norm self.mean, self.std = 0 , 1 def get_image_gt_pairs(self, dir_root: str, split="train", datasets: tuple = DATASETS): """ for each folder in dir root, get all image-gt pairs. Assumes each subdir has a split.txt file dir_root: root directory of all subdirectories, each subdirectory contains images and masks folders split: train, val, or test """ image_paths = [] gt_paths = [] for folder in os.listdir(dir_root): if folder not in datasets: continue split_file = os.path.join(dir_root, folder, "split.txt") if os.path.isfile(split_file): image_root = os.path.join(dir_root, folder, "images") gt_root = os.path.join(dir_root, folder, "masks") image_paths_tmp, gt_paths_tmp = self.get_image_gt_pairs_from_text_file( image_root, gt_root, split_file, split=split) image_paths.extend(image_paths_tmp) gt_paths.extend(gt_paths_tmp) else: print( f"No split.txt file found in {os.path.join(dir_root, folder)}") return image_paths, gt_paths def get_image_gt_pairs_from_text_file(self, image_root: str, gt_root: str, text_file: str, split: str = "train"): """ image_root: root directory of images gt_root: root directory of ground truth text_file: text file containing train, val, test split with the following format: train: image1 image2 ... val: image1 image2 ... test: image1 image2 ... split: train, val, or test """ # Initialize a dictionary to hold file names for each split splits = {"train": [], "val": [], "test": []} current_split = None # Read the text file and categorize file names under each split with open(text_file, 'r') as file: for line in file: line = line.strip() if line in splits: current_split = line elif line and current_split: splits[current_split].append(line) # Get the file names for the requested split file_names = splits[split] # Create image-ground truth pairs image_paths = [] gt_paths = [] for name in file_names: image_path = os.path.join(image_root, name + '.png') gt_path = os.path.join(gt_root, name + '.png') image_paths.append(image_path) gt_paths.append(gt_path) return image_paths, gt_paths def get_support_from_dirs(self, support_image_dir, support_mask_dir, n_support=1): support_images = [] support_labels = [] # get all images and masks support_image_paths = sorted([os.path.join(support_image_dir, f) for f in os.listdir( support_image_dir) if f.endswith('.jpg') or f.endswith('.png')]) support_mask_paths = sorted([os.path.join(support_mask_dir, f) for f in os.listdir( support_mask_dir) if f.endswith('.png')]) # sample n_support images and masks for i in range(n_support): index = random.randint(0, len(support_image_paths) - 1) support_img = self.cv2_loader( support_image_paths[index], is_mask=False) support_mask = self.cv2_loader( support_mask_paths[index], is_mask=True) support_images.append(support_img) support_labels.append(support_mask) if self.augmentations: support_images = [self.augmentations( img, mask)[0] for img, mask in zip(support_images, support_labels)] support_labels = [self.augmentations( img, mask)[1] for img, mask in zip(support_images, support_labels)] support_images = [(support_image - self.mean) / self.std if support_image.max() == 255 and support_image.min() == 0 else support_image for support_image in support_images] if self.sam_trans is not None: support_images = [self.sam_trans.preprocess( img).squeeze(0) for img in support_images] support_labels = [self.sam_trans.preprocess( mask) for mask in support_labels] else: image_size = self.image_size support_images = [torch.nn.functional.interpolate(img.unsqueeze( 0), size=image_size, mode='bilinear', align_corners=False).squeeze(0) for img in support_images] support_labels = [torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze( 0), size=image_size, mode='nearest').squeeze(0).squeeze(0) for mask in support_labels] return torch.stack(support_images), torch.stack(support_labels) def get_support_from_text_file(self, text_file, n_support=1): """ each row in the file has 2 paths divided by space, the first is the image path and the second is the mask path """ support_images = [] support_labels = [] with open(text_file, 'r') as file: for line in file: image_path, mask_path = line.strip().split() support_images.append(image_path) support_labels.append(mask_path) # indices = random.choices(range(len(support_images)), k=n_support) if n_support > len(support_images): raise ValueError(f"n_support ({n_support}) is larger than the number of images in the text file ({len(support_images)})") n_support_images = support_images[:n_support] n_support_labels = support_labels[:n_support] return n_support_images, n_support_labels def get_support(self, n_support=1, support_image_dir=None, support_mask_dir=None, text_file=None): """ Get support set from specified directories, text file or from the dataset itself """ if support_image_dir is not None and support_mask_dir: return self.get_support_from_dirs(support_image_dir, support_mask_dir, n_support=n_support) elif text_file is not None: support_image_paths, support_gt_paths = self.get_support_from_text_file(text_file, n_support=n_support) else: # randomly sample n_support images and masks from the dataset indices = random.choices(range(self.size), k=n_support) # indices = list(range(n_support)) print(f"support indices:{indices}") support_image_paths = [self.images[index] for index in indices] support_gt_paths = [self.gts[index] for index in indices] support_images = [] support_gts = [] for image_path, gt_path in zip(support_image_paths, support_gt_paths): support_img = self.cv2_loader(image_path, is_mask=False) support_mask = self.cv2_loader(gt_path, is_mask=True) out = self.process_image_gt(support_img, support_mask) support_images.append(out['image'].unsqueeze(0)) support_gts.append(out['label'].unsqueeze(0)) if len(support_images) >= n_support: break return support_images, support_gts, out['case'] # return torch.stack(support_images), torch.stack(support_gts), out['case'] def process_image_gt(self, image, gt, dataset=""): """ image and gt are expected to be output from self.cv2_loader """ original_size = tuple(image.shape[-2:]) if self.augmentations: image, mask = self.augmentations(image, gt) if self.sam_trans: image, mask = self.sam_trans.apply_image_torch( image.unsqueeze(0)), self.sam_trans.apply_image_torch(mask) elif image.max() <= 255 and image.min() >= 0: image = (image - self.mean) / self.std mask[mask > 0.5] = 1 mask[mask <= 0.5] = 0 # image_size = tuple(img.shape[-2:]) image_size = self.image_size if self.sam_trans is None: image = torch.nn.functional.interpolate(image.unsqueeze( 0), size=image_size, mode='bilinear', align_corners=False).squeeze(0) mask = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze( 0), size=image_size, mode='nearest').squeeze(0).squeeze(0) # img = (img - img.min()) / (img.max() - img.min()) # TODO uncomment this if results get worse return {'image': self.sam_trans.preprocess(image).squeeze(0) if self.sam_trans else image, 'label': self.sam_trans.preprocess(mask) if self.sam_trans else mask, 'original_size': torch.Tensor(original_size), 'image_size': torch.Tensor(image_size), 'case': dataset} # case to be compatible with polyp video dataset def get_dataset_name_from_path(self, path): for dataset in self.datasets: if dataset in path: return dataset return "" def __getitem__(self, index): image = self.cv2_loader(self.images[index], is_mask=False) gt = self.cv2_loader(self.gts[index], is_mask=True) dataset = self.get_dataset_name_from_path(self.images[index]) return self.process_image_gt(image, gt, dataset) def filter_files_and_get_ds_mean_and_std(self): assert len(self.images) == len(self.gts) images = [] gts = [] ds_mean = 0 ds_std = 0 for img_path, gt_path in zip(self.images, self.gts): if any([ex_ds in img_path for ex_ds in EXCLUDE_DS]): continue img = Image.open(img_path) gt = Image.open(gt_path) if img.size == gt.size: images.append(img_path) gts.append(gt_path) ds_mean += np.array(img).mean() ds_std += np.array(img).std() self.images = images self.gts = gts self.mean = ds_mean / len(self.images) self.std = ds_std / len(self.images) def rgb_loader(self, path): with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') def binary_loader(self, path): # with open(path, 'rb') as f: # img = Image.open(f) # return img.convert('1') img = cv2.imread(path, 0) return img def cv2_loader(self, path, is_mask): if is_mask: img = cv2.imread(path, 0) img[img > 0] = 1 else: img = cv2.cvtColor(cv2.imread( path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) return img def resize(self, img, gt): assert img.size == gt.size w, h = img.size if h < self.trainsize or w < self.trainsize: h = max(h, self.trainsize) w = max(w, self.trainsize) return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST) else: return img, gt def __len__(self): # return 32 return self.size class SuperpixPolypDataset(PolypDataset): def __init__(self, root, image_root=None, gt_root=None, trainsize=352, augmentations=None, train=True, sam_trans=None, datasets=DATASETS, image_size=(1024, 1024), ds_mean=None, ds_std=None): self.trainsize = trainsize self.augmentations = augmentations self.datasets = datasets self.image_size = image_size # print(self.augmentations) if image_root is not None and gt_root is not None: self.images = [ os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] self.gts = [ os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png') and 'superpix' in f] # also look in subdirectories for subdir in os.listdir(image_root): # if not dir, continue if not os.path.isdir(os.path.join(image_root, subdir)): continue subdir_image_root = os.path.join(image_root, subdir) subdir_gt_root = os.path.join(gt_root, subdir) self.images.extend([os.path.join(subdir_image_root, f) for f in os.listdir( subdir_image_root) if f.endswith('.jpg') or f.endswith('.png')]) self.gts.extend([os.path.join(subdir_gt_root, f) for f in os.listdir( subdir_gt_root) if f.endswith('.png')]) else: self.images, self.gts = self.get_image_gt_pairs( root, split="train" if train else "test", datasets=self.datasets) self.images = sorted(self.images) self.gts = sorted(self.gts) if not 'VPS' in root: self.filter_files_and_get_ds_mean_and_std() if ds_mean is not None and ds_std is not None: self.mean, self.std = ds_mean, ds_std self.size = len(self.images) self.train = train self.sam_trans = sam_trans if self.sam_trans is not None: # sam trans takes care of norm self.mean, self.std = 0 , 1 def __getitem__(self, index): image = self.cv2_loader(self.images[index], is_mask=False) gt = self.cv2_loader(self.gts[index], is_mask=False) gt = gt[:, :, 0] fgpath = os.path.basename(self.gts[index]).split('.png')[0].split('superpix-MIDDLE_') fgpath = os.path.join(os.path.dirname(self.gts[index]), 'fgmask_' + fgpath[1] + '.png') fg = self.cv2_loader(fgpath, is_mask=True) dataset = self.get_dataset_name_from_path(self.images[index]) # randomly choose a superpixels from the gt gt[1-fg] = 0 sp_id = random.choice(np.unique(gt)[1:]) sp = (gt == sp_id).astype(np.uint8) out = self.process_image_gt(image, gt, dataset) support_image, support_sp, dataset = out["image"], out["label"], out["case"] out = self.process_image_gt(image, sp, dataset) query_image, query_sp, dataset = out["image"], out["label"], out["case"] # TODO tile the masks to have 3 channels? support_bg_mask = 1 - support_sp support_masks = {"fg_mask": support_sp, "bg_mask": support_bg_mask} batch = {"support_images" : [[support_image]], "support_mask" : [[support_masks]], "query_images" : [query_image], "query_labels" : [query_sp], "scan_id" : [dataset] } return batch def get_superpix_polyp_dataset(image_size:tuple=(1024,1024), sam_trans=None): transform_train, transform_test = get_polyp_transform() image_root = './data/PolypDataset/TrainDataset/images/' gt_root = './data/PolypDataset/TrainDataset/superpixels/' ds_train = SuperpixPolypDataset(root=image_root, image_root=image_root, gt_root=gt_root, augmentations=transform_train, sam_trans=sam_trans, image_size=image_size) return ds_train def get_polyp_dataset(image_size, sam_trans=None): transform_train, transform_test = get_polyp_transform() image_root = './data/PolypDataset/TrainDataset/images/' gt_root = './data/PolypDataset/TrainDataset/masks/' ds_train = PolypDataset(root=image_root, image_root=image_root, gt_root=gt_root, augmentations=transform_test, sam_trans=sam_trans, train=True, image_size=image_size) image_root = './data/PolypDataset/TestDataset/test/images/' gt_root = './data/PolypDataset/TestDataset/test/masks/' ds_test = PolypDataset(root=image_root, image_root=image_root, gt_root=gt_root, train=False, augmentations=transform_test, sam_trans=sam_trans, image_size=image_size) return ds_train, ds_test def get_tests_polyp_dataset(sam_trans): transform_train, transform_test = get_polyp_transform() image_root = './data/polyp/TestDataset/Kvasir/images/' gt_root = './data/polyp/TestDataset/Kvasir/masks/' ds_Kvasir = PolypDataset( image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) image_root = './data/polyp/TestDataset/CVC-ClinicDB/images/' gt_root = './data/polyp/TestDataset/CVC-ClinicDB/masks/' ds_ClinicDB = PolypDataset( image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) image_root = './data/polyp/TestDataset/CVC-ColonDB/images/' gt_root = './data/polyp/TestDataset/CVC-ColonDB/masks/' ds_ColonDB = PolypDataset( image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) image_root = './data/polyp/TestDataset/ETIS-LaribPolypDB/images/' gt_root = './data/polyp/TestDataset/ETIS-LaribPolypDB/masks/' ds_ETIS = PolypDataset( image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) return ds_Kvasir, ds_ClinicDB, ds_ColonDB, ds_ETIS if __name__ == '__main__': # create_train_val_test_split_for_polyps() create_suppport_set_for_polyps()