File size: 7,466 Bytes
427d150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
'''

Utilities for augmentation. Partly credit to Dr. Jo Schlemper

'''
from os.path import join

import torch
import numpy as np
import torchvision.transforms as deftfx
import dataloaders.image_transforms as myit
import copy
from util.consts import IMG_SIZE
import time
import functools


def get_sabs_aug(input_size, use_3d=False):
    sabs_aug = {
        # turn flipping off as medical data has fixed orientations
        'flip': {'v': False, 'h': False, 't': False, 'p': 0.25},
        'affine': {
            'rotate': 5,
            'shift': (5, 5),
            'shear': 5,
            'scale': (0.9, 1.2),
        },
        'elastic': {'alpha': 10, 'sigma': 5},
        'patch': input_size,
        'reduce_2d': True,
        '3d': use_3d,
        'gamma_range': (0.5, 1.5)
    }
    return sabs_aug


def get_sabs_augv3(input_size):
    sabs_augv3 = {
        'flip': {'v': False, 'h': False, 't': False, 'p': 0.25},
        'affine': {
            'rotate': 30,
            'shift': (30, 30),
            'shear': 30,
            'scale': (0.8, 1.3),
        },
        'elastic': {'alpha': 20, 'sigma': 5},
        'patch': input_size,
        'reduce_2d': True,
        'gamma_range': (0.2, 1.8)
    }
    return sabs_augv3


def get_aug(which_aug, input_size):
    if which_aug == 'sabs_aug':
        return get_sabs_aug(input_size)
    elif which_aug == 'aug_v3':
        return get_sabs_augv3(input_size)
    else:
        raise NotImplementedError

# augs = {
#     'sabs_aug': get_sabs_aug,
#     'aug_v3': get_sabs_augv3, # more aggresive
# }


def get_geometric_transformer(aug, order=3):
    """order: interpolation degree. Select order=0 for augmenting segmentation """
    affine = aug['aug'].get('affine', 0)
    alpha = aug['aug'].get('elastic', {'alpha': 0})['alpha']
    sigma = aug['aug'].get('elastic', {'sigma': 0})['sigma']
    flip = aug['aug'].get(
        'flip', {'v': True, 'h': True, 't': True, 'p': 0.125})

    tfx = []
    if 'flip' in aug['aug']:
        tfx.append(myit.RandomFlip3D(**flip))

    if 'affine' in aug['aug']:
        tfx.append(myit.RandomAffine(affine.get('rotate'),
                                     affine.get('shift'),
                                     affine.get('shear'),
                                     affine.get('scale'),
                                     affine.get('scale_iso', True),
                                     order=order))

    if 'elastic' in aug['aug']:
        tfx.append(myit.ElasticTransform(alpha, sigma))
    input_transform = deftfx.Compose(tfx)
    return input_transform


def get_geometric_transformer_3d(aug, order=3):
    """order: interpolation degree. Select order=0 for augmenting segmentation """
    affine = aug['aug'].get('affine', 0)
    alpha = aug['aug'].get('elastic', {'alpha': 0})['alpha']
    sigma = aug['aug'].get('elastic', {'sigma': 0})['sigma']
    flip = aug['aug'].get(
        'flip', {'v': True, 'h': True, 't': True, 'p': 0.125})

    tfx = []
    if 'flip' in aug['aug']:
        tfx.append(myit.RandomFlip3D(**flip))

    if 'affine' in aug['aug']:
        tfx.append(myit.RandomAffine(affine.get('rotate'),
                                     affine.get('shift'),
                                     affine.get('shear'),
                                     affine.get('scale'),
                                     affine.get('scale_iso', True),
                                     order=order,
                                     use_3d=True))

    if 'elastic' in aug['aug']:
        tfx.append(myit.ElasticTransform(alpha, sigma))
    input_transform = deftfx.Compose(tfx)
    return input_transform


def gamma_transform(img, aug):
    gamma_range = aug['aug']['gamma_range']
    if isinstance(gamma_range, tuple):
        gamma = np.random.rand() * \
            (gamma_range[1] - gamma_range[0]) + gamma_range[0]
        cmin = img.min()
        irange = (img.max() - cmin + 1e-5)

        img = img - cmin + 1e-5
        img = irange * np.power(img * 1.0 / irange,  gamma)
        img = img + cmin

    elif gamma_range == False:
        pass
    else:
        raise ValueError(
            "Cannot identify gamma transform range {}".format(gamma_range))
    return img


def get_intensity_transformer(aug):
    """some basic intensity transforms"""
    return functools.partial(gamma_transform, aug=aug)


def transform_with_label(aug):
    """

    Doing image geometric transform

    Proposed image to have the following configurations

    [H x W x C + CL]

    Where CL is the number of channels for the label. It is NOT in one-hot form

    """

    geometric_tfx = get_geometric_transformer(aug)
    intensity_tfx = get_intensity_transformer(aug)

    def transform(comp, c_label, c_img, use_onehot, nclass, **kwargs):
        """

        Args

        comp:               a numpy array with shape [H x W x C + c_label]

        c_label:            number of channels for a compact label. Note that the current version only supports 1 slice (H x W x 1)

        nc_onehot:          -1 for not using one-hot representation of mask. otherwise, specify number of classes in the label



        """
        comp = copy.deepcopy(comp)
        if (use_onehot is True) and (c_label != 1):
            raise NotImplementedError(
                "Only allow compact label, also the label can only be 2d")
        assert c_img + 1 == comp.shape[-1], "only allow single slice 2D label"

        # geometric transform
        _label = comp[..., c_img]
        _h_label = np.float32(np.arange(nclass) == (_label[..., None]))
        # _h_label = np.float32(_label[..., None])
        comp = np.concatenate([comp[..., :c_img], _h_label], -1)
        comp = geometric_tfx(comp)
        # round one_hot labels to 0 or 1
        t_label_h = comp[..., c_img:]
        t_label_h = np.rint(t_label_h)
        assert t_label_h.max() <= 1
        t_img = comp[..., 0: c_img]

        # intensity transform
        t_img = intensity_tfx(t_img)

        if use_onehot is True:
            t_label = t_label_h
        else:
            t_label = np.expand_dims(np.argmax(t_label_h, axis=-1), -1)
        return t_img, t_label

    return transform


def transform(scan, label, nclass, geometric_tfx, intensity_tfx):
    """

    Args

    scan: a numpy array with shape [D x H x W x C]

    label: a numpy array with shape [D x H x W x 1]

    """
    assert len(scan.shape) == 4, "Input scan must be 4D"
    if len(label.shape) == 3:
        label = np.expand_dims(label, -1)

    # geometric transform
    comp = copy.deepcopy(np.concatenate(
        [scan, label], -1))  # [D x H x W x C + 1]
    _label = comp[..., -1]
    _h_label = np.float32(np.arange(nclass) == (_label[..., None]))
    comp = np.concatenate([comp[..., :-1], _h_label], -1)
    # change comp to be H x W x D x C + 1
    comp = np.transpose(comp, (1, 2, 0, 3))
    comp = geometric_tfx(comp)
    t_label_h = comp[..., 1:]
    t_label_h = np.rint(t_label_h)
    assert t_label_h.max() <= 1
    t_img = comp[..., 0:1]

    # intensity transform
    t_img = intensity_tfx(t_img)
    return t_img, t_label_h


def transform_wrapper(scan, label, nclass, geometric_tfx, intensity_tfx):
    return transform(scan, label, nclass, geometric_tfx, intensity_tfx)