amanSethSmava
new commit
6d314be
raw
history blame
19.5 kB
import sys
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from PIL import Image
from torch.autograd import grad
from torchvision import transforms, utils
import face_alignment
import lpips
current_dir = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, current_dir)
from pixel2style2pixel.models.stylegan2.model import Generator, get_keys
from nets.feature_style_encoder import *
from arcface.iresnet import *
from face_parsing.model import BiSeNet
from ranger import Ranger
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from PIL import Image
from torch.autograd import grad
def clip_img(x):
"""Clip stylegan generated image to range(0,1)"""
img_tmp = x.clone()[0]
img_tmp = (img_tmp + 1) / 2
img_tmp = torch.clamp(img_tmp, 0, 1)
return [img_tmp.detach().cpu()]
def tensor_byte(x):
return x.element_size()*x.nelement()
def count_parameters(net):
s = sum([np.prod(list(mm.size())) for mm in net.parameters()])
print(s)
def stylegan_to_classifier(x, out_size=(224, 224)):
"""Clip image to range(0,1)"""
img_tmp = x.clone()
img_tmp = torch.clamp((0.5*img_tmp + 0.5), 0, 1)
img_tmp = F.interpolate(img_tmp, size=out_size, mode='bilinear')
img_tmp[:,0] = (img_tmp[:,0] - 0.485)/0.229
img_tmp[:,1] = (img_tmp[:,1] - 0.456)/0.224
img_tmp[:,2] = (img_tmp[:,2] - 0.406)/0.225
#img_tmp = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img_tmp)
return img_tmp
def downscale(x, scale_times=1, mode='bilinear'):
for i in range(scale_times):
x = F.interpolate(x, scale_factor=0.5, mode=mode)
return x
def upscale(x, scale_times=1, mode='bilinear'):
for i in range(scale_times):
x = F.interpolate(x, scale_factor=2, mode=mode)
return x
def hist_transform(source_tensor, target_tensor):
"""Histogram transformation"""
c, h, w = source_tensor.size()
s_t = source_tensor.view(c, -1)
t_t = target_tensor.view(c, -1)
s_t_sorted, s_t_indices = torch.sort(s_t)
t_t_sorted, t_t_indices = torch.sort(t_t)
for i in range(c):
s_t[i, s_t_indices[i]] = t_t_sorted[i]
return s_t.view(c, h, w)
def init_weights(m):
"""Initialize layers with Xavier uniform distribution"""
if type(m) == nn.Conv2d:
nn.init.xavier_uniform_(m.weight)
elif type(m) == nn.Linear:
nn.init.uniform_(m.weight, 0.0, 1.0)
if m.bias is not None:
nn.init.constant_(m.bias, 0.01)
def total_variation(x, delta=1):
"""Total variation, x: tensor of size (B, C, H, W)"""
out = torch.mean(torch.abs(x[:, :, :, :-delta] - x[:, :, :, delta:]))\
+ torch.mean(torch.abs(x[:, :, :-delta, :] - x[:, :, delta:, :]))
return out
def vgg_transform(x):
"""Adapt image for vgg network, x: image of range(0,1) subtracting ImageNet mean"""
r, g, b = torch.split(x, 1, 1)
out = torch.cat((b, g, r), dim = 1)
out = F.interpolate(out, size=(224, 224), mode='bilinear')
out = out*255.
return out
# warp image with flow
def normalize_axis(x,L):
return (x-1-(L-1)/2)*2/(L-1)
def unnormalize_axis(x,L):
return x*(L-1)/2+1+(L-1)/2
def torch_flow_to_th_sampling_grid(flow,h_src,w_src,use_cuda=False):
b,c,h_tgt,w_tgt=flow.size()
grid_y, grid_x = torch.meshgrid(torch.tensor(range(1,w_tgt+1)),torch.tensor(range(1,h_tgt+1)))
disp_x=flow[:,0,:,:]
disp_y=flow[:,1,:,:]
source_x=grid_x.unsqueeze(0).repeat(b,1,1).type_as(flow)+disp_x
source_y=grid_y.unsqueeze(0).repeat(b,1,1).type_as(flow)+disp_y
source_x_norm=normalize_axis(source_x,w_src)
source_y_norm=normalize_axis(source_y,h_src)
sampling_grid=torch.cat((source_x_norm.unsqueeze(3), source_y_norm.unsqueeze(3)), dim=3)
if use_cuda:
sampling_grid = sampling_grid.cuda()
return sampling_grid
def warp_image_torch(image, flow):
"""
Warp image (tensor, shape=[b, 3, h_src, w_src]) with flow (tensor, shape=[b, h_tgt, w_tgt, 2])
"""
b,c,h_src,w_src=image.size()
sampling_grid_torch = torch_flow_to_th_sampling_grid(flow, h_src, w_src)
warped_image_torch = F.grid_sample(image, sampling_grid_torch)
return warped_image_torch
class Trainer(nn.Module):
def __init__(self, config, opts):
super(Trainer, self).__init__()
# Load Hyperparameters
self.config = config
self.device = torch.device(self.config['device'])
self.scale = int(np.log2(config['resolution']/config['enc_resolution']))
self.scale_mode = 'bilinear'
self.opts = opts
self.n_styles = 2 * int(np.log2(config['resolution'])) - 2
self.idx_k = 5
if 'idx_k' in self.config:
self.idx_k = self.config['idx_k']
if 'stylegan_version' in self.config and self.config['stylegan_version'] == 3:
self.n_styles = 16
# Networks
in_channels = 256
if 'in_c' in self.config:
in_channels = config['in_c']
enc_residual = False
if 'enc_residual' in self.config:
enc_residual = self.config['enc_residual']
enc_residual_coeff = False
if 'enc_residual_coeff' in self.config:
enc_residual_coeff = self.config['enc_residual_coeff']
resnet_layers = [4,5,6]
if 'enc_start_layer' in self.config:
st_l = self.config['enc_start_layer']
resnet_layers = [st_l, st_l+1, st_l+2]
if 'scale_mode' in self.config:
self.scale_mode = self.config['scale_mode']
# Load encoder
self.stride = (self.config['fs_stride'], self.config['fs_stride'])
self.enc = fs_encoder_v2(n_styles=self.n_styles, opts=opts, residual=enc_residual, use_coeff=enc_residual_coeff, resnet_layer=resnet_layers, stride=self.stride)
##########################
# Other nets
self.StyleGAN = self.init_stylegan(config)
self.Arcface = iresnet50()
self.parsing_net = BiSeNet(n_classes=19)
# Optimizers
# Latent encoder
self.enc_params = list(self.enc.parameters())
if 'freeze_iresnet' in self.config and self.config['freeze_iresnet']:
self.enc_params = list(self.enc.styles.parameters())
if 'optimizer' in self.config and self.config['optimizer'] == 'ranger':
self.enc_opt = Ranger(self.enc_params, lr=config['lr'], betas=(config['beta_1'], config['beta_2']), weight_decay=config['weight_decay'])
else:
self.enc_opt = torch.optim.Adam(self.enc_params, lr=config['lr'], betas=(config['beta_1'], config['beta_2']), weight_decay=config['weight_decay'])
self.enc_scheduler = torch.optim.lr_scheduler.StepLR(self.enc_opt, step_size=config['step_size'], gamma=config['gamma'])
self.fea_avg = None
def initialize(self, stylegan_model_path, arcface_model_path, parsing_model_path):
# load StyleGAN model
stylegan_state_dict = torch.load(stylegan_model_path, map_location='cpu')
self.StyleGAN.load_state_dict(get_keys(stylegan_state_dict, 'decoder'), strict=True)
self.StyleGAN.to(self.device)
# get StyleGAN average latent in w space and the noise inputs
self.dlatent_avg = stylegan_state_dict['latent_avg'].to(self.device)
self.noise_inputs = [getattr(self.StyleGAN.noises, f'noise_{i}').to(self.device) for i in range(self.StyleGAN.num_layers)]
# load Arcface weight
self.Arcface.load_state_dict(torch.load(self.opts.arcface_model_path))
self.Arcface.eval()
# load face parsing net weight
self.parsing_net.load_state_dict(torch.load(self.opts.parsing_model_path))
self.parsing_net.eval()
# load lpips net weight
# self.loss_fn = lpips.LPIPS(net='alex', spatial=False)
# self.loss_fn.to(self.device)
def init_stylegan(self, config):
"""StyleGAN = G_main(
truncation_psi=config['truncation_psi'],
resolution=config['resolution'],
use_noise=config['use_noise'],
randomize_noise=config['randomize_noise']
)"""
StyleGAN = Generator(1024, 512, 8)
return StyleGAN
def mapping(self, z):
return self.StyleGAN.get_latent(z).detach()
def L1loss(self, input, target):
return nn.L1Loss()(input,target)
def L2loss(self, input, target):
return nn.MSELoss()(input,target)
def CEloss(self, x, target_age):
return nn.CrossEntropyLoss()(x, target_age)
def LPIPS(self, input, target, multi_scale=False):
if multi_scale:
out = 0
for k in range(3):
out += self.loss_fn.forward(downscale(input, k, self.scale_mode), downscale(target, k, self.scale_mode)).mean()
else:
out = self.loss_fn.forward(downscale(input, self.scale, self.scale_mode), downscale(target, self.scale, self.scale_mode)).mean()
return out
def IDloss(self, input, target):
x_1 = F.interpolate(input, (112,112))
x_2 = F.interpolate(target, (112,112))
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
if 'multi_layer_idloss' in self.config and self.config['multi_layer_idloss']:
id_1 = self.Arcface(x_1, return_features=True)
id_2 = self.Arcface(x_2, return_features=True)
return sum([1 - cos(id_1[i].flatten(start_dim=1), id_2[i].flatten(start_dim=1)) for i in range(len(id_1))])
else:
id_1 = self.Arcface(x_1)
id_2 = self.Arcface(x_2)
return 1 - cos(id_1, id_2)
def landmarkloss(self, input, target):
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
x_1 = stylegan_to_classifier(input, out_size=(512, 512))
x_2 = stylegan_to_classifier(target, out_size=(512,512))
out_1 = self.parsing_net(x_1)
out_2 = self.parsing_net(x_2)
parsing_loss = sum([1 - cos(out_1[i].flatten(start_dim=1), out_2[i].flatten(start_dim=1)) for i in range(len(out_1))])
return parsing_loss.mean()
def feature_match(self, enc_feat, dec_feat, layer_idx=None):
loss = []
if layer_idx is None:
layer_idx = [i for i in range(len(enc_feat))]
for i in layer_idx:
loss.append(self.L1loss(enc_feat[i], dec_feat[i]))
return loss
def encode(self, img):
w_recon, fea = self.enc(downscale(img, self.scale, self.scale_mode))
w_recon = w_recon + self.dlatent_avg
return w_recon, fea
def get_image(self, w=None, img=None, noise=None, zero_noise_input=True, training_mode=True):
x_1, n_1 = img, noise
if x_1 is None:
x_1, _ = self.StyleGAN([w], input_is_latent=True, noise = n_1)
w_delta = None
fea = None
features = None
return_features = False
# Reconstruction
k = 0
if 'use_fs_encoder' in self.config and self.config['use_fs_encoder']:
return_features = True
k = self.idx_k
w_recon, fea = self.enc(downscale(x_1, self.scale, self.scale_mode))
w_recon = w_recon + self.dlatent_avg
features = [None]*k + [fea] + [None]*(17-k)
else:
w_recon = self.enc(downscale(x_1, self.scale, self.scale_mode)) + self.dlatent_avg
# generate image
x_1_recon, fea_recon = self.StyleGAN([w_recon], input_is_latent=True, return_features=True, features_in=features, feature_scale=min(1.0, 0.0001*self.n_iter))
fea_recon = fea_recon[k].detach()
return [x_1_recon, x_1[:,:3,:,:], w_recon, w_delta, n_1, fea, fea_recon]
def compute_loss(self, w=None, img=None, noise=None, real_img=None):
return self.compute_loss_stylegan2(w=w, img=img, noise=noise, real_img=real_img)
def compute_loss_stylegan2(self, w=None, img=None, noise=None, real_img=None):
if img is None:
# generate synthetic images
if noise is None:
noise = [torch.randn(w.size()[:1] + ee.size()[1:]).to(self.device) for ee in self.noise_inputs]
img, _ = self.StyleGAN([w], input_is_latent=True, noise = noise)
img = img.detach()
if img is not None and real_img is not None:
# concat synthetic and real data
img = torch.cat([img, real_img], dim=0)
noise = [torch.cat([ee, ee], dim=0) for ee in noise]
out = self.get_image(w=w, img=img, noise=noise)
x_1_recon, x_1, w_recon, w_delta, n_1, fea_1, fea_recon = out
# Loss setting
w_l2, w_lpips, w_id = self.config['w']['l2'], self.config['w']['lpips'], self.config['w']['id']
b = x_1.size(0)//2
if 'l2loss_on_real_image' in self.config and self.config['l2loss_on_real_image']:
b = x_1.size(0)
self.l2_loss = self.L2loss(x_1_recon[:b], x_1[:b]) if w_l2 > 0 else torch.tensor(0) # l2 loss only on synthetic data
# LPIPS
multiscale_lpips=False if 'multiscale_lpips' not in self.config else self.config['multiscale_lpips']
self.lpips_loss = self.LPIPS(x_1_recon, x_1, multi_scale=multiscale_lpips).mean() if w_lpips > 0 else torch.tensor(0)
self.id_loss = self.IDloss(x_1_recon, x_1).mean() if w_id > 0 else torch.tensor(0)
self.landmark_loss = self.landmarkloss(x_1_recon, x_1) if self.config['w']['landmark'] > 0 else torch.tensor(0)
if 'use_fs_encoder' in self.config and self.config['use_fs_encoder']:
k = self.idx_k
features = [None]*k + [fea_1] + [None]*(17-k)
x_1_recon_2, _ = self.StyleGAN([w_recon], noise=n_1, input_is_latent=True, features_in=features, feature_scale=min(1.0, 0.0001*self.n_iter))
self.lpips_loss += self.LPIPS(x_1_recon_2, x_1, multi_scale=multiscale_lpips).mean() if w_lpips > 0 else torch.tensor(0)
self.id_loss += self.IDloss(x_1_recon_2, x_1).mean() if w_id > 0 else torch.tensor(0)
self.landmark_loss += self.landmarkloss(x_1_recon_2, x_1) if self.config['w']['landmark'] > 0 else torch.tensor(0)
# downscale image
x_1 = downscale(x_1, self.scale, self.scale_mode)
x_1_recon = downscale(x_1_recon, self.scale, self.scale_mode)
# Total loss
w_l2, w_lpips, w_id = self.config['w']['l2'], self.config['w']['lpips'], self.config['w']['id']
self.loss = w_l2*self.l2_loss + w_lpips*self.lpips_loss + w_id*self.id_loss
if 'f_recon' in self.config['w']:
self.feature_recon_loss = self.L2loss(fea_1, fea_recon)
self.loss += self.config['w']['f_recon']*self.feature_recon_loss
if 'l1' in self.config['w'] and self.config['w']['l1']>0:
self.l1_loss = self.L1loss(x_1_recon, x_1)
self.loss += self.config['w']['l1']*self.l1_loss
if 'landmark' in self.config['w']:
self.loss += self.config['w']['landmark']*self.landmark_loss
return self.loss
def test(self, w=None, img=None, noise=None, zero_noise_input=True, return_latent=False, training_mode=False):
if 'n_iter' not in self.__dict__.keys():
self.n_iter = 1e5
out = self.get_image(w=w, img=img, noise=noise, training_mode=training_mode)
x_1_recon, x_1, w_recon, w_delta, n_1, fea_1 = out[:6]
output = [x_1, x_1_recon]
if return_latent:
output += [w_recon, fea_1]
return output
def log_loss(self, logger, n_iter, prefix='scripts'):
logger.log_value(prefix + '/l2_loss', self.l2_loss.item(), n_iter + 1)
logger.log_value(prefix + '/lpips_loss', self.lpips_loss.item(), n_iter + 1)
logger.log_value(prefix + '/id_loss', self.id_loss.item(), n_iter + 1)
logger.log_value(prefix + '/total_loss', self.loss.item(), n_iter + 1)
if 'f_recon' in self.config['w']:
logger.log_value(prefix + '/feature_recon_loss', self.feature_recon_loss.item(), n_iter + 1)
if 'l1' in self.config['w'] and self.config['w']['l1']>0:
logger.log_value(prefix + '/l1_loss', self.l1_loss.item(), n_iter + 1)
if 'landmark' in self.config['w']:
logger.log_value(prefix + '/landmark_loss', self.landmark_loss.item(), n_iter + 1)
def save_image(self, log_dir, n_epoch, n_iter, prefix='/scripts/', w=None, img=None, noise=None, training_mode=True):
return self.save_image_stylegan2(log_dir=log_dir, n_epoch=n_epoch, n_iter=n_iter, prefix=prefix, w=w, img=img, noise=noise, training_mode=training_mode)
def save_image_stylegan2(self, log_dir, n_epoch, n_iter, prefix='/scripts/', w=None, img=None, noise=None, training_mode=True):
os.makedirs(log_dir + prefix, exist_ok=True)
with torch.no_grad():
out = self.get_image(w=w, img=img, noise=noise, training_mode=training_mode)
x_1_recon, x_1, w_recon, w_delta, n_1, fea_1 = out[:6]
x_1 = downscale(x_1, self.scale, self.scale_mode)
x_1_recon = downscale(x_1_recon, self.scale, self.scale_mode)
out_img = torch.cat((x_1, x_1_recon), dim=3)
#fs
if 'use_fs_encoder' in self.config and self.config['use_fs_encoder']:
k = self.idx_k
features = [None]*k + [fea_1] + [None]*(17-k)
x_1_recon_2, _ = self.StyleGAN([w_recon], noise=n_1, input_is_latent=True, features_in=features, feature_scale=min(1.0, 0.0001*self.n_iter))
x_1_recon_2 = downscale(x_1_recon_2, self.scale, self.scale_mode)
out_img = torch.cat((x_1, x_1_recon, x_1_recon_2), dim=3)
utils.save_image(clip_img(out_img[:1]), log_dir + prefix + 'epoch_' +str(n_epoch+1) + '_iter_' + str(n_iter+1) + '_0.jpg')
if out_img.size(0)>1:
utils.save_image(clip_img(out_img[1:]), log_dir + prefix + 'epoch_' +str(n_epoch+1) + '_iter_' + str(n_iter+1) + '_1.jpg')
def save_model(self, log_dir):
torch.save(self.enc.state_dict(),'{:s}/enc.pth.tar'.format(log_dir))
def save_checkpoint(self, n_epoch, log_dir):
checkpoint_state = {
'n_epoch': n_epoch,
'enc_state_dict': self.enc.state_dict(),
'enc_opt_state_dict': self.enc_opt.state_dict(),
'enc_scheduler_state_dict': self.enc_scheduler.state_dict()
}
torch.save(checkpoint_state, '{:s}/checkpoint.pth'.format(log_dir))
if (n_epoch+1)%10 == 0 :
torch.save(checkpoint_state, '{:s}/checkpoint'.format(log_dir)+'_'+str(n_epoch+1)+'.pth')
def load_model(self, log_dir):
self.enc.load_state_dict(torch.load('{:s}/enc.pth.tar'.format(log_dir)))
def load_checkpoint(self, checkpoint_path):
state_dict = torch.load(checkpoint_path)
self.enc.load_state_dict(state_dict['enc_state_dict'])
self.enc_opt.load_state_dict(state_dict['enc_opt_state_dict'])
self.enc_scheduler.load_state_dict(state_dict['enc_scheduler_state_dict'])
return state_dict['n_epoch'] + 1
def update(self, w=None, img=None, noise=None, real_img=None, n_iter=0):
self.n_iter = n_iter
self.enc_opt.zero_grad()
self.compute_loss(w=w, img=img, noise=noise, real_img=real_img).backward()
self.enc_opt.step()