LoGoSAM_demo / dataloaders /PolypTransforms.py
quandn2003's picture
Upload folder using huggingface_hub
427d150 verified
from __future__ import division
import torch
import math
import sys
import random
from PIL import Image
try:
import accimage
except ImportError:
accimage = None
import numpy as np
import numbers
import types
import collections
import warnings
from torchvision.transforms import functional as F
if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable
else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "FiveCrop", "TenCrop",
"ColorJitter", "RandomRotation", "RandomAffine",
"RandomPerspective"]
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
Image.HAMMING: 'PIL.Image.HAMMING',
Image.BOX: 'PIL.Image.BOX',
}
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, mask):
for t in self.transforms:
img, mask = t(img, mask)
return img, mask
class ToTensor(object):
def __call__(self, img, mask):
# return F.to_tensor(img), F.to_tensor(mask)
img = np.array(img)
img = torch.from_numpy(img).permute(2, 0, 1).float() # TODO add division by 255 to match torch.ToTensor()?
mask = torch.from_numpy(np.array(mask)).float()
return img, mask
class ToPILImage(object):
def __init__(self, mode=None):
self.mode = mode
def __call__(self, img, mask):
return F.to_pil_image(img, self.mode), F.to_pil_image(mask, self.mode)
class Normalize(object):
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, img, mask):
return F.normalize(img, self.mean, self.std, self.inplace), mask
class Resize(object):
def __init__(self, size, interpolation=Image.BILINEAR, do_mask=True):
assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation
self.do_mask = do_mask
def __call__(self, img, mask):
if self.do_mask:
return F.resize(img, self.size, Image.BICUBIC), F.resize(mask, self.size, Image.BICUBIC)
else:
return F.resize(img, self.size, Image.BICUBIC), mask
class CenterCrop(object):
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img, mask):
return F.center_crop(img, self.size), F.center_crop(mask, self.size)
class Pad(object):
def __init__(self, padding, fill=0, padding_mode='constant'):
assert isinstance(padding, (numbers.Number, tuple))
assert isinstance(fill, (numbers.Number, str, tuple))
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
self.padding = padding
self.fill = fill
self.padding_mode = padding_mode
def __call__(self, img, mask):
return F.pad(img, self.padding, self.fill, self.padding_mode), \
F.pad(mask, self.padding, self.fill, self.padding_mode)
class Lambda(object):
def __init__(self, lambd):
assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
self.lambd = lambd
def __call__(self, img, mask):
return self.lambd(img), self.lambd(mask)
class Lambda_image(object):
def __init__(self, lambd):
assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
self.lambd = lambd
def __call__(self, img, mask):
return self.lambd(img), mask
class RandomTransforms(object):
def __init__(self, transforms):
assert isinstance(transforms, (list, tuple))
self.transforms = transforms
def __call__(self, *args, **kwargs):
raise NotImplementedError()
class RandomApply(RandomTransforms):
def __init__(self, transforms, p=0.5):
super(RandomApply, self).__init__(transforms)
self.p = p
def __call__(self, img, mask):
if self.p < random.random():
return img, mask
for t in self.transforms:
img, mask = t(img, mask)
return img, mask
class RandomOrder(RandomTransforms):
def __call__(self, img, mask):
order = list(range(len(self.transforms)))
random.shuffle(order)
for i in order:
img, mask = self.transforms[i](img, mask)
return img, mask
class RandomChoice(RandomTransforms):
def __call__(self, img, mask):
t = random.choice(self.transforms)
return t(img, mask)
class RandomCrop(object):
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill = fill
self.padding_mode = padding_mode
@staticmethod
def get_params(img, output_size):
w, h = img.size
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
def __call__(self, img, mask):
if self.padding is not None:
img = F.pad(img, self.padding, self.fill, self.padding_mode)
# pad the width if needed
if self.pad_if_needed and img.size[0] < self.size[1]:
img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
# pad the height if needed
if self.pad_if_needed and img.size[1] < self.size[0]:
img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
i, j, h, w = self.get_params(img, self.size)
return F.crop(img, i, j, h, w), F.crop(mask, i, j, h, w)
class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, img, mask):
if random.random() < self.p:
return F.hflip(img), F.hflip(mask)
return img, mask
class RandomVerticalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, img, mask):
if random.random() < self.p:
return F.vflip(img), F.vflip(mask)
return img, mask
class RandomPerspective(object):
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC):
self.p = p
self.interpolation = interpolation
self.distortion_scale = distortion_scale
def __call__(self, img, mask):
if not F._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if random.random() < self.p:
width, height = img.size
startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
return F.perspective(img, startpoints, endpoints, self.interpolation), \
F.perspective(mask, startpoints, endpoints, Image.NEAREST)
return img, mask
@staticmethod
def get_params(width, height, distortion_scale):
half_height = int(height / 2)
half_width = int(width / 2)
topleft = (random.randint(0, int(distortion_scale * half_width)),
random.randint(0, int(distortion_scale * half_height)))
topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
random.randint(0, int(distortion_scale * half_height)))
botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
botleft = (random.randint(0, int(distortion_scale * half_width)),
random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
endpoints = [topleft, topright, botright, botleft]
return startpoints, endpoints
class RandomResizedCrop(object):
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
if isinstance(size, tuple):
self.size = size
else:
self.size = (size, size)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
self.interpolation = interpolation
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img, scale, ratio):
area = img.size[0] * img.size[1]
for attempt in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
if (in_ratio < min(ratio)):
w = img.size[0]
h = w / min(ratio)
elif (in_ratio > max(ratio)):
h = img.size[1]
w = h * max(ratio)
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
return i, j, h, w
def __call__(self, img, mask):
i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), \
F.resized_crop(mask, i, j, h, w, self.size, Image.NEAREST)
class FiveCrop(object):
def __init__(self, size):
self.size = size
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
self.size = size
def __call__(self, img, mask):
return F.five_crop(img, self.size), F.five_crop(mask, self.size)
class TenCrop(object):
def __init__(self, size, vertical_flip=False):
self.size = size
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
self.size = size
self.vertical_flip = vertical_flip
def __call__(self, img, mask):
return F.ten_crop(img, self.size, self.vertical_flip), F.ten_crop(mask, self.size, self.vertical_flip)
class ColorJitter(object):
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
self.brightness = self._check_input(brightness, 'brightness')
self.contrast = self._check_input(contrast, 'contrast')
self.saturation = self._check_input(saturation, 'saturation')
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
clip_first_on_zero=False)
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
if isinstance(value, numbers.Number):
if value < 0:
raise ValueError("If {} is a single number, it must be non negative.".format(name))
value = [center - value, center + value]
if clip_first_on_zero:
value[0] = max(value[0], 0)
elif isinstance(value, (tuple, list)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError("{} values should be between {}".format(name, bound))
else:
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
value = None
return value
@staticmethod
def get_params(brightness, contrast, saturation, hue):
transforms = []
if brightness is not None:
brightness_factor = random.uniform(brightness[0], brightness[1])
transforms.append(Lambda_image(lambda img: F.adjust_brightness(img, brightness_factor)))
if contrast is not None:
contrast_factor = random.uniform(contrast[0], contrast[1])
transforms.append(Lambda_image(lambda img: F.adjust_contrast(img, contrast_factor)))
if saturation is not None:
saturation_factor = random.uniform(saturation[0], saturation[1])
transforms.append(Lambda_image(lambda img: F.adjust_saturation(img, saturation_factor)))
if hue is not None:
hue_factor = random.uniform(hue[0], hue[1])
transforms.append(Lambda_image(lambda img: F.adjust_hue(img, hue_factor)))
random.shuffle(transforms)
transform = Compose(transforms)
return transform
def __call__(self, img, mask):
transform = self.get_params(self.brightness, self.contrast,
self.saturation, self.hue)
return transform(img, mask)
class RandomRotation(object):
def __init__(self, degrees, resample=False, expand=False, center=None):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
if len(degrees) != 2:
raise ValueError("If degrees is a sequence, it must be of len 2.")
self.degrees = degrees
self.resample = resample
self.expand = expand
self.center = center
@staticmethod
def get_params(degrees):
angle = random.uniform(degrees[0], degrees[1])
return angle
def __call__(self, img, mask):
angle = self.get_params(self.degrees)
return F.rotate(img, angle, Image.BILINEAR, self.expand, self.center), \
F.rotate(mask, angle, Image.NEAREST, self.expand, self.center)
class RandomAffine(object):
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
"degrees should be a list or tuple and it must be of length 2."
self.degrees = degrees
if translate is not None:
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
"translate should be a list or tuple and it must be of length 2."
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
"scale should be a list or tuple and it must be of length 2."
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
if isinstance(shear, numbers.Number):
if shear < 0:
raise ValueError("If shear is a single number, it must be positive.")
self.shear = (-shear, shear)
else:
assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
"shear should be a list or tuple and it must be of length 2."
self.shear = shear
else:
self.shear = shear
self.resample = resample
self.fillcolor = fillcolor
@staticmethod
def get_params(degrees, translate, scale_ranges, shears, img_size):
angle = random.uniform(degrees[0], degrees[1])
if translate is not None:
max_dx = translate[0] * img_size[0]
max_dy = translate[1] * img_size[1]
translations = (np.round(random.uniform(-max_dx, max_dx)),
np.round(random.uniform(-max_dy, max_dy)))
else:
translations = (0, 0)
if scale_ranges is not None:
scale = random.uniform(scale_ranges[0], scale_ranges[1])
else:
scale = 1.0
if shears is not None:
shear = random.uniform(shears[0], shears[1])
else:
shear = 0.0
return angle, translations, scale, shear
def __call__(self, img, mask):
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
return F.affine(img, *ret, interpolation=Image.BILINEAR, fill=self.fillcolor), \
F.affine(mask, *ret, interpolation=Image.NEAREST, fill=self.fillcolor)
def get_cub_transform():
transform_train = Compose([
ToPILImage(),
Resize((256, 256)),
RandomHorizontalFlip(),
RandomAffine(22, scale=(0.75, 1.25)),
ToTensor(),
Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
])
transform_test = Compose([
ToPILImage(),
Resize((256, 256)),
ToTensor(),
Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
])
return transform_train, transform_test
def get_glas_transform():
transform_train = Compose([
ToPILImage(),
# Resize((256, 256)),
ColorJitter(brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.1),
RandomHorizontalFlip(),
RandomAffine(5, scale=(0.75, 1.25)),
ToTensor(),
# Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
])
transform_test = Compose([
ToPILImage(),
# Resize((256, 256)),
ToTensor(),
# Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
])
return transform_train, transform_test
# def get_glas_transform():
# transform_train = Compose([
# ToPILImage(),
# Resize((256, 256)),
# ColorJitter(brightness=0.2,
# contrast=0.2,
# saturation=0.2,
# hue=0.1),
# RandomHorizontalFlip(),
# RandomAffine(5, scale=(0.75, 1.25)),
# ToTensor(),
# Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
# ])
# transform_test = Compose([
# ToPILImage(),
# Resize((256, 256)),
# ToTensor(),
# Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])
# ])
# return transform_train, transform_test
def get_monu_transform(args):
Idim = int(args['Idim'])
transform_train = Compose([
ToPILImage(),
# Resize((Idim, Idim)),
ColorJitter(brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.1),
RandomHorizontalFlip(),
RandomAffine(int(args['rotate']), scale=(float(args['scale1']), float(args['scale2']))),
ToTensor(),
# Normalize(mean=[142.07, 98.48, 132.96], std=[65.78, 57.05, 57.78])
])
transform_test = Compose([
ToPILImage(),
# Resize((Idim, Idim)),
ToTensor(),
# Normalize(mean=[142.07, 98.48, 132.96], std=[65.78, 57.05, 57.78])
])
return transform_train, transform_test
def get_polyp_transform():
transform_train = Compose([
# Resize((352, 352)),
ToPILImage(),
ColorJitter(brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.1),
RandomVerticalFlip(),
RandomHorizontalFlip(),
RandomAffine(90, scale=(0.75, 1.25)),
ToTensor(),
# Normalize([105.61, 63.69, 45.67],
# [83.08, 55.86, 42.59])
])
transform_test = Compose([
# Resize((352, 352)),
ToPILImage(),
ToTensor(),
# Normalize([105.61, 63.69, 45.67],
# [83.08, 55.86, 42.59])
])
return transform_train, transform_test
def get_polyp_support_train_transform():
transform_train = Compose([
ColorJitter(brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.1),
RandomVerticalFlip(),
RandomHorizontalFlip(),
RandomAffine(90, scale=(0.75, 1.25)),
])
return transform_train