Spaces:
Sleeping
Sleeping
""" | |
Copied from https://github.com/talshaharabany/AutoSAM | |
""" | |
import os | |
from PIL import Image | |
import torch.utils.data as data | |
import torchvision.transforms as transforms | |
import numpy as np | |
import random | |
import torch | |
from dataloaders.PolypTransforms import get_polyp_transform | |
import cv2 | |
KVASIR = "Kvasir" | |
CLINIC_DB = "CVC-ClinicDB" | |
COLON_DB = "CVC-ColonDB" | |
ETIS_DB = "ETIS-LaribPolypDB" | |
CVC300 = "CVC-300" | |
DATASETS = (KVASIR, CLINIC_DB, COLON_DB, ETIS_DB) | |
EXCLUDE_DS = (CVC300, ) | |
def create_suppport_set_for_polyps(n_support=10): | |
""" | |
create a text file contating n_support_images for each dataset | |
""" | |
root_dir = "/disk4/Lev/Projects/Self-supervised-Fewshot-Medical-Image-Segmentation/data/PolypDataset/TrainDataset" | |
supp_images = [] | |
supp_masks = [] | |
image_dir = os.path.join(root_dir, "images") | |
mask_dir = os.path.join(root_dir, "masks") | |
# randonly sample n_support images and masks | |
image_paths = sorted([os.path.join(image_dir, f) for f in os.listdir( | |
image_dir) if f.endswith('.jpg') or f.endswith('.png')]) | |
mask_paths = sorted([os.path.join(mask_dir, f) for f in os.listdir( | |
mask_dir) if f.endswith('.png')]) | |
while len(supp_images) < n_support: | |
index = random.randint(0, len(image_paths) - 1) | |
# check that the index is not already in the support set | |
if image_paths[index] in supp_images: | |
continue | |
supp_images.append(image_paths[index]) | |
supp_masks.append(mask_paths[index]) | |
with open(os.path.join(root_dir, "support.txt"), 'w') as file: | |
for image_path, mask_path in zip(supp_images, supp_masks): | |
file.write(f"{image_path} {mask_path}\n") | |
def create_train_val_test_split_for_polyps(): | |
root_dir = "/disk4/Lev/Projects/Self-supervised-Fewshot-Medical-Image-Segmentation/data/PolypDataset/" | |
# for each subdir in root_dir, create a split file | |
num_train_images_per_dataset = { | |
"CVC-ClinicDB": 548, "Kvasir": 900, "CVC-300": 0, "CVC-ColonDB": 0} | |
num_test_images_per_dataset = { | |
"CVC-ClinicDB": 64, "Kvasir": 100, "CVC-300": 60, "CVC-ColonDB": 380} | |
for subdir in os.listdir(root_dir): | |
subdir_path = os.path.join(root_dir, subdir) | |
if os.path.isdir(subdir_path): | |
split_file = os.path.join(subdir_path, "split.txt") | |
image_dir = os.path.join(subdir_path, "images") | |
create_train_val_test_split( | |
image_dir, split_file, train_number=num_train_images_per_dataset[subdir], test_number=num_test_images_per_dataset[subdir]) | |
def create_train_val_test_split(root, split_file, train_number=100, test_number=20): | |
""" | |
Create a train, val, test split file for a dataset | |
root: root directory of dataset | |
split_file: name of split file to create | |
train_ratio: ratio of train set | |
val_ratio: ratio of val set | |
test_ratio: ratio of test set | |
""" | |
# Get all files in root directory | |
files = os.listdir(root) | |
# Filter out non-image files, remove suffix | |
files = [f.split('.')[0] | |
for f in files if f.endswith('.jpg') or f.endswith('.png')] | |
# Shuffle files | |
random.shuffle(files) | |
# Calculate number of files for each split | |
num_files = len(files) | |
num_train = train_number | |
num_test = test_number | |
num_val = num_files - num_train - num_test | |
print(f"num_train: {num_train}, num_val: {num_val}, num_test: {num_test}") | |
# Create splits | |
train = files[:num_train] | |
val = files[num_train:num_train + num_val] | |
test = files[num_train + num_val:] | |
# Write splits to file | |
with open(split_file, 'w') as file: | |
file.write("train\n") | |
for f in train: | |
file.write(f + "\n") | |
file.write("val\n") | |
for f in val: | |
file.write(f + "\n") | |
file.write("test\n") | |
for f in test: | |
file.write(f + "\n") | |
class PolypDataset(data.Dataset): | |
""" | |
dataloader for polyp segmentation tasks | |
""" | |
def __init__(self, root, image_root=None, gt_root=None, trainsize=352, augmentations=None, train=True, sam_trans=None, datasets=DATASETS, image_size=(1024, 1024), ds_mean=None, ds_std=None): | |
self.trainsize = trainsize | |
self.augmentations = augmentations | |
self.datasets = datasets | |
if isinstance(image_size, int): | |
image_size = (image_size, image_size) | |
self.image_size = image_size | |
if image_root is not None and gt_root is not None: | |
self.images = [ | |
os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] | |
self.gts = [ | |
os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png')] | |
# also look in subdirectories | |
for subdir in os.listdir(image_root): | |
# if not dir, continue | |
if not os.path.isdir(os.path.join(image_root, subdir)): | |
continue | |
subdir_image_root = os.path.join(image_root, subdir) | |
subdir_gt_root = os.path.join(gt_root, subdir) | |
self.images.extend([os.path.join(subdir_image_root, f) for f in os.listdir( | |
subdir_image_root) if f.endswith('.jpg') or f.endswith('.png')]) | |
self.gts.extend([os.path.join(subdir_gt_root, f) for f in os.listdir( | |
subdir_gt_root) if f.endswith('.png')]) | |
else: | |
self.images, self.gts = self.get_image_gt_pairs( | |
root, split="train" if train else "test", datasets=self.datasets) | |
self.images = sorted(self.images) | |
self.gts = sorted(self.gts) | |
if not 'VPS' in root: | |
self.filter_files_and_get_ds_mean_and_std() | |
if ds_mean is not None and ds_std is not None: | |
self.mean, self.std = ds_mean, ds_std | |
self.size = len(self.images) | |
self.train = train | |
self.sam_trans = sam_trans | |
if self.sam_trans is not None: | |
# sam trans takes care of norm | |
self.mean, self.std = 0 , 1 | |
def get_image_gt_pairs(self, dir_root: str, split="train", datasets: tuple = DATASETS): | |
""" | |
for each folder in dir root, get all image-gt pairs. Assumes each subdir has a split.txt file | |
dir_root: root directory of all subdirectories, each subdirectory contains images and masks folders | |
split: train, val, or test | |
""" | |
image_paths = [] | |
gt_paths = [] | |
for folder in os.listdir(dir_root): | |
if folder not in datasets: | |
continue | |
split_file = os.path.join(dir_root, folder, "split.txt") | |
if os.path.isfile(split_file): | |
image_root = os.path.join(dir_root, folder, "images") | |
gt_root = os.path.join(dir_root, folder, "masks") | |
image_paths_tmp, gt_paths_tmp = self.get_image_gt_pairs_from_text_file( | |
image_root, gt_root, split_file, split=split) | |
image_paths.extend(image_paths_tmp) | |
gt_paths.extend(gt_paths_tmp) | |
else: | |
print( | |
f"No split.txt file found in {os.path.join(dir_root, folder)}") | |
return image_paths, gt_paths | |
def get_image_gt_pairs_from_text_file(self, image_root: str, gt_root: str, text_file: str, split: str = "train"): | |
""" | |
image_root: root directory of images | |
gt_root: root directory of ground truth | |
text_file: text file containing train, val, test split with the following format: | |
train: | |
image1 | |
image2 | |
... | |
val: | |
image1 | |
image2 | |
... | |
test: | |
image1 | |
image2 | |
... | |
split: train, val, or test | |
""" | |
# Initialize a dictionary to hold file names for each split | |
splits = {"train": [], "val": [], "test": []} | |
current_split = None | |
# Read the text file and categorize file names under each split | |
with open(text_file, 'r') as file: | |
for line in file: | |
line = line.strip() | |
if line in splits: | |
current_split = line | |
elif line and current_split: | |
splits[current_split].append(line) | |
# Get the file names for the requested split | |
file_names = splits[split] | |
# Create image-ground truth pairs | |
image_paths = [] | |
gt_paths = [] | |
for name in file_names: | |
image_path = os.path.join(image_root, name + '.png') | |
gt_path = os.path.join(gt_root, name + '.png') | |
image_paths.append(image_path) | |
gt_paths.append(gt_path) | |
return image_paths, gt_paths | |
def get_support_from_dirs(self, support_image_dir, support_mask_dir, n_support=1): | |
support_images = [] | |
support_labels = [] | |
# get all images and masks | |
support_image_paths = sorted([os.path.join(support_image_dir, f) for f in os.listdir( | |
support_image_dir) if f.endswith('.jpg') or f.endswith('.png')]) | |
support_mask_paths = sorted([os.path.join(support_mask_dir, f) for f in os.listdir( | |
support_mask_dir) if f.endswith('.png')]) | |
# sample n_support images and masks | |
for i in range(n_support): | |
index = random.randint(0, len(support_image_paths) - 1) | |
support_img = self.cv2_loader( | |
support_image_paths[index], is_mask=False) | |
support_mask = self.cv2_loader( | |
support_mask_paths[index], is_mask=True) | |
support_images.append(support_img) | |
support_labels.append(support_mask) | |
if self.augmentations: | |
support_images = [self.augmentations( | |
img, mask)[0] for img, mask in zip(support_images, support_labels)] | |
support_labels = [self.augmentations( | |
img, mask)[1] for img, mask in zip(support_images, support_labels)] | |
support_images = [(support_image - self.mean) / self.std if support_image.max() == 255 and support_image.min() == 0 else support_image for support_image in support_images] | |
if self.sam_trans is not None: | |
support_images = [self.sam_trans.preprocess( | |
img).squeeze(0) for img in support_images] | |
support_labels = [self.sam_trans.preprocess( | |
mask) for mask in support_labels] | |
else: | |
image_size = self.image_size | |
support_images = [torch.nn.functional.interpolate(img.unsqueeze( | |
0), size=image_size, mode='bilinear', align_corners=False).squeeze(0) for img in support_images] | |
support_labels = [torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze( | |
0), size=image_size, mode='nearest').squeeze(0).squeeze(0) for mask in support_labels] | |
return torch.stack(support_images), torch.stack(support_labels) | |
def get_support_from_text_file(self, text_file, n_support=1): | |
""" | |
each row in the file has 2 paths divided by space, the first is the image path and the second is the mask path | |
""" | |
support_images = [] | |
support_labels = [] | |
with open(text_file, 'r') as file: | |
for line in file: | |
image_path, mask_path = line.strip().split() | |
support_images.append(image_path) | |
support_labels.append(mask_path) | |
# indices = random.choices(range(len(support_images)), k=n_support) | |
if n_support > len(support_images): | |
raise ValueError(f"n_support ({n_support}) is larger than the number of images in the text file ({len(support_images)})") | |
n_support_images = support_images[:n_support] | |
n_support_labels = support_labels[:n_support] | |
return n_support_images, n_support_labels | |
def get_support(self, n_support=1, support_image_dir=None, support_mask_dir=None, text_file=None): | |
""" | |
Get support set from specified directories, text file or from the dataset itself | |
""" | |
if support_image_dir is not None and support_mask_dir: | |
return self.get_support_from_dirs(support_image_dir, support_mask_dir, n_support=n_support) | |
elif text_file is not None: | |
support_image_paths, support_gt_paths = self.get_support_from_text_file(text_file, n_support=n_support) | |
else: | |
# randomly sample n_support images and masks from the dataset | |
indices = random.choices(range(self.size), k=n_support) | |
# indices = list(range(n_support)) | |
print(f"support indices:{indices}") | |
support_image_paths = [self.images[index] for index in indices] | |
support_gt_paths = [self.gts[index] for index in indices] | |
support_images = [] | |
support_gts = [] | |
for image_path, gt_path in zip(support_image_paths, support_gt_paths): | |
support_img = self.cv2_loader(image_path, is_mask=False) | |
support_mask = self.cv2_loader(gt_path, is_mask=True) | |
out = self.process_image_gt(support_img, support_mask) | |
support_images.append(out['image'].unsqueeze(0)) | |
support_gts.append(out['label'].unsqueeze(0)) | |
if len(support_images) >= n_support: | |
break | |
return support_images, support_gts, out['case'] | |
# return torch.stack(support_images), torch.stack(support_gts), out['case'] | |
def process_image_gt(self, image, gt, dataset=""): | |
""" | |
image and gt are expected to be output from self.cv2_loader | |
""" | |
original_size = tuple(image.shape[-2:]) | |
if self.augmentations: | |
image, mask = self.augmentations(image, gt) | |
if self.sam_trans: | |
image, mask = self.sam_trans.apply_image_torch( | |
image.unsqueeze(0)), self.sam_trans.apply_image_torch(mask) | |
elif image.max() <= 255 and image.min() >= 0: | |
image = (image - self.mean) / self.std | |
mask[mask > 0.5] = 1 | |
mask[mask <= 0.5] = 0 | |
# image_size = tuple(img.shape[-2:]) | |
image_size = self.image_size | |
if self.sam_trans is None: | |
image = torch.nn.functional.interpolate(image.unsqueeze( | |
0), size=image_size, mode='bilinear', align_corners=False).squeeze(0) | |
mask = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze( | |
0), size=image_size, mode='nearest').squeeze(0).squeeze(0) | |
# img = (img - img.min()) / (img.max() - img.min()) # TODO uncomment this if results get worse | |
return {'image': self.sam_trans.preprocess(image).squeeze(0) if self.sam_trans else image, | |
'label': self.sam_trans.preprocess(mask) if self.sam_trans else mask, | |
'original_size': torch.Tensor(original_size), | |
'image_size': torch.Tensor(image_size), | |
'case': dataset} # case to be compatible with polyp video dataset | |
def get_dataset_name_from_path(self, path): | |
for dataset in self.datasets: | |
if dataset in path: | |
return dataset | |
return "" | |
def __getitem__(self, index): | |
image = self.cv2_loader(self.images[index], is_mask=False) | |
gt = self.cv2_loader(self.gts[index], is_mask=True) | |
dataset = self.get_dataset_name_from_path(self.images[index]) | |
return self.process_image_gt(image, gt, dataset) | |
def filter_files_and_get_ds_mean_and_std(self): | |
assert len(self.images) == len(self.gts) | |
images = [] | |
gts = [] | |
ds_mean = 0 | |
ds_std = 0 | |
for img_path, gt_path in zip(self.images, self.gts): | |
if any([ex_ds in img_path for ex_ds in EXCLUDE_DS]): | |
continue | |
img = Image.open(img_path) | |
gt = Image.open(gt_path) | |
if img.size == gt.size: | |
images.append(img_path) | |
gts.append(gt_path) | |
ds_mean += np.array(img).mean() | |
ds_std += np.array(img).std() | |
self.images = images | |
self.gts = gts | |
self.mean = ds_mean / len(self.images) | |
self.std = ds_std / len(self.images) | |
def rgb_loader(self, path): | |
with open(path, 'rb') as f: | |
img = Image.open(f) | |
return img.convert('RGB') | |
def binary_loader(self, path): | |
# with open(path, 'rb') as f: | |
# img = Image.open(f) | |
# return img.convert('1') | |
img = cv2.imread(path, 0) | |
return img | |
def cv2_loader(self, path, is_mask): | |
if is_mask: | |
img = cv2.imread(path, 0) | |
img[img > 0] = 1 | |
else: | |
img = cv2.cvtColor(cv2.imread( | |
path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) | |
return img | |
def resize(self, img, gt): | |
assert img.size == gt.size | |
w, h = img.size | |
if h < self.trainsize or w < self.trainsize: | |
h = max(h, self.trainsize) | |
w = max(w, self.trainsize) | |
return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST) | |
else: | |
return img, gt | |
def __len__(self): | |
# return 32 | |
return self.size | |
class SuperpixPolypDataset(PolypDataset): | |
def __init__(self, root, image_root=None, gt_root=None, trainsize=352, augmentations=None, train=True, sam_trans=None, datasets=DATASETS, image_size=(1024, 1024), ds_mean=None, ds_std=None): | |
self.trainsize = trainsize | |
self.augmentations = augmentations | |
self.datasets = datasets | |
self.image_size = image_size | |
# print(self.augmentations) | |
if image_root is not None and gt_root is not None: | |
self.images = [ | |
os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] | |
self.gts = [ | |
os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png') and 'superpix' in f] | |
# also look in subdirectories | |
for subdir in os.listdir(image_root): | |
# if not dir, continue | |
if not os.path.isdir(os.path.join(image_root, subdir)): | |
continue | |
subdir_image_root = os.path.join(image_root, subdir) | |
subdir_gt_root = os.path.join(gt_root, subdir) | |
self.images.extend([os.path.join(subdir_image_root, f) for f in os.listdir( | |
subdir_image_root) if f.endswith('.jpg') or f.endswith('.png')]) | |
self.gts.extend([os.path.join(subdir_gt_root, f) for f in os.listdir( | |
subdir_gt_root) if f.endswith('.png')]) | |
else: | |
self.images, self.gts = self.get_image_gt_pairs( | |
root, split="train" if train else "test", datasets=self.datasets) | |
self.images = sorted(self.images) | |
self.gts = sorted(self.gts) | |
if not 'VPS' in root: | |
self.filter_files_and_get_ds_mean_and_std() | |
if ds_mean is not None and ds_std is not None: | |
self.mean, self.std = ds_mean, ds_std | |
self.size = len(self.images) | |
self.train = train | |
self.sam_trans = sam_trans | |
if self.sam_trans is not None: | |
# sam trans takes care of norm | |
self.mean, self.std = 0 , 1 | |
def __getitem__(self, index): | |
image = self.cv2_loader(self.images[index], is_mask=False) | |
gt = self.cv2_loader(self.gts[index], is_mask=False) | |
gt = gt[:, :, 0] | |
fgpath = os.path.basename(self.gts[index]).split('.png')[0].split('superpix-MIDDLE_') | |
fgpath = os.path.join(os.path.dirname(self.gts[index]), 'fgmask_' + fgpath[1] + '.png') | |
fg = self.cv2_loader(fgpath, is_mask=True) | |
dataset = self.get_dataset_name_from_path(self.images[index]) | |
# randomly choose a superpixels from the gt | |
gt[1-fg] = 0 | |
sp_id = random.choice(np.unique(gt)[1:]) | |
sp = (gt == sp_id).astype(np.uint8) | |
out = self.process_image_gt(image, gt, dataset) | |
support_image, support_sp, dataset = out["image"], out["label"], out["case"] | |
out = self.process_image_gt(image, sp, dataset) | |
query_image, query_sp, dataset = out["image"], out["label"], out["case"] | |
# TODO tile the masks to have 3 channels? | |
support_bg_mask = 1 - support_sp | |
support_masks = {"fg_mask": support_sp, "bg_mask": support_bg_mask} | |
batch = {"support_images" : [[support_image]], | |
"support_mask" : [[support_masks]], | |
"query_images" : [query_image], | |
"query_labels" : [query_sp], | |
"scan_id" : [dataset] | |
} | |
return batch | |
def get_superpix_polyp_dataset(image_size:tuple=(1024,1024), sam_trans=None): | |
transform_train, transform_test = get_polyp_transform() | |
image_root = './data/PolypDataset/TrainDataset/images/' | |
gt_root = './data/PolypDataset/TrainDataset/superpixels/' | |
ds_train = SuperpixPolypDataset(root=image_root, image_root=image_root, gt_root=gt_root, | |
augmentations=transform_train, | |
sam_trans=sam_trans, | |
image_size=image_size) | |
return ds_train | |
def get_polyp_dataset(image_size, sam_trans=None): | |
transform_train, transform_test = get_polyp_transform() | |
image_root = './data/PolypDataset/TrainDataset/images/' | |
gt_root = './data/PolypDataset/TrainDataset/masks/' | |
ds_train = PolypDataset(root=image_root, image_root=image_root, gt_root=gt_root, | |
augmentations=transform_test, sam_trans=sam_trans, train=True, image_size=image_size) | |
image_root = './data/PolypDataset/TestDataset/test/images/' | |
gt_root = './data/PolypDataset/TestDataset/test/masks/' | |
ds_test = PolypDataset(root=image_root, image_root=image_root, gt_root=gt_root, train=False, | |
augmentations=transform_test, sam_trans=sam_trans, image_size=image_size) | |
return ds_train, ds_test | |
def get_tests_polyp_dataset(sam_trans): | |
transform_train, transform_test = get_polyp_transform() | |
image_root = './data/polyp/TestDataset/Kvasir/images/' | |
gt_root = './data/polyp/TestDataset/Kvasir/masks/' | |
ds_Kvasir = PolypDataset( | |
image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) | |
image_root = './data/polyp/TestDataset/CVC-ClinicDB/images/' | |
gt_root = './data/polyp/TestDataset/CVC-ClinicDB/masks/' | |
ds_ClinicDB = PolypDataset( | |
image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) | |
image_root = './data/polyp/TestDataset/CVC-ColonDB/images/' | |
gt_root = './data/polyp/TestDataset/CVC-ColonDB/masks/' | |
ds_ColonDB = PolypDataset( | |
image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) | |
image_root = './data/polyp/TestDataset/ETIS-LaribPolypDB/images/' | |
gt_root = './data/polyp/TestDataset/ETIS-LaribPolypDB/masks/' | |
ds_ETIS = PolypDataset( | |
image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) | |
return ds_Kvasir, ds_ClinicDB, ds_ColonDB, ds_ETIS | |
if __name__ == '__main__': | |
# create_train_val_test_split_for_polyps() | |
create_suppport_set_for_polyps() | |