|
""" Normalization + Activation Layers |
|
""" |
|
import torch |
|
from torch import nn as nn |
|
from torch.nn import functional as F |
|
|
|
from .create_act import get_act_layer |
|
|
|
|
|
class BatchNormAct2d(nn.BatchNorm2d): |
|
"""BatchNorm + Activation |
|
|
|
This module performs BatchNorm + Activation in a manner that will remain backwards |
|
compatible with weights trained with separate bn, act. This is why we inherit from BN |
|
instead of composing it as a .bn member. |
|
""" |
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, |
|
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): |
|
super(BatchNormAct2d, self).__init__( |
|
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) |
|
if isinstance(act_layer, str): |
|
act_layer = get_act_layer(act_layer) |
|
if act_layer is not None and apply_act: |
|
act_args = dict(inplace=True) if inplace else {} |
|
self.act = act_layer(**act_args) |
|
else: |
|
self.act = nn.Identity() |
|
|
|
def _forward_jit(self, x): |
|
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function |
|
""" |
|
|
|
|
|
|
|
if self.momentum is None: |
|
exponential_average_factor = 0.0 |
|
else: |
|
exponential_average_factor = self.momentum |
|
|
|
if self.training and self.track_running_stats: |
|
|
|
if self.num_batches_tracked is not None: |
|
self.num_batches_tracked += 1 |
|
if self.momentum is None: |
|
exponential_average_factor = 1.0 / float(self.num_batches_tracked) |
|
else: |
|
exponential_average_factor = self.momentum |
|
|
|
x = F.batch_norm( |
|
x, self.running_mean, self.running_var, self.weight, self.bias, |
|
self.training or not self.track_running_stats, |
|
exponential_average_factor, self.eps) |
|
return x |
|
|
|
@torch.jit.ignore |
|
def _forward_python(self, x): |
|
return super(BatchNormAct2d, self).forward(x) |
|
|
|
def forward(self, x): |
|
|
|
if torch.jit.is_scripting(): |
|
x = self._forward_jit(x) |
|
else: |
|
x = self._forward_python(x) |
|
x = self.act(x) |
|
return x |
|
|
|
|
|
class GroupNormAct(nn.GroupNorm): |
|
|
|
def __init__(self, num_channels, num_groups, eps=1e-5, affine=True, |
|
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): |
|
super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) |
|
if isinstance(act_layer, str): |
|
act_layer = get_act_layer(act_layer) |
|
if act_layer is not None and apply_act: |
|
act_args = dict(inplace=True) if inplace else {} |
|
self.act = act_layer(**act_args) |
|
else: |
|
self.act = nn.Identity() |
|
|
|
def forward(self, x): |
|
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) |
|
x = self.act(x) |
|
return x |
|
|