Spaces:
Sleeping
Sleeping
# Copyright (C) 2020 NVIDIA Corporation. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
# Copyright (C) 2020 NVIDIA Corporation. All rights reserved | |
import torch | |
import torch.nn.functional as F | |
import torchvision | |
from torch import nn | |
def apply_imagenet_normalization(input): | |
r"""Normalize using ImageNet mean and std. | |
Args: | |
input (4D tensor NxCxHxW): The input images, assuming to be [-1, 1]. | |
Returns: | |
Normalized inputs using the ImageNet normalization. | |
""" | |
# normalize the input back to [0, 1] | |
normalized_input = (input + 1) / 2 | |
# normalize the input using the ImageNet mean and std | |
mean = normalized_input.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) | |
std = normalized_input.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) | |
output = (normalized_input - mean) / std | |
return output | |
class PerceptualHashValue(nn.Module): | |
"""Perceptual loss initialization. | |
Args: | |
cfg (Config): Configuration file. | |
network (str) : The name of the loss network: 'vgg16' | 'vgg19'. | |
layers (str or list of str) : The layers used to compute the loss. | |
weights (float or list of float : The loss weights of each layer. | |
criterion (str): The type of distance function: 'l1' | 'l2'. | |
resize (bool) : If ``True``, resize the input images to 224x224. | |
resize_mode (str): Algorithm used for resizing. | |
instance_normalized (bool): If ``True``, applies instance normalization | |
to the feature maps before computing the distance. | |
num_scales (int): The loss will be evaluated at original size and | |
this many times downsampled sizes. | |
""" | |
def __init__(self, T=0.005, network='vgg19', layers='relu_4_1', resize=False, resize_mode='bilinear', | |
instance_normalized=False): | |
super().__init__() | |
if isinstance(layers, str): | |
layers = [layers] | |
if network == 'vgg19': | |
self.model = _vgg19(layers) | |
elif network == 'vgg16': | |
self.model = _vgg16(layers) | |
elif network == 'alexnet': | |
self.model = _alexnet(layers) | |
elif network == 'inception_v3': | |
self.model = _inception_v3(layers) | |
elif network == 'resnet50': | |
self.model = _resnet50(layers) | |
elif network == 'robust_resnet50': | |
self.model = _robust_resnet50(layers) | |
elif network == 'vgg_face_dag': | |
self.model = _vgg_face_dag(layers) | |
else: | |
raise ValueError('Network %s is not recognized' % network) | |
self.T = T | |
self.layers = layers | |
self.resize = resize | |
self.resize_mode = resize_mode | |
self.instance_normalized = instance_normalized | |
print('Perceptual Hash Value:') | |
print('\tMode: {}'.format(network)) | |
def forward(self, inp, target): | |
r"""Perceptual loss forward. | |
Args: | |
inp (4D tensor) : Input tensor. | |
target (4D tensor) : Ground truth tensor, same shape as the input. | |
Returns: | |
(scalar tensor) : The perceptual loss. | |
""" | |
# Perceptual loss should operate in eval mode by default. | |
self.model.eval() | |
inp, target = \ | |
apply_imagenet_normalization(inp), \ | |
apply_imagenet_normalization(target) | |
if self.resize: | |
inp = F.interpolate( | |
inp, mode=self.resize_mode, size=(224, 224), | |
align_corners=False) | |
target = F.interpolate( | |
target, mode=self.resize_mode, size=(224, 224), | |
align_corners=False) | |
# Evaluate perceptual loss at each scale. | |
loss = 0 | |
input_features, target_features = \ | |
self.model(inp), self.model(target) | |
hpv_list = [] | |
for layer in self.layers: | |
# Example per-layer VGG19 loss values after applying | |
# [0.03125, 0.0625, 0.125, 0.25, 1.0] weighting. | |
# relu_1_1, 0.014698, 0.47 | |
# relu_2_1, 0.085817, 1.37 | |
# relu_3_1, 0.349977, 2.8 | |
# relu_4_1, 0.544188, 2.176 | |
# relu_5_1, 0.906261, 0.906 | |
input_feature = input_features[layer] | |
target_feature = target_features[layer].detach() | |
if self.instance_normalized: | |
input_feature = F.instance_norm(input_feature) | |
target_feature = F.instance_norm(target_feature) | |
# We are ignoring the spatial dimensions | |
B, C = input_feature.shape[:2] | |
inp_avg = torch.mean(input_feature.view(B, C, -1), -1) | |
tgt_avg = torch.mean(target_feature.view(B, C, -1), -1) | |
abs_dif = torch.abs(inp_avg - tgt_avg) | |
hpv = torch.sum(abs_dif > self.T).item() / (B * C) | |
hpv_list.append(hpv) | |
return hpv_list | |
class _PerceptualNetwork(nn.Module): | |
r"""The network that extracts features to compute the perceptual loss. | |
Args: | |
network (nn.Sequential) : The network that extracts features. | |
layer_name_mapping (dict) : The dictionary that | |
maps a layer's index to its name. | |
layers (list of str): The list of layer names that we are using. | |
""" | |
def __init__(self, network, layer_name_mapping, layers): | |
super().__init__() | |
assert isinstance(network, nn.Sequential), \ | |
'The network needs to be of type "nn.Sequential".' | |
self.network = network | |
self.layer_name_mapping = layer_name_mapping | |
self.layers = layers | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, x): | |
r"""Extract perceptual features.""" | |
output = {} | |
for i, layer in enumerate(self.network): | |
x = layer(x) | |
layer_name = self.layer_name_mapping.get(i, None) | |
if layer_name in self.layers: | |
# If the current layer is used by the perceptual loss. | |
output[layer_name] = x | |
return output | |
def _vgg19(layers): | |
r"""Get vgg19 layers""" | |
network = torchvision.models.vgg19(pretrained=True).features | |
layer_name_mapping = {1: 'relu_1_1', | |
3: 'relu_1_2', | |
6: 'relu_2_1', | |
8: 'relu_2_2', | |
11: 'relu_3_1', | |
13: 'relu_3_2', | |
15: 'relu_3_3', | |
17: 'relu_3_4', | |
20: 'relu_4_1', | |
22: 'relu_4_2', | |
24: 'relu_4_3', | |
26: 'relu_4_4', | |
29: 'relu_5_1'} | |
return _PerceptualNetwork(network, layer_name_mapping, layers) | |
def _vgg16(layers): | |
r"""Get vgg16 layers""" | |
network = torchvision.models.vgg16(pretrained=True).features | |
layer_name_mapping = {1: 'relu_1_1', | |
3: 'relu_1_2', | |
6: 'relu_2_1', | |
8: 'relu_2_2', | |
11: 'relu_3_1', | |
13: 'relu_3_2', | |
15: 'relu_3_3', | |
18: 'relu_4_1', | |
20: 'relu_4_2', | |
22: 'relu_4_3', | |
25: 'relu_5_1'} | |
return _PerceptualNetwork(network, layer_name_mapping, layers) | |
def _alexnet(layers): | |
r"""Get alexnet layers""" | |
network = torchvision.models.alexnet(pretrained=True).features | |
layer_name_mapping = {0: 'conv_1', | |
1: 'relu_1', | |
3: 'conv_2', | |
4: 'relu_2', | |
6: 'conv_3', | |
7: 'relu_3', | |
8: 'conv_4', | |
9: 'relu_4', | |
10: 'conv_5', | |
11: 'relu_5'} | |
return _PerceptualNetwork(network, layer_name_mapping, layers) | |
def _inception_v3(layers): | |
r"""Get inception v3 layers""" | |
inception = torchvision.models.inception_v3(pretrained=True) | |
network = nn.Sequential(inception.Conv2d_1a_3x3, | |
inception.Conv2d_2a_3x3, | |
inception.Conv2d_2b_3x3, | |
nn.MaxPool2d(kernel_size=3, stride=2), | |
inception.Conv2d_3b_1x1, | |
inception.Conv2d_4a_3x3, | |
nn.MaxPool2d(kernel_size=3, stride=2), | |
inception.Mixed_5b, | |
inception.Mixed_5c, | |
inception.Mixed_5d, | |
inception.Mixed_6a, | |
inception.Mixed_6b, | |
inception.Mixed_6c, | |
inception.Mixed_6d, | |
inception.Mixed_6e, | |
inception.Mixed_7a, | |
inception.Mixed_7b, | |
inception.Mixed_7c, | |
nn.AdaptiveAvgPool2d(output_size=(1, 1))) | |
layer_name_mapping = {3: 'pool_1', | |
6: 'pool_2', | |
14: 'mixed_6e', | |
18: 'pool_3'} | |
return _PerceptualNetwork(network, layer_name_mapping, layers) | |
def _resnet50(layers): | |
r"""Get resnet50 layers""" | |
resnet50 = torchvision.models.resnet50(pretrained=True) | |
network = nn.Sequential(resnet50.conv1, | |
resnet50.bn1, | |
resnet50.relu, | |
resnet50.maxpool, | |
resnet50.layer1, | |
resnet50.layer2, | |
resnet50.layer3, | |
resnet50.layer4, | |
resnet50.avgpool) | |
layer_name_mapping = {4: 'layer_1', | |
5: 'layer_2', | |
6: 'layer_3', | |
7: 'layer_4'} | |
return _PerceptualNetwork(network, layer_name_mapping, layers) | |
def _robust_resnet50(layers): | |
r"""Get robust resnet50 layers""" | |
resnet50 = torchvision.models.resnet50(pretrained=False) | |
state_dict = torch.utils.model_zoo.load_url( | |
'http://andrewilyas.com/ImageNet.pt') | |
new_state_dict = {} | |
for k, v in state_dict['model'].items(): | |
if k.startswith('module.model.'): | |
new_state_dict[k[13:]] = v | |
resnet50.load_state_dict(new_state_dict) | |
network = nn.Sequential(resnet50.conv1, | |
resnet50.bn1, | |
resnet50.relu, | |
resnet50.maxpool, | |
resnet50.layer1, | |
resnet50.layer2, | |
resnet50.layer3, | |
resnet50.layer4, | |
resnet50.avgpool) | |
layer_name_mapping = {4: 'layer_1', | |
5: 'layer_2', | |
6: 'layer_3', | |
7: 'layer_4'} | |
return _PerceptualNetwork(network, layer_name_mapping, layers) | |
def _vgg_face_dag(layers): | |
r"""Get vgg face layers""" | |
network = torchvision.models.vgg16(num_classes=2622) | |
state_dict = torch.utils.model_zoo.load_url( | |
'http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/' | |
'vgg_face_dag.pth') | |
feature_layer_name_mapping = { | |
0: 'conv1_1', | |
2: 'conv1_2', | |
5: 'conv2_1', | |
7: 'conv2_2', | |
10: 'conv3_1', | |
12: 'conv3_2', | |
14: 'conv3_3', | |
17: 'conv4_1', | |
19: 'conv4_2', | |
21: 'conv4_3', | |
24: 'conv5_1', | |
26: 'conv5_2', | |
28: 'conv5_3'} | |
new_state_dict = {} | |
for k, v in feature_layer_name_mapping.items(): | |
new_state_dict['features.' + str(k) + '.weight'] =\ | |
state_dict[v + '.weight'] | |
new_state_dict['features.' + str(k) + '.bias'] = \ | |
state_dict[v + '.bias'] | |
classifier_layer_name_mapping = { | |
0: 'fc6', | |
3: 'fc7', | |
6: 'fc8'} | |
for k, v in classifier_layer_name_mapping.items(): | |
new_state_dict['classifier.' + str(k) + '.weight'] = \ | |
state_dict[v + '.weight'] | |
new_state_dict['classifier.' + str(k) + '.bias'] = \ | |
state_dict[v + '.bias'] | |
network.load_state_dict(new_state_dict) | |
class Flatten(nn.Module): | |
r"""Flatten the tensor""" | |
def forward(self, x): | |
r"""Flatten it""" | |
return x.view(x.shape[0], -1) | |
layer_name_mapping = { | |
1: 'avgpool', | |
3: 'fc6', | |
4: 'relu_6', | |
6: 'fc7', | |
7: 'relu_7', | |
9: 'fc8'} | |
seq_layers = [network.features, network.avgpool, Flatten()] | |
for i in range(7): | |
seq_layers += [network.classifier[i]] | |
network = nn.Sequential(*seq_layers) | |
return _PerceptualNetwork(network, layer_name_mapping, layers) | |