Spaces:
Sleeping
Sleeping
''' | |
Utilities for augmentation. Partly credit to Dr. Jo Schlemper | |
''' | |
from os.path import join | |
import torch | |
import numpy as np | |
import torchvision.transforms as deftfx | |
import dataloaders.image_transforms as myit | |
import copy | |
from util.consts import IMG_SIZE | |
import time | |
import functools | |
def get_sabs_aug(input_size, use_3d=False): | |
sabs_aug = { | |
# turn flipping off as medical data has fixed orientations | |
'flip': {'v': False, 'h': False, 't': False, 'p': 0.25}, | |
'affine': { | |
'rotate': 5, | |
'shift': (5, 5), | |
'shear': 5, | |
'scale': (0.9, 1.2), | |
}, | |
'elastic': {'alpha': 10, 'sigma': 5}, | |
'patch': input_size, | |
'reduce_2d': True, | |
'3d': use_3d, | |
'gamma_range': (0.5, 1.5) | |
} | |
return sabs_aug | |
def get_sabs_augv3(input_size): | |
sabs_augv3 = { | |
'flip': {'v': False, 'h': False, 't': False, 'p': 0.25}, | |
'affine': { | |
'rotate': 30, | |
'shift': (30, 30), | |
'shear': 30, | |
'scale': (0.8, 1.3), | |
}, | |
'elastic': {'alpha': 20, 'sigma': 5}, | |
'patch': input_size, | |
'reduce_2d': True, | |
'gamma_range': (0.2, 1.8) | |
} | |
return sabs_augv3 | |
def get_aug(which_aug, input_size): | |
if which_aug == 'sabs_aug': | |
return get_sabs_aug(input_size) | |
elif which_aug == 'aug_v3': | |
return get_sabs_augv3(input_size) | |
else: | |
raise NotImplementedError | |
# augs = { | |
# 'sabs_aug': get_sabs_aug, | |
# 'aug_v3': get_sabs_augv3, # more aggresive | |
# } | |
def get_geometric_transformer(aug, order=3): | |
"""order: interpolation degree. Select order=0 for augmenting segmentation """ | |
affine = aug['aug'].get('affine', 0) | |
alpha = aug['aug'].get('elastic', {'alpha': 0})['alpha'] | |
sigma = aug['aug'].get('elastic', {'sigma': 0})['sigma'] | |
flip = aug['aug'].get( | |
'flip', {'v': True, 'h': True, 't': True, 'p': 0.125}) | |
tfx = [] | |
if 'flip' in aug['aug']: | |
tfx.append(myit.RandomFlip3D(**flip)) | |
if 'affine' in aug['aug']: | |
tfx.append(myit.RandomAffine(affine.get('rotate'), | |
affine.get('shift'), | |
affine.get('shear'), | |
affine.get('scale'), | |
affine.get('scale_iso', True), | |
order=order)) | |
if 'elastic' in aug['aug']: | |
tfx.append(myit.ElasticTransform(alpha, sigma)) | |
input_transform = deftfx.Compose(tfx) | |
return input_transform | |
def get_geometric_transformer_3d(aug, order=3): | |
"""order: interpolation degree. Select order=0 for augmenting segmentation """ | |
affine = aug['aug'].get('affine', 0) | |
alpha = aug['aug'].get('elastic', {'alpha': 0})['alpha'] | |
sigma = aug['aug'].get('elastic', {'sigma': 0})['sigma'] | |
flip = aug['aug'].get( | |
'flip', {'v': True, 'h': True, 't': True, 'p': 0.125}) | |
tfx = [] | |
if 'flip' in aug['aug']: | |
tfx.append(myit.RandomFlip3D(**flip)) | |
if 'affine' in aug['aug']: | |
tfx.append(myit.RandomAffine(affine.get('rotate'), | |
affine.get('shift'), | |
affine.get('shear'), | |
affine.get('scale'), | |
affine.get('scale_iso', True), | |
order=order, | |
use_3d=True)) | |
if 'elastic' in aug['aug']: | |
tfx.append(myit.ElasticTransform(alpha, sigma)) | |
input_transform = deftfx.Compose(tfx) | |
return input_transform | |
def gamma_transform(img, aug): | |
gamma_range = aug['aug']['gamma_range'] | |
if isinstance(gamma_range, tuple): | |
gamma = np.random.rand() * \ | |
(gamma_range[1] - gamma_range[0]) + gamma_range[0] | |
cmin = img.min() | |
irange = (img.max() - cmin + 1e-5) | |
img = img - cmin + 1e-5 | |
img = irange * np.power(img * 1.0 / irange, gamma) | |
img = img + cmin | |
elif gamma_range == False: | |
pass | |
else: | |
raise ValueError( | |
"Cannot identify gamma transform range {}".format(gamma_range)) | |
return img | |
def get_intensity_transformer(aug): | |
"""some basic intensity transforms""" | |
return functools.partial(gamma_transform, aug=aug) | |
def transform_with_label(aug): | |
""" | |
Doing image geometric transform | |
Proposed image to have the following configurations | |
[H x W x C + CL] | |
Where CL is the number of channels for the label. It is NOT in one-hot form | |
""" | |
geometric_tfx = get_geometric_transformer(aug) | |
intensity_tfx = get_intensity_transformer(aug) | |
def transform(comp, c_label, c_img, use_onehot, nclass, **kwargs): | |
""" | |
Args | |
comp: a numpy array with shape [H x W x C + c_label] | |
c_label: number of channels for a compact label. Note that the current version only supports 1 slice (H x W x 1) | |
nc_onehot: -1 for not using one-hot representation of mask. otherwise, specify number of classes in the label | |
""" | |
comp = copy.deepcopy(comp) | |
if (use_onehot is True) and (c_label != 1): | |
raise NotImplementedError( | |
"Only allow compact label, also the label can only be 2d") | |
assert c_img + 1 == comp.shape[-1], "only allow single slice 2D label" | |
# geometric transform | |
_label = comp[..., c_img] | |
_h_label = np.float32(np.arange(nclass) == (_label[..., None])) | |
# _h_label = np.float32(_label[..., None]) | |
comp = np.concatenate([comp[..., :c_img], _h_label], -1) | |
comp = geometric_tfx(comp) | |
# round one_hot labels to 0 or 1 | |
t_label_h = comp[..., c_img:] | |
t_label_h = np.rint(t_label_h) | |
assert t_label_h.max() <= 1 | |
t_img = comp[..., 0: c_img] | |
# intensity transform | |
t_img = intensity_tfx(t_img) | |
if use_onehot is True: | |
t_label = t_label_h | |
else: | |
t_label = np.expand_dims(np.argmax(t_label_h, axis=-1), -1) | |
return t_img, t_label | |
return transform | |
def transform(scan, label, nclass, geometric_tfx, intensity_tfx): | |
""" | |
Args | |
scan: a numpy array with shape [D x H x W x C] | |
label: a numpy array with shape [D x H x W x 1] | |
""" | |
assert len(scan.shape) == 4, "Input scan must be 4D" | |
if len(label.shape) == 3: | |
label = np.expand_dims(label, -1) | |
# geometric transform | |
comp = copy.deepcopy(np.concatenate( | |
[scan, label], -1)) # [D x H x W x C + 1] | |
_label = comp[..., -1] | |
_h_label = np.float32(np.arange(nclass) == (_label[..., None])) | |
comp = np.concatenate([comp[..., :-1], _h_label], -1) | |
# change comp to be H x W x D x C + 1 | |
comp = np.transpose(comp, (1, 2, 0, 3)) | |
comp = geometric_tfx(comp) | |
t_label_h = comp[..., 1:] | |
t_label_h = np.rint(t_label_h) | |
assert t_label_h.max() <= 1 | |
t_img = comp[..., 0:1] | |
# intensity transform | |
t_img = intensity_tfx(t_img) | |
return t_img, t_label_h | |
def transform_wrapper(scan, label, nclass, geometric_tfx, intensity_tfx): | |
return transform(scan, label, nclass, geometric_tfx, intensity_tfx) | |