newTryOn / losses /pp_losses.py
amanSethSmava
new commit
6d314be
raw
history blame
22.3 kB
from dataclasses import dataclass
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms as T
from utils.bicubic import BicubicDownSample
normalize = T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
@dataclass
class DefaultPaths:
psp_path: str = "pretrained_models/psp_ffhq_encode.pt"
ir_se50_path: str = "pretrained_models/ArcFace/ir_se50.pth"
stylegan_weights: str = "pretrained_models/stylegan2-ffhq-config-f.pt"
stylegan_car_weights: str = "pretrained_models/stylegan2-car-config-f-new.pkl"
stylegan_weights_pkl: str = (
"pretrained_models/stylegan2-ffhq-config-f.pkl"
)
arcface_model_path: str = "pretrained_models/ArcFace/backbone_ir50.pth"
moco: str = "pretrained_models/moco_v2_800ep_pretrain.pt"
from collections import namedtuple
from torch.nn import (
Conv2d,
BatchNorm2d,
PReLU,
ReLU,
Sigmoid,
MaxPool2d,
AdaptiveAvgPool2d,
Sequential,
Module,
Dropout,
Linear,
BatchNorm1d,
)
"""
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""
class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1)
def l2_norm(input, axis=1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
class Bottleneck(namedtuple("Block", ["in_channel", "depth", "stride"])):
"""A named tuple describing a ResNet block."""
def get_block(in_channel, depth, num_units, stride=2):
return [Bottleneck(in_channel, depth, stride)] + [
Bottleneck(depth, depth, 1) for i in range(num_units - 1)
]
def get_blocks(num_layers):
if num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3),
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3),
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=8),
get_block(in_channel=128, depth=256, num_units=36),
get_block(in_channel=256, depth=512, num_units=3),
]
else:
raise ValueError(
"Invalid number of layers: {}. Must be one of [50, 100, 152]".format(
num_layers
)
)
return blocks
class SEModule(Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(
channels, channels // reduction, kernel_size=1, padding=0, bias=False
)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(
channels // reduction, channels, kernel_size=1, padding=0, bias=False
)
self.sigmoid = Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class bottleneck_IR(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth),
)
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth),
)
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class bottleneck_IR_SE(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR_SE, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth),
)
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth),
SEModule(depth, 16),
)
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
"""
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""
class Backbone(Module):
def __init__(self, input_size, num_layers, mode="ir", drop_ratio=0.4, affine=True):
super(Backbone, self).__init__()
assert input_size in [112, 224], "input_size should be 112 or 224"
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
assert mode in ["ir", "ir_se"], "mode should be ir or ir_se"
blocks = get_blocks(num_layers)
if mode == "ir":
unit_module = bottleneck_IR
elif mode == "ir_se":
unit_module = bottleneck_IR_SE
self.input_layer = Sequential(
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)
)
if input_size == 112:
self.output_layer = Sequential(
BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 7 * 7, 512),
BatchNorm1d(512, affine=affine),
)
else:
self.output_layer = Sequential(
BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 14 * 14, 512),
BatchNorm1d(512, affine=affine),
)
modules = []
for block in blocks:
for bottleneck in block:
modules.append(
unit_module(
bottleneck.in_channel, bottleneck.depth, bottleneck.stride
)
)
self.body = Sequential(*modules)
def forward(self, x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_layer(x)
return l2_norm(x)
def IR_50(input_size):
"""Constructs a ir-50 model."""
model = Backbone(input_size, num_layers=50, mode="ir", drop_ratio=0.4, affine=False)
return model
def IR_101(input_size):
"""Constructs a ir-101 model."""
model = Backbone(
input_size, num_layers=100, mode="ir", drop_ratio=0.4, affine=False
)
return model
def IR_152(input_size):
"""Constructs a ir-152 model."""
model = Backbone(
input_size, num_layers=152, mode="ir", drop_ratio=0.4, affine=False
)
return model
def IR_SE_50(input_size):
"""Constructs a ir_se-50 model."""
model = Backbone(
input_size, num_layers=50, mode="ir_se", drop_ratio=0.4, affine=False
)
return model
def IR_SE_101(input_size):
"""Constructs a ir_se-101 model."""
model = Backbone(
input_size, num_layers=100, mode="ir_se", drop_ratio=0.4, affine=False
)
return model
def IR_SE_152(input_size):
"""Constructs a ir_se-152 model."""
model = Backbone(
input_size, num_layers=152, mode="ir_se", drop_ratio=0.4, affine=False
)
return model
class IDLoss(nn.Module):
def __init__(self):
super(IDLoss, self).__init__()
print("Loading ResNet ArcFace")
self.facenet = Backbone(
input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se"
)
self.facenet.load_state_dict(torch.load(DefaultPaths.ir_se50_path))
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval()
def extract_feats(self, x):
x = x[:, :, 35:223, 32:220] # Crop interesting region
x = self.face_pool(x)
x_feats = self.facenet(x)
return x_feats
def forward(self, y_hat, y):
n_samples = y.shape[0]
y_feats = self.extract_feats(y)
y_hat_feats = self.extract_feats(y_hat)
y_feats = y_feats.detach()
loss = 0
count = 0
for i in range(n_samples):
diff_target = y_hat_feats[i].dot(y_feats[i])
loss += 1 - diff_target
count += 1
return loss / count
class FeatReconLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.MSELoss()
def forward(self, recon_1, recon_2):
return self.loss_fn(recon_1, recon_2).mean()
class EncoderAdvLoss:
def __call__(self, fake_preds):
loss_G_adv = F.softplus(-fake_preds).mean()
return loss_G_adv
class AdvLoss:
def __init__(self, coef=0.0):
self.coef = coef
def __call__(self, disc, real_images, generated_images):
fake_preds = disc(generated_images, None)
real_preds = disc(real_images, None)
loss = self.d_logistic_loss(real_preds, fake_preds)
return {'disc adv': loss}
def d_logistic_loss(self, real_preds, fake_preds):
real_loss = F.softplus(-real_preds)
fake_loss = F.softplus(fake_preds)
return (real_loss.mean() + fake_loss.mean()) / 2
from models.face_parsing.model import BiSeNet, seg_mean, seg_std
class DiceLoss(nn.Module):
def __init__(self, gamma=2):
super().__init__()
self.gamma = gamma
self.seg = BiSeNet(n_classes=16)
self.seg.to('cuda')
self.seg.load_state_dict(torch.load('pretrained_models/BiSeNet/seg.pth'))
for param in self.seg.parameters():
param.requires_grad = False
self.seg.eval()
self.downsample_512 = BicubicDownSample(factor=2)
def calc_landmark(self, x):
IM = (self.downsample_512(x) - seg_mean) / seg_std
out, _, _ = self.seg(IM)
return out
def dice_loss(self, input, target):
smooth = 1.
iflat = input.view(input.size(0), -1)
tflat = target.view(target.size(0), -1)
intersection = (iflat * tflat).sum(dim=1)
fn = torch.sum((tflat * (1-iflat))**self.gamma, dim=1)
fp = torch.sum(((1-tflat) * iflat)**self.gamma, dim=1)
return 1 - ((2. * intersection + smooth) /
(iflat.sum(dim=1) + tflat.sum(dim=1) + fn + fp + smooth))
def __call__(self, in_logit, tg_logit):
probs1 = F.softmax(in_logit, dim=1)
probs2 = F.softmax(tg_logit, dim=1)
return self.dice_loss(probs1, probs2).mean()
from typing import Sequence
from itertools import chain
import torch
import torch.nn as nn
from torchvision import models
def get_network(net_type: str):
if net_type == "alex":
return AlexNet()
elif net_type == "squeeze":
return SqueezeNet()
elif net_type == "vgg":
return VGG16()
else:
raise NotImplementedError("choose net_type from [alex, squeeze, vgg].")
class LinLayers(nn.ModuleList):
def __init__(self, n_channels_list: Sequence[int]):
super(LinLayers, self).__init__(
[
nn.Sequential(nn.Identity(), nn.Conv2d(nc, 1, 1, 1, 0, bias=False))
for nc in n_channels_list
]
)
for param in self.parameters():
param.requires_grad = False
class BaseNet(nn.Module):
def __init__(self):
super(BaseNet, self).__init__()
# register buffer
self.register_buffer(
"mean", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
)
self.register_buffer(
"std", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
)
def set_requires_grad(self, state: bool):
for param in chain(self.parameters(), self.buffers()):
param.requires_grad = state
def z_score(self, x: torch.Tensor):
return (x - self.mean) / self.std
def forward(self, x: torch.Tensor):
x = self.z_score(x)
output = []
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
x = layer(x)
if i in self.target_layers:
output.append(normalize_activation(x))
if len(output) == len(self.target_layers):
break
return output
class SqueezeNet(BaseNet):
def __init__(self):
super(SqueezeNet, self).__init__()
self.layers = models.squeezenet1_1(True).features
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
self.set_requires_grad(False)
class AlexNet(BaseNet):
def __init__(self):
super(AlexNet, self).__init__()
self.layers = models.alexnet(True).features
self.target_layers = [2, 5, 8, 10, 12]
self.n_channels_list = [64, 192, 384, 256, 256]
self.set_requires_grad(False)
class VGG16(BaseNet):
def __init__(self):
super(VGG16, self).__init__()
self.layers = models.vgg16(True).features
self.target_layers = [4, 9, 16, 23, 30]
self.n_channels_list = [64, 128, 256, 512, 512]
self.set_requires_grad(False)
from collections import OrderedDict
import torch
def normalize_activation(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
return x / (norm_factor + eps)
def get_state_dict(net_type: str = "alex", version: str = "0.1"):
# build url
url = (
"https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/"
+ f"master/lpips/weights/v{version}/{net_type}.pth"
)
# download
old_state_dict = torch.hub.load_state_dict_from_url(
url,
progress=True,
map_location=None if torch.cuda.is_available() else torch.device("cpu"),
)
# rename keys
new_state_dict = OrderedDict()
for key, val in old_state_dict.items():
new_key = key
new_key = new_key.replace("lin", "")
new_key = new_key.replace("model.", "")
new_state_dict[new_key] = val
return new_state_dict
class LPIPS(nn.Module):
r"""Creates a criterion that measures
Learned Perceptual Image Patch Similarity (LPIPS).
Arguments:
net_type (str): the network type to compare the features:
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
version (str): the version of LPIPS. Default: 0.1.
"""
def __init__(self, net_type: str = "alex", version: str = "0.1"):
assert version in ["0.1"], "v0.1 is only supported now"
super(LPIPS, self).__init__()
# pretrained network
self.net = get_network(net_type).to("cuda")
# linear layers
self.lin = LinLayers(self.net.n_channels_list).to("cuda")
self.lin.load_state_dict(get_state_dict(net_type, version))
def forward(self, x: torch.Tensor, y: torch.Tensor):
feat_x, feat_y = self.net(x), self.net(y)
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
return torch.sum(torch.cat(res, 0)) / x.shape[0]
class LPIPSLoss(LPIPS):
pass
class LPIPSScaleLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = LPIPSLoss()
def forward(self, x, y):
out = 0
for res in [256, 128, 64]:
x_scale = F.interpolate(x, size=(res, res), mode="bilinear", align_corners=False)
y_scale = F.interpolate(y, size=(res, res), mode="bilinear", align_corners=False)
out += self.loss_fn.forward(x_scale, y_scale).mean()
return out
class SyntMSELoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.MSELoss()
def forward(self, im1, im2):
return self.loss_fn(im1, im2).mean()
class R1Loss:
def __init__(self, coef=10.0):
self.coef = coef
def __call__(self, disc, real_images):
real_images.requires_grad = True
real_preds = disc(real_images, None)
real_preds = real_preds.view(real_images.size(0), -1)
real_preds = real_preds.mean(dim=1).unsqueeze(1)
r1_loss = self.d_r1_loss(real_preds, real_images)
loss_D_R1 = self.coef / 2 * r1_loss * 16 + 0 * real_preds[0]
return {'disc r1 loss': loss_D_R1}
def d_r1_loss(self, real_pred, real_img):
(grad_real,) = torch.autograd.grad(
outputs=real_pred.sum(), inputs=real_img, create_graph=True
)
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
class DilatedMask:
def __init__(self, kernel_size=5):
self.kernel_size = kernel_size
cords_x = torch.arange(0, kernel_size).view(1, -1).expand(kernel_size, -1) - kernel_size // 2
cords_y = cords_x.clone().permute(1, 0)
self.kernel = torch.as_tensor((cords_x ** 2 + cords_y ** 2) <= (kernel_size // 2) ** 2, dtype=torch.float).view(1, 1, kernel_size, kernel_size).cuda()
self.kernel /= self.kernel.sum()
def __call__(self, mask):
smooth_mask = F.conv2d(mask, self.kernel, padding=self.kernel_size // 2)
return smooth_mask ** 0.25
class LossBuilder:
def __init__(self, losses_dict, device='cuda'):
self.losses_dict = losses_dict
self.device = device
self.EncoderAdvLoss = EncoderAdvLoss()
self.AdvLoss = AdvLoss()
self.R1Loss = R1Loss()
self.FeatReconLoss = FeatReconLoss().to(device).eval()
self.IDLoss = IDLoss().to(device).eval()
self.LPIPS = LPIPSScaleLoss().to(device).eval()
self.SyntMSELoss = SyntMSELoss().to(device).eval()
self.downsample_256 = BicubicDownSample(factor=4)
def CalcAdvLoss(self, disc, gen_F):
fake_preds_F = disc(gen_F, None)
return {'adv': self.losses_dict['adv'] * self.EncoderAdvLoss(fake_preds_F)}
def CalcDisLoss(self, disc, real_images, generated_images):
return self.AdvLoss(disc, real_images, generated_images)
def CalcR1Loss(self, disc, real_images):
return self.R1Loss(disc, real_images)
def __call__(self, source, target, target_mask, HT_E, gen_w, F_w, gen_F, F_gen, **kwargs):
losses = {}
gen_w_256 = self.downsample_256(gen_w)
gen_F_256 = self.downsample_256(gen_F)
# ID loss
losses['rec id'] = self.losses_dict['id'] * (self.IDLoss(normalize(source), gen_w_256) + self.IDLoss(normalize(source), gen_F_256))
# Feat Recons Loss
losses['rec feat_rec'] = self.losses_dict['feat_rec'] * self.FeatReconLoss(F_w.detach(), F_gen)
# LPIPS loss
losses['rec lpips_scale'] = self.losses_dict['lpips_scale'] * (self.LPIPS(normalize(source), gen_w_256) + self.LPIPS(normalize(source), gen_F_256))
# Synt loss
# losses['l2_synt'] = self.losses_dict['l2_synt'] * self.SyntMSELoss(target * HT_E, (gen_F_256 + 1) / 2 * HT_E)
return losses
class LossBuilderMulti(LossBuilder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.DiceLoss = DiceLoss().to(kwargs.get('device', 'cuda')).eval()
self.dilated = DilatedMask(25)
def __call__(self, source, target, target_mask, HT_E, gen_w, F_w, gen_F, F_gen, **kwargs):
losses = {}
gen_w_256 = self.downsample_256(gen_w)
gen_F_256 = self.downsample_256(gen_F)
# Dice loss
with torch.no_grad():
target_512 = F.interpolate(target, size=(512, 512), mode='bilinear').clip(0, 1)
seg_target = self.DiceLoss.calc_landmark(target_512)
seg_target = F.interpolate(seg_target, size=(256, 256), mode='nearest')
seg_gen = F.interpolate(self.DiceLoss.calc_landmark((gen_F + 1) / 2), size=(256, 256), mode='nearest')
losses['DiceLoss'] = self.losses_dict['landmark'] * self.DiceLoss(seg_gen, seg_target)
# ID loss
losses['id'] = self.losses_dict['id'] * (self.IDLoss(normalize(source) * target_mask, gen_w_256 * target_mask) +
self.IDLoss(normalize(source) * target_mask, gen_F_256 * target_mask))
# Feat Recons loss
losses['feat_rec'] = self.losses_dict['feat_rec'] * self.FeatReconLoss(F_w.detach(), F_gen)
# LPIPS loss
losses['lpips_face'] = 0.5 * self.losses_dict['lpips_scale'] * (self.LPIPS(normalize(source) * target_mask, gen_w_256 * target_mask) +
self.LPIPS(normalize(source) * target_mask, gen_F_256 * target_mask))
losses['lpips_hair'] = 0.5 * self.losses_dict['lpips_scale'] * (self.LPIPS(normalize(target) * HT_E, gen_w_256 * HT_E) +
self.LPIPS(normalize(target) * HT_E, gen_F_256 * HT_E))
# Inpaint loss
if self.losses_dict['inpaint'] != 0.:
M_Inp = (1 - target_mask) * (1 - HT_E)
Smooth_M = self.dilated(M_Inp)
losses['inpaint'] = 0.5 * self.losses_dict['inpaint'] * self.LPIPS(normalize(target) * Smooth_M, gen_F_256 * Smooth_M)
losses['inpaint'] += 0.5 * self.losses_dict['inpaint'] * self.LPIPS(gen_w_256.detach() * Smooth_M * (1 - HT_E), gen_F_256 * Smooth_M * (1 - HT_E))
return losses