|
from __future__ import absolute_import |
|
|
|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from . import pretrained_networks as pretrained_torch_models |
|
|
|
|
|
def spatial_average(x, keepdim=True): |
|
return x.mean([2, 3], keepdim=keepdim) |
|
|
|
|
|
def upsample(x): |
|
return nn.Upsample(size=x.shape[2:], mode='bilinear', align_corners=False)(x) |
|
|
|
|
|
def normalize_tensor(in_feat, eps=1e-10): |
|
norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) |
|
return in_feat / (norm_factor + eps) |
|
|
|
|
|
|
|
class LPIPS(nn.Module): |
|
|
|
def __init__(self, |
|
pretrained=True, |
|
net='alex', |
|
version='0.1', |
|
lpips=True, |
|
spatial=False, |
|
pnet_rand=False, |
|
pnet_tune=False, |
|
use_dropout=True, |
|
model_path=None, |
|
eval_mode=True, |
|
verbose=True): |
|
""" Initializes a perceptual loss torch.nn.Module |
|
|
|
Parameters (default listed first) |
|
--------------------------------- |
|
lpips : bool |
|
[True] use linear layers on top of base/trunk network |
|
[False] means no linear layers; each layer is averaged together |
|
pretrained : bool |
|
This flag controls the linear layers, which are only in effect when lpips=True above |
|
[True] means linear layers are calibrated with human perceptual judgments |
|
[False] means linear layers are randomly initialized |
|
pnet_rand : bool |
|
[False] means trunk loaded with ImageNet classification weights |
|
[True] means randomly initialized trunk |
|
net : str |
|
['alex','vgg','squeeze'] are the base/trunk networks available |
|
version : str |
|
['v0.1'] is the default and latest |
|
['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1) |
|
model_path : 'str' |
|
[None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1 |
|
|
|
The following parameters should only be changed if training the network: |
|
|
|
eval_mode : bool |
|
[True] is for test mode (default) |
|
[False] is for training mode |
|
pnet_tune |
|
[False] keep base/trunk frozen |
|
[True] tune the base/trunk network |
|
use_dropout : bool |
|
[True] to use dropout when training linear layers |
|
[False] for no dropout when training linear layers |
|
""" |
|
super(LPIPS, self).__init__() |
|
if verbose: |
|
print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' % |
|
('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) |
|
|
|
self.pnet_type = net |
|
self.pnet_tune = pnet_tune |
|
self.pnet_rand = pnet_rand |
|
self.spatial = spatial |
|
self.lpips = lpips |
|
self.version = version |
|
self.scaling_layer = ScalingLayer() |
|
|
|
if self.pnet_type in ['vgg', 'vgg16']: |
|
net_type = pretrained_torch_models.vgg16 |
|
self.chns = [64, 128, 256, 512, 512] |
|
elif self.pnet_type == 'alex': |
|
net_type = pretrained_torch_models.alexnet |
|
self.chns = [64, 192, 384, 256, 256] |
|
elif self.pnet_type == 'squeeze': |
|
net_type = pretrained_torch_models.squeezenet |
|
self.chns = [64, 128, 256, 384, 384, 512, 512] |
|
self.L = len(self.chns) |
|
|
|
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) |
|
|
|
if lpips: |
|
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) |
|
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) |
|
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) |
|
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) |
|
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) |
|
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] |
|
if self.pnet_type == 'squeeze': |
|
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) |
|
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) |
|
self.lins += [self.lin5, self.lin6] |
|
self.lins = nn.ModuleList(self.lins) |
|
|
|
if pretrained: |
|
if model_path is None: |
|
model_path = os.path.join( |
|
os.path.dirname(os.path.abspath(__file__)), |
|
f"weights/v{version}/{net}.pth" |
|
) |
|
if verbose: |
|
print('Loading model from: %s' % model_path) |
|
self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) |
|
|
|
if eval_mode: |
|
self.eval() |
|
|
|
def forward(self, in0, in1, return_per_layer=False, normalize=False): |
|
if normalize: |
|
in0 = 2 * in0 - 1 |
|
in1 = 2 * in1 - 1 |
|
|
|
|
|
if self.version == '0.1': |
|
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) |
|
else: |
|
in0_input, in1_input = in0, in1 |
|
|
|
|
|
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) |
|
|
|
feats0, feats1, diffs = {}, {}, {} |
|
for kk in range(self.L): |
|
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) |
|
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 |
|
|
|
if self.lpips: |
|
if self.spatial: |
|
res = [upsample(self.lins[kk](diffs[kk])) for kk in range(self.L)] |
|
else: |
|
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] |
|
else: |
|
if self.spatial: |
|
res = [upsample(diffs[kk].sum(dim=1, keepdim=True)) for kk in range(self.L)] |
|
else: |
|
res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] |
|
|
|
loss = sum(res) |
|
|
|
if return_per_layer: |
|
return loss, res |
|
else: |
|
return loss |
|
|
|
|
|
class ScalingLayer(nn.Module): |
|
def __init__(self): |
|
super(ScalingLayer, self).__init__() |
|
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) |
|
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) |
|
|
|
def forward(self, inp): |
|
return (inp - self.shift) / self.scale |
|
|
|
|
|
class NetLinLayer(nn.Module): |
|
"""A single linear layer which does a 1x1 conv""" |
|
|
|
def __init__(self, chn_in, chn_out=1, use_dropout=False): |
|
super(NetLinLayer, self).__init__() |
|
|
|
layers = [nn.Dropout(), ] if (use_dropout) else [] |
|
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] |
|
self.model = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|