|
""" Test Time Pooling (Average-Max Pool) |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
|
|
import logging |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from .adaptive_avgmax_pool import adaptive_avgmax_pool2d |
|
|
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
|
|
class TestTimePoolHead(nn.Module): |
|
def __init__(self, base, original_pool=7): |
|
super(TestTimePoolHead, self).__init__() |
|
self.base = base |
|
self.original_pool = original_pool |
|
base_fc = self.base.get_classifier() |
|
if isinstance(base_fc, nn.Conv2d): |
|
self.fc = base_fc |
|
else: |
|
self.fc = nn.Conv2d( |
|
self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) |
|
self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) |
|
self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) |
|
self.base.reset_classifier(0) |
|
|
|
def forward(self, x): |
|
x = self.base.forward_features(x) |
|
x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) |
|
x = self.fc(x) |
|
x = adaptive_avgmax_pool2d(x, 1) |
|
return x.view(x.size(0), -1) |
|
|
|
|
|
def apply_test_time_pool(model, config, use_test_size=True): |
|
test_time_pool = False |
|
if not hasattr(model, 'default_cfg') or not model.default_cfg: |
|
return model, False |
|
if use_test_size and 'test_input_size' in model.default_cfg: |
|
df_input_size = model.default_cfg['test_input_size'] |
|
else: |
|
df_input_size = model.default_cfg['input_size'] |
|
if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: |
|
_logger.info('Target input size %s > pretrained default %s, using test time pooling' % |
|
(str(config['input_size'][-2:]), str(df_input_size[-2:]))) |
|
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) |
|
test_time_pool = True |
|
return model, test_time_pool |
|
|