|
""" |
|
pnasnet5large implementation grabbed from Cadene's pretrained models |
|
Additional credit to https://github.com/creafz |
|
|
|
https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py |
|
|
|
""" |
|
from collections import OrderedDict |
|
from functools import partial |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .helpers import build_model_with_cfg |
|
from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier |
|
from .registry import register_model |
|
|
|
__all__ = ['PNASNet5Large'] |
|
|
|
default_cfgs = { |
|
'pnasnet5large': { |
|
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth', |
|
'input_size': (3, 331, 331), |
|
'pool_size': (11, 11), |
|
'crop_pct': 0.911, |
|
'interpolation': 'bicubic', |
|
'mean': (0.5, 0.5, 0.5), |
|
'std': (0.5, 0.5, 0.5), |
|
'num_classes': 1000, |
|
'first_conv': 'conv_0.conv', |
|
'classifier': 'last_linear', |
|
'label_offset': 1, |
|
}, |
|
} |
|
|
|
|
|
class SeparableConv2d(nn.Module): |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''): |
|
super(SeparableConv2d, self).__init__() |
|
self.depthwise_conv2d = create_conv2d( |
|
in_channels, in_channels, kernel_size=kernel_size, |
|
stride=stride, padding=padding, groups=in_channels) |
|
self.pointwise_conv2d = create_conv2d( |
|
in_channels, out_channels, kernel_size=1, padding=padding) |
|
|
|
def forward(self, x): |
|
x = self.depthwise_conv2d(x) |
|
x = self.pointwise_conv2d(x) |
|
return x |
|
|
|
|
|
class BranchSeparables(nn.Module): |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, stem_cell=False, padding=''): |
|
super(BranchSeparables, self).__init__() |
|
middle_channels = out_channels if stem_cell else in_channels |
|
self.act_1 = nn.ReLU() |
|
self.separable_1 = SeparableConv2d( |
|
in_channels, middle_channels, kernel_size, stride=stride, padding=padding) |
|
self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001) |
|
self.act_2 = nn.ReLU() |
|
self.separable_2 = SeparableConv2d( |
|
middle_channels, out_channels, kernel_size, stride=1, padding=padding) |
|
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001) |
|
|
|
def forward(self, x): |
|
x = self.act_1(x) |
|
x = self.separable_1(x) |
|
x = self.bn_sep_1(x) |
|
x = self.act_2(x) |
|
x = self.separable_2(x) |
|
x = self.bn_sep_2(x) |
|
return x |
|
|
|
|
|
class ActConvBn(nn.Module): |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''): |
|
super(ActConvBn, self).__init__() |
|
self.act = nn.ReLU() |
|
self.conv = create_conv2d( |
|
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) |
|
self.bn = nn.BatchNorm2d(out_channels, eps=0.001) |
|
|
|
def forward(self, x): |
|
x = self.act(x) |
|
x = self.conv(x) |
|
x = self.bn(x) |
|
return x |
|
|
|
|
|
class FactorizedReduction(nn.Module): |
|
|
|
def __init__(self, in_channels, out_channels, padding=''): |
|
super(FactorizedReduction, self).__init__() |
|
self.act = nn.ReLU() |
|
self.path_1 = nn.Sequential(OrderedDict([ |
|
('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), |
|
('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)), |
|
])) |
|
self.path_2 = nn.Sequential(OrderedDict([ |
|
('pad', nn.ZeroPad2d((-1, 1, -1, 1))), |
|
('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), |
|
('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)), |
|
])) |
|
self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001) |
|
|
|
def forward(self, x): |
|
x = self.act(x) |
|
x_path1 = self.path_1(x) |
|
x_path2 = self.path_2(x) |
|
out = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) |
|
return out |
|
|
|
|
|
class CellBase(nn.Module): |
|
|
|
def cell_forward(self, x_left, x_right): |
|
x_comb_iter_0_left = self.comb_iter_0_left(x_left) |
|
x_comb_iter_0_right = self.comb_iter_0_right(x_left) |
|
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right |
|
|
|
x_comb_iter_1_left = self.comb_iter_1_left(x_right) |
|
x_comb_iter_1_right = self.comb_iter_1_right(x_right) |
|
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right |
|
|
|
x_comb_iter_2_left = self.comb_iter_2_left(x_right) |
|
x_comb_iter_2_right = self.comb_iter_2_right(x_right) |
|
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right |
|
|
|
x_comb_iter_3_left = self.comb_iter_3_left(x_comb_iter_2) |
|
x_comb_iter_3_right = self.comb_iter_3_right(x_right) |
|
x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right |
|
|
|
x_comb_iter_4_left = self.comb_iter_4_left(x_left) |
|
if self.comb_iter_4_right is not None: |
|
x_comb_iter_4_right = self.comb_iter_4_right(x_right) |
|
else: |
|
x_comb_iter_4_right = x_right |
|
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right |
|
|
|
x_out = torch.cat([x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) |
|
return x_out |
|
|
|
|
|
class CellStem0(CellBase): |
|
|
|
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): |
|
super(CellStem0, self).__init__() |
|
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type) |
|
|
|
self.comb_iter_0_left = BranchSeparables( |
|
in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=pad_type) |
|
self.comb_iter_0_right = nn.Sequential(OrderedDict([ |
|
('max_pool', create_pool2d('max', 3, stride=2, padding=pad_type)), |
|
('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type)), |
|
('bn', nn.BatchNorm2d(out_chs_left, eps=0.001)), |
|
])) |
|
|
|
self.comb_iter_1_left = BranchSeparables( |
|
out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=pad_type) |
|
self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=pad_type) |
|
|
|
self.comb_iter_2_left = BranchSeparables( |
|
out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=pad_type) |
|
self.comb_iter_2_right = BranchSeparables( |
|
out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=pad_type) |
|
|
|
self.comb_iter_3_left = BranchSeparables( |
|
out_chs_right, out_chs_right, kernel_size=3, padding=pad_type) |
|
self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=pad_type) |
|
|
|
self.comb_iter_4_left = BranchSeparables( |
|
in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=pad_type) |
|
self.comb_iter_4_right = ActConvBn( |
|
out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=pad_type) |
|
|
|
def forward(self, x_left): |
|
x_right = self.conv_1x1(x_left) |
|
x_out = self.cell_forward(x_left, x_right) |
|
return x_out |
|
|
|
|
|
class Cell(CellBase): |
|
|
|
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type='', |
|
is_reduction=False, match_prev_layer_dims=False): |
|
super(Cell, self).__init__() |
|
|
|
|
|
|
|
|
|
stride = 2 if is_reduction else 1 |
|
|
|
|
|
|
|
|
|
self.match_prev_layer_dimensions = match_prev_layer_dims |
|
if match_prev_layer_dims: |
|
self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=pad_type) |
|
else: |
|
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type) |
|
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type) |
|
|
|
self.comb_iter_0_left = BranchSeparables( |
|
out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=pad_type) |
|
self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=pad_type) |
|
|
|
self.comb_iter_1_left = BranchSeparables( |
|
out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=pad_type) |
|
self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=pad_type) |
|
|
|
self.comb_iter_2_left = BranchSeparables( |
|
out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=pad_type) |
|
self.comb_iter_2_right = BranchSeparables( |
|
out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=pad_type) |
|
|
|
self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3) |
|
self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=pad_type) |
|
|
|
self.comb_iter_4_left = BranchSeparables( |
|
out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=pad_type) |
|
if is_reduction: |
|
self.comb_iter_4_right = ActConvBn( |
|
out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=pad_type) |
|
else: |
|
self.comb_iter_4_right = None |
|
|
|
def forward(self, x_left, x_right): |
|
x_left = self.conv_prev_1x1(x_left) |
|
x_right = self.conv_1x1(x_right) |
|
x_out = self.cell_forward(x_left, x_right) |
|
return x_out |
|
|
|
|
|
class PNASNet5Large(nn.Module): |
|
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''): |
|
super(PNASNet5Large, self).__init__() |
|
self.num_classes = num_classes |
|
self.drop_rate = drop_rate |
|
self.num_features = 4320 |
|
assert output_stride == 32 |
|
|
|
self.conv_0 = ConvBnAct( |
|
in_chans, 96, kernel_size=3, stride=2, padding=0, |
|
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False) |
|
|
|
self.cell_stem_0 = CellStem0( |
|
in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type) |
|
|
|
self.cell_stem_1 = Cell( |
|
in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, pad_type=pad_type, |
|
match_prev_layer_dims=True, is_reduction=True) |
|
self.cell_0 = Cell( |
|
in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, pad_type=pad_type, |
|
match_prev_layer_dims=True) |
|
self.cell_1 = Cell( |
|
in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) |
|
self.cell_2 = Cell( |
|
in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) |
|
self.cell_3 = Cell( |
|
in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) |
|
|
|
self.cell_4 = Cell( |
|
in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, pad_type=pad_type, |
|
is_reduction=True) |
|
self.cell_5 = Cell( |
|
in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type, |
|
match_prev_layer_dims=True) |
|
self.cell_6 = Cell( |
|
in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type) |
|
self.cell_7 = Cell( |
|
in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type) |
|
|
|
self.cell_8 = Cell( |
|
in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, pad_type=pad_type, |
|
is_reduction=True) |
|
self.cell_9 = Cell( |
|
in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type, |
|
match_prev_layer_dims=True) |
|
self.cell_10 = Cell( |
|
in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type) |
|
self.cell_11 = Cell( |
|
in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type) |
|
self.act = nn.ReLU() |
|
self.feature_info = [ |
|
dict(num_chs=96, reduction=2, module='conv_0'), |
|
dict(num_chs=270, reduction=4, module='cell_stem_1.conv_1x1.act'), |
|
dict(num_chs=1080, reduction=8, module='cell_4.conv_1x1.act'), |
|
dict(num_chs=2160, reduction=16, module='cell_8.conv_1x1.act'), |
|
dict(num_chs=4320, reduction=32, module='act'), |
|
] |
|
|
|
self.global_pool, self.last_linear = create_classifier( |
|
self.num_features, self.num_classes, pool_type=global_pool) |
|
|
|
def get_classifier(self): |
|
return self.last_linear |
|
|
|
def reset_classifier(self, num_classes, global_pool='avg'): |
|
self.num_classes = num_classes |
|
self.global_pool, self.last_linear = create_classifier( |
|
self.num_features, self.num_classes, pool_type=global_pool) |
|
|
|
def forward_features(self, x): |
|
x_conv_0 = self.conv_0(x) |
|
x_stem_0 = self.cell_stem_0(x_conv_0) |
|
x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0) |
|
x_cell_0 = self.cell_0(x_stem_0, x_stem_1) |
|
x_cell_1 = self.cell_1(x_stem_1, x_cell_0) |
|
x_cell_2 = self.cell_2(x_cell_0, x_cell_1) |
|
x_cell_3 = self.cell_3(x_cell_1, x_cell_2) |
|
x_cell_4 = self.cell_4(x_cell_2, x_cell_3) |
|
x_cell_5 = self.cell_5(x_cell_3, x_cell_4) |
|
x_cell_6 = self.cell_6(x_cell_4, x_cell_5) |
|
x_cell_7 = self.cell_7(x_cell_5, x_cell_6) |
|
x_cell_8 = self.cell_8(x_cell_6, x_cell_7) |
|
x_cell_9 = self.cell_9(x_cell_7, x_cell_8) |
|
x_cell_10 = self.cell_10(x_cell_8, x_cell_9) |
|
x_cell_11 = self.cell_11(x_cell_9, x_cell_10) |
|
x = self.act(x_cell_11) |
|
return x |
|
|
|
def forward(self, x): |
|
x = self.forward_features(x) |
|
x = self.global_pool(x) |
|
if self.drop_rate > 0: |
|
x = F.dropout(x, self.drop_rate, training=self.training) |
|
x = self.last_linear(x) |
|
return x |
|
|
|
|
|
def _create_pnasnet(variant, pretrained=False, **kwargs): |
|
return build_model_with_cfg( |
|
PNASNet5Large, variant, pretrained, |
|
default_cfg=default_cfgs[variant], |
|
feature_cfg=dict(feature_cls='hook', no_rewrite=True), |
|
**kwargs) |
|
|
|
|
|
@register_model |
|
def pnasnet5large(pretrained=False, **kwargs): |
|
r"""PNASNet-5 model architecture from the |
|
`"Progressive Neural Architecture Search" |
|
<https://arxiv.org/abs/1712.00559>`_ paper. |
|
""" |
|
model_kwargs = dict(pad_type='same', **kwargs) |
|
return _create_pnasnet('pnasnet5large', pretrained, **model_kwargs) |
|
|