Spaces:
Running
on
Zero
Running
on
Zero
#https://github.com/wolny/pytorch-3dunet/blob/master/pytorch3dunet/unet3d/buildingblocks.py | |
# MIT License | |
# Copyright (c) 2018 Adrian Wolny | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
from functools import partial | |
import torch | |
from torch import nn as nn | |
from torch.nn import functional as F | |
# from pytorch3dunet.unet3d.se import ChannelSELayer3D, ChannelSpatialSELayer3D, SpatialSELayer3D | |
def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, | |
dropout_prob, is3d): | |
""" | |
Create a list of modules with together constitute a single conv layer with non-linearity | |
and optional batchnorm/groupnorm. | |
Args: | |
in_channels (int): number of input channels | |
out_channels (int): number of output channels | |
kernel_size(int or tuple): size of the convolving kernel | |
order (string): order of things, e.g. | |
'cr' -> conv + ReLU | |
'gcr' -> groupnorm + conv + ReLU | |
'cl' -> conv + LeakyReLU | |
'ce' -> conv + ELU | |
'bcr' -> batchnorm + conv + ReLU | |
'cbrd' -> conv + batchnorm + ReLU + dropout | |
'cbrD' -> conv + batchnorm + ReLU + dropout2d | |
num_groups (int): number of groups for the GroupNorm | |
padding (int or tuple): add zero-padding added to all three sides of the input | |
dropout_prob (float): dropout probability | |
is3d (bool): is3d (bool): if True use Conv3d, otherwise use Conv2d | |
Return: | |
list of tuple (name, module) | |
""" | |
assert 'c' in order, "Conv layer MUST be present" | |
assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' | |
modules = [] | |
for i, char in enumerate(order): | |
if char == 'r': | |
modules.append(('ReLU', nn.ReLU(inplace=True))) | |
elif char == 'l': | |
modules.append(('LeakyReLU', nn.LeakyReLU(inplace=True))) | |
elif char == 'e': | |
modules.append(('ELU', nn.ELU(inplace=True))) | |
elif char == 'c': | |
# add learnable bias only in the absence of batchnorm/groupnorm | |
bias = not ('g' in order or 'b' in order) | |
if is3d: | |
conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) | |
else: | |
conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) | |
modules.append(('conv', conv)) | |
elif char == 'g': | |
is_before_conv = i < order.index('c') | |
if is_before_conv: | |
num_channels = in_channels | |
else: | |
num_channels = out_channels | |
# use only one group if the given number of groups is greater than the number of channels | |
if num_channels < num_groups: | |
num_groups = 1 | |
assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' | |
modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) | |
elif char == 'b': | |
is_before_conv = i < order.index('c') | |
if is3d: | |
bn = nn.BatchNorm3d | |
else: | |
bn = nn.BatchNorm2d | |
if is_before_conv: | |
modules.append(('batchnorm', bn(in_channels))) | |
else: | |
modules.append(('batchnorm', bn(out_channels))) | |
elif char == 'd': | |
modules.append(('dropout', nn.Dropout(p=dropout_prob))) | |
elif char == 'D': | |
modules.append(('dropout2d', nn.Dropout2d(p=dropout_prob))) | |
else: | |
raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c', 'd', 'D']") | |
return modules | |
class SingleConv(nn.Sequential): | |
""" | |
Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order | |
of operations can be specified via the `order` parameter | |
Args: | |
in_channels (int): number of input channels | |
out_channels (int): number of output channels | |
kernel_size (int or tuple): size of the convolving kernel | |
order (string): determines the order of layers, e.g. | |
'cr' -> conv + ReLU | |
'crg' -> conv + ReLU + groupnorm | |
'cl' -> conv + LeakyReLU | |
'ce' -> conv + ELU | |
num_groups (int): number of groups for the GroupNorm | |
padding (int or tuple): add zero-padding | |
dropout_prob (float): dropout probability, default 0.1 | |
is3d (bool): if True use Conv3d, otherwise use Conv2d | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8, | |
padding=1, dropout_prob=0.1, is3d=True): | |
super(SingleConv, self).__init__() | |
for name, module in create_conv(in_channels, out_channels, kernel_size, order, | |
num_groups, padding, dropout_prob, is3d): | |
self.add_module(name, module) | |
class DoubleConv(nn.Sequential): | |
""" | |
A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). | |
We use (Conv3d+ReLU+GroupNorm3d) by default. | |
This can be changed however by providing the 'order' argument, e.g. in order | |
to change to Conv3d+BatchNorm3d+ELU use order='cbe'. | |
Use padded convolutions to make sure that the output (H_out, W_out) is the same | |
as (H_in, W_in), so that you don't have to crop in the decoder path. | |
Args: | |
in_channels (int): number of input channels | |
out_channels (int): number of output channels | |
encoder (bool): if True we're in the encoder path, otherwise we're in the decoder | |
kernel_size (int or tuple): size of the convolving kernel | |
order (string): determines the order of layers, e.g. | |
'cr' -> conv + ReLU | |
'crg' -> conv + ReLU + groupnorm | |
'cl' -> conv + LeakyReLU | |
'ce' -> conv + ELU | |
num_groups (int): number of groups for the GroupNorm | |
padding (int or tuple): add zero-padding added to all three sides of the input | |
upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2 | |
dropout_prob (float or tuple): dropout probability for each convolution, default 0.1 | |
is3d (bool): if True use Conv3d instead of Conv2d layers | |
""" | |
def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr', | |
num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True): | |
super(DoubleConv, self).__init__() | |
if encoder: | |
# we're in the encoder path | |
conv1_in_channels = in_channels | |
if upscale == 1: | |
conv1_out_channels = out_channels | |
else: | |
conv1_out_channels = out_channels // 2 | |
if conv1_out_channels < in_channels: | |
conv1_out_channels = in_channels | |
conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels | |
else: | |
# we're in the decoder path, decrease the number of channels in the 1st convolution | |
conv1_in_channels, conv1_out_channels = in_channels, out_channels | |
conv2_in_channels, conv2_out_channels = out_channels, out_channels | |
# check if dropout_prob is a tuple and if so | |
# split it for different dropout probabilities for each convolution. | |
if isinstance(dropout_prob, list) or isinstance(dropout_prob, tuple): | |
dropout_prob1 = dropout_prob[0] | |
dropout_prob2 = dropout_prob[1] | |
else: | |
dropout_prob1 = dropout_prob2 = dropout_prob | |
# conv1 | |
self.add_module('SingleConv1', | |
SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups, | |
padding=padding, dropout_prob=dropout_prob1, is3d=is3d)) | |
# conv2 | |
self.add_module('SingleConv2', | |
SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups, | |
padding=padding, dropout_prob=dropout_prob2, is3d=is3d)) | |
class ResNetBlock(nn.Module): | |
""" | |
Residual block that can be used instead of standard DoubleConv in the Encoder module. | |
Motivated by: https://arxiv.org/pdf/1706.00120.pdf | |
Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, is3d=True, **kwargs): | |
super(ResNetBlock, self).__init__() | |
if in_channels != out_channels: | |
# conv1x1 for increasing the number of channels | |
if is3d: | |
self.conv1 = nn.Conv3d(in_channels, out_channels, 1) | |
else: | |
self.conv1 = nn.Conv2d(in_channels, out_channels, 1) | |
else: | |
self.conv1 = nn.Identity() | |
self.conv2 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups, | |
is3d=is3d) | |
# remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual | |
n_order = order | |
for c in 'rel': | |
n_order = n_order.replace(c, '') | |
self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, | |
num_groups=num_groups, is3d=is3d) | |
# create non-linearity separately | |
if 'l' in order: | |
self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) | |
elif 'e' in order: | |
self.non_linearity = nn.ELU(inplace=True) | |
else: | |
self.non_linearity = nn.ReLU(inplace=True) | |
def forward(self, x): | |
# apply first convolution to bring the number of channels to out_channels | |
residual = self.conv1(x) | |
out = self.conv2(x) | |
out = self.conv3(out) | |
out += residual | |
out = self.non_linearity(out) | |
return out | |
class Encoder(nn.Module): | |
""" | |
A single module from the encoder path consisting of the optional max | |
pooling layer (one may specify the MaxPool kernel_size to be different | |
from the standard (2,2,2), e.g. if the volumetric data is anisotropic | |
(make sure to use complementary scale_factor in the decoder path) followed by | |
a basic module (DoubleConv or ResNetBlock). | |
Args: | |
in_channels (int): number of input channels | |
out_channels (int): number of output channels | |
conv_kernel_size (int or tuple): size of the convolving kernel | |
apply_pooling (bool): if True use MaxPool3d before DoubleConv | |
pool_kernel_size (int or tuple): the size of the window | |
pool_type (str): pooling layer: 'max' or 'avg' | |
basic_module(nn.Module): either ResNetBlock or DoubleConv | |
conv_layer_order (string): determines the order of layers | |
in `DoubleConv` module. See `DoubleConv` for more info. | |
num_groups (int): number of groups for the GroupNorm | |
padding (int or tuple): add zero-padding added to all three sides of the input | |
upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2 | |
dropout_prob (float or tuple): dropout probability, default 0.1 | |
is3d (bool): use 3d or 2d convolutions/pooling operation | |
""" | |
def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, | |
pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr', | |
num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True): | |
super(Encoder, self).__init__() | |
assert pool_type in ['max', 'avg'] | |
if apply_pooling: | |
if pool_type == 'max': | |
if is3d: | |
self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) | |
else: | |
self.pooling = nn.MaxPool2d(kernel_size=pool_kernel_size) | |
else: | |
if is3d: | |
self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) | |
else: | |
self.pooling = nn.AvgPool2d(kernel_size=pool_kernel_size) | |
else: | |
self.pooling = None | |
self.basic_module = basic_module(in_channels, out_channels, | |
encoder=True, | |
kernel_size=conv_kernel_size, | |
order=conv_layer_order, | |
num_groups=num_groups, | |
padding=padding, | |
upscale=upscale, | |
dropout_prob=dropout_prob, | |
is3d=is3d) | |
def forward(self, x): | |
if self.pooling is not None: | |
x = self.pooling(x) | |
x = self.basic_module(x) | |
return x | |
class Decoder(nn.Module): | |
""" | |
A single module for decoder path consisting of the upsampling layer | |
(either learned ConvTranspose3d or nearest neighbor interpolation) | |
followed by a basic module (DoubleConv or ResNetBlock). | |
Args: | |
in_channels (int): number of input channels | |
out_channels (int): number of output channels | |
conv_kernel_size (int or tuple): size of the convolving kernel | |
scale_factor (int or tuple): used as the multiplier for the image H/W/D in | |
case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation | |
from the corresponding encoder | |
basic_module(nn.Module): either ResNetBlock or DoubleConv | |
conv_layer_order (string): determines the order of layers | |
in `DoubleConv` module. See `DoubleConv` for more info. | |
num_groups (int): number of groups for the GroupNorm | |
padding (int or tuple): add zero-padding added to all three sides of the input | |
upsample (str): algorithm used for upsampling: | |
InterpolateUpsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area' | |
TransposeConvUpsampling: 'deconv' | |
No upsampling: None | |
Default: 'default' (chooses automatically) | |
dropout_prob (float or tuple): dropout probability, default 0.1 | |
""" | |
def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=2, basic_module=DoubleConv, | |
conv_layer_order='gcr', num_groups=8, padding=1, upsample='default', | |
dropout_prob=0.1, is3d=True): | |
super(Decoder, self).__init__() | |
# perform concat joining per default | |
concat = True | |
# don't adapt channels after join operation | |
adapt_channels = False | |
if upsample is not None and upsample != 'none': | |
if upsample == 'default': | |
if basic_module == DoubleConv: | |
upsample = 'nearest' # use nearest neighbor interpolation for upsampling | |
concat = True # use concat joining | |
adapt_channels = False # don't adapt channels | |
elif basic_module == ResNetBlock: #or basic_module == ResNetBlockSE: | |
upsample = 'deconv' # use deconvolution upsampling | |
concat = False # use summation joining | |
adapt_channels = True # adapt channels after joining | |
# perform deconvolution upsampling if mode is deconv | |
if upsample == 'deconv': | |
self.upsampling = TransposeConvUpsampling(in_channels=in_channels, out_channels=out_channels, | |
kernel_size=conv_kernel_size, scale_factor=scale_factor, | |
is3d=is3d) | |
else: | |
self.upsampling = InterpolateUpsampling(mode=upsample) | |
else: | |
# no upsampling | |
self.upsampling = NoUpsampling() | |
# concat joining | |
self.joining = partial(self._joining, concat=True) | |
# perform joining operation | |
self.joining = partial(self._joining, concat=concat) | |
# adapt the number of in_channels for the ResNetBlock | |
if adapt_channels is True: | |
in_channels = out_channels | |
self.basic_module = basic_module(in_channels, out_channels, | |
encoder=False, | |
kernel_size=conv_kernel_size, | |
order=conv_layer_order, | |
num_groups=num_groups, | |
padding=padding, | |
dropout_prob=dropout_prob, | |
is3d=is3d) | |
def forward(self, encoder_features, x): | |
x = self.upsampling(encoder_features=encoder_features, x=x) | |
x = self.joining(encoder_features, x) | |
x = self.basic_module(x) | |
return x | |
def _joining(encoder_features, x, concat): | |
if concat: | |
return torch.cat((encoder_features, x), dim=1) | |
else: | |
return encoder_features + x | |
def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, | |
conv_upscale, dropout_prob, | |
layer_order, num_groups, pool_kernel_size, is3d): | |
# create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` | |
encoders = [] | |
for i, out_feature_num in enumerate(f_maps): | |
if i == 0: | |
# apply conv_coord only in the first encoder if any | |
encoder = Encoder(in_channels, out_feature_num, | |
apply_pooling=False, # skip pooling in the firs encoder | |
basic_module=basic_module, | |
conv_layer_order=layer_order, | |
conv_kernel_size=conv_kernel_size, | |
num_groups=num_groups, | |
padding=conv_padding, | |
upscale=conv_upscale, | |
dropout_prob=dropout_prob, | |
is3d=is3d) | |
else: | |
encoder = Encoder(f_maps[i - 1], out_feature_num, | |
basic_module=basic_module, | |
conv_layer_order=layer_order, | |
conv_kernel_size=conv_kernel_size, | |
num_groups=num_groups, | |
pool_kernel_size=pool_kernel_size, | |
padding=conv_padding, | |
upscale=conv_upscale, | |
dropout_prob=dropout_prob, | |
is3d=is3d) | |
encoders.append(encoder) | |
return nn.ModuleList(encoders) | |
def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, | |
num_groups, upsample, dropout_prob, is3d): | |
# create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1` | |
decoders = [] | |
reversed_f_maps = list(reversed(f_maps[1:])) | |
for i in range(len(reversed_f_maps) - 1): | |
if basic_module == DoubleConv and upsample != 'deconv': | |
in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] | |
else: | |
in_feature_num = reversed_f_maps[i] | |
out_feature_num = reversed_f_maps[i + 1] | |
decoder = Decoder(in_feature_num, out_feature_num, | |
basic_module=basic_module, | |
conv_layer_order=layer_order, | |
conv_kernel_size=conv_kernel_size, | |
num_groups=num_groups, | |
padding=conv_padding, | |
upsample=upsample, | |
dropout_prob=dropout_prob, | |
is3d=is3d) | |
decoders.append(decoder) | |
return nn.ModuleList(decoders) | |
class AbstractUpsampling(nn.Module): | |
""" | |
Abstract class for upsampling. A given implementation should upsample a given 5D input tensor using either | |
interpolation or learned transposed convolution. | |
""" | |
def __init__(self, upsample): | |
super(AbstractUpsampling, self).__init__() | |
self.upsample = upsample | |
def forward(self, encoder_features, x): | |
# get the spatial dimensions of the output given the encoder_features | |
output_size = encoder_features.size()[2:] | |
# upsample the input and return | |
return self.upsample(x, output_size) | |
class InterpolateUpsampling(AbstractUpsampling): | |
""" | |
Args: | |
mode (str): algorithm used for upsampling: | |
'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest' | |
used only if transposed_conv is False | |
""" | |
def __init__(self, mode='nearest'): | |
upsample = partial(self._interpolate, mode=mode) | |
super().__init__(upsample) | |
def _interpolate(x, size, mode): | |
return F.interpolate(x, size=size, mode=mode) | |
class TransposeConvUpsampling(AbstractUpsampling): | |
""" | |
Args: | |
in_channels (int): number of input channels for transposed conv | |
used only if transposed_conv is True | |
out_channels (int): number of output channels for transpose conv | |
used only if transposed_conv is True | |
kernel_size (int or tuple): size of the convolving kernel | |
used only if transposed_conv is True | |
scale_factor (int or tuple): stride of the convolution | |
used only if transposed_conv is True | |
is3d (bool): if True use ConvTranspose3d, otherwise use ConvTranspose2d | |
""" | |
class Upsample(nn.Module): | |
""" | |
Workaround the 'ValueError: requested an output size...' in the `_output_padding` method in | |
transposed convolution. It performs transposed conv followed by the interpolation to the correct size if necessary. | |
""" | |
def __init__(self, conv_transposed, is3d): | |
super().__init__() | |
self.conv_transposed = conv_transposed | |
self.is3d = is3d | |
def forward(self, x, size): | |
x = self.conv_transposed(x) | |
return F.interpolate(x, size=size) | |
def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=2, is3d=True): | |
# make sure that the output size reverses the MaxPool3d from the corresponding encoder | |
if is3d is True: | |
conv_transposed = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, | |
stride=scale_factor, padding=1, bias=False) | |
else: | |
conv_transposed = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, | |
stride=scale_factor, padding=1, bias=False) | |
upsample = self.Upsample(conv_transposed, is3d) | |
super().__init__(upsample) | |
class NoUpsampling(AbstractUpsampling): | |
def __init__(self): | |
super().__init__(self._no_upsampling) | |
def _no_upsampling(x, size): | |
return x |