Spaces:
Sleeping
Sleeping
""" | |
Backbones supported by torchvison. | |
""" | |
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
class TVDeeplabRes101Encoder(nn.Module): | |
""" | |
FCN-Resnet101 backbone from torchvision deeplabv3 | |
No ASPP is used as we found emperically it hurts performance | |
""" | |
def __init__(self, use_coco_init, aux_dim_keep = 64, use_aspp = False): | |
super().__init__() | |
_model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=use_coco_init, progress=True, num_classes=21, aux_loss=None) | |
if use_coco_init: | |
print("###### NETWORK: Using ms-coco initialization ######") | |
else: | |
print("###### NETWORK: Training from scratch ######") | |
_model_list = list(_model.children()) | |
self.aux_dim_keep = aux_dim_keep | |
self.backbone = _model_list[0] | |
self.localconv = nn.Conv2d(2048, 256,kernel_size = 1, stride = 1, bias = False) # reduce feature map dimension | |
self.asppconv = nn.Conv2d(256, 256,kernel_size = 1, bias = False) | |
_aspp = _model_list[1][0] | |
_conv256 = _model_list[1][1] | |
self.aspp_out = nn.Sequential(*[_aspp, _conv256] ) | |
self.use_aspp = use_aspp | |
def forward(self, x_in, low_level): | |
""" | |
Args: | |
low_level: whether returning aggregated low-level features in FCN | |
""" | |
fts = self.backbone(x_in) | |
if self.use_aspp: | |
fts256 = self.aspp_out(fts['out']) | |
high_level_fts = fts256 | |
else: | |
fts2048 = fts['out'] | |
high_level_fts = self.localconv(fts2048) | |
if low_level: | |
low_level_fts = fts['aux'][:, : self.aux_dim_keep] | |
return high_level_fts, low_level_fts | |
else: | |
return high_level_fts | |