LoGoSAM_demo / models /backbone /torchvision_backbones.py
quandn2003's picture
Upload folder using huggingface_hub
427d150 verified
"""
Backbones supported by torchvison.
"""
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
class TVDeeplabRes101Encoder(nn.Module):
"""
FCN-Resnet101 backbone from torchvision deeplabv3
No ASPP is used as we found emperically it hurts performance
"""
def __init__(self, use_coco_init, aux_dim_keep = 64, use_aspp = False):
super().__init__()
_model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=use_coco_init, progress=True, num_classes=21, aux_loss=None)
if use_coco_init:
print("###### NETWORK: Using ms-coco initialization ######")
else:
print("###### NETWORK: Training from scratch ######")
_model_list = list(_model.children())
self.aux_dim_keep = aux_dim_keep
self.backbone = _model_list[0]
self.localconv = nn.Conv2d(2048, 256,kernel_size = 1, stride = 1, bias = False) # reduce feature map dimension
self.asppconv = nn.Conv2d(256, 256,kernel_size = 1, bias = False)
_aspp = _model_list[1][0]
_conv256 = _model_list[1][1]
self.aspp_out = nn.Sequential(*[_aspp, _conv256] )
self.use_aspp = use_aspp
def forward(self, x_in, low_level):
"""
Args:
low_level: whether returning aggregated low-level features in FCN
"""
fts = self.backbone(x_in)
if self.use_aspp:
fts256 = self.aspp_out(fts['out'])
high_level_fts = fts256
else:
fts2048 = fts['out']
high_level_fts = self.localconv(fts2048)
if low_level:
low_level_fts = fts['aux'][:, : self.aux_dim_keep]
return high_level_fts, low_level_fts
else:
return high_level_fts