Spaces:
Build error
Build error
import sys | |
import os | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.data as data | |
from PIL import Image | |
from torch.autograd import grad | |
from torchvision import transforms, utils | |
import face_alignment | |
import lpips | |
current_dir = os.path.abspath(os.path.dirname(__file__)) | |
sys.path.insert(0, current_dir) | |
from pixel2style2pixel.models.stylegan2.model import Generator, get_keys | |
from nets.feature_style_encoder import * | |
from arcface.iresnet import * | |
from face_parsing.model import BiSeNet | |
from ranger import Ranger | |
import os | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.data as data | |
from PIL import Image | |
from torch.autograd import grad | |
def clip_img(x): | |
"""Clip stylegan generated image to range(0,1)""" | |
img_tmp = x.clone()[0] | |
img_tmp = (img_tmp + 1) / 2 | |
img_tmp = torch.clamp(img_tmp, 0, 1) | |
return [img_tmp.detach().cpu()] | |
def tensor_byte(x): | |
return x.element_size()*x.nelement() | |
def count_parameters(net): | |
s = sum([np.prod(list(mm.size())) for mm in net.parameters()]) | |
print(s) | |
def stylegan_to_classifier(x, out_size=(224, 224)): | |
"""Clip image to range(0,1)""" | |
img_tmp = x.clone() | |
img_tmp = torch.clamp((0.5*img_tmp + 0.5), 0, 1) | |
img_tmp = F.interpolate(img_tmp, size=out_size, mode='bilinear') | |
img_tmp[:,0] = (img_tmp[:,0] - 0.485)/0.229 | |
img_tmp[:,1] = (img_tmp[:,1] - 0.456)/0.224 | |
img_tmp[:,2] = (img_tmp[:,2] - 0.406)/0.225 | |
#img_tmp = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img_tmp) | |
return img_tmp | |
def downscale(x, scale_times=1, mode='bilinear'): | |
for i in range(scale_times): | |
x = F.interpolate(x, scale_factor=0.5, mode=mode) | |
return x | |
def upscale(x, scale_times=1, mode='bilinear'): | |
for i in range(scale_times): | |
x = F.interpolate(x, scale_factor=2, mode=mode) | |
return x | |
def hist_transform(source_tensor, target_tensor): | |
"""Histogram transformation""" | |
c, h, w = source_tensor.size() | |
s_t = source_tensor.view(c, -1) | |
t_t = target_tensor.view(c, -1) | |
s_t_sorted, s_t_indices = torch.sort(s_t) | |
t_t_sorted, t_t_indices = torch.sort(t_t) | |
for i in range(c): | |
s_t[i, s_t_indices[i]] = t_t_sorted[i] | |
return s_t.view(c, h, w) | |
def init_weights(m): | |
"""Initialize layers with Xavier uniform distribution""" | |
if type(m) == nn.Conv2d: | |
nn.init.xavier_uniform_(m.weight) | |
elif type(m) == nn.Linear: | |
nn.init.uniform_(m.weight, 0.0, 1.0) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0.01) | |
def total_variation(x, delta=1): | |
"""Total variation, x: tensor of size (B, C, H, W)""" | |
out = torch.mean(torch.abs(x[:, :, :, :-delta] - x[:, :, :, delta:]))\ | |
+ torch.mean(torch.abs(x[:, :, :-delta, :] - x[:, :, delta:, :])) | |
return out | |
def vgg_transform(x): | |
"""Adapt image for vgg network, x: image of range(0,1) subtracting ImageNet mean""" | |
r, g, b = torch.split(x, 1, 1) | |
out = torch.cat((b, g, r), dim = 1) | |
out = F.interpolate(out, size=(224, 224), mode='bilinear') | |
out = out*255. | |
return out | |
# warp image with flow | |
def normalize_axis(x,L): | |
return (x-1-(L-1)/2)*2/(L-1) | |
def unnormalize_axis(x,L): | |
return x*(L-1)/2+1+(L-1)/2 | |
def torch_flow_to_th_sampling_grid(flow,h_src,w_src,use_cuda=False): | |
b,c,h_tgt,w_tgt=flow.size() | |
grid_y, grid_x = torch.meshgrid(torch.tensor(range(1,w_tgt+1)),torch.tensor(range(1,h_tgt+1))) | |
disp_x=flow[:,0,:,:] | |
disp_y=flow[:,1,:,:] | |
source_x=grid_x.unsqueeze(0).repeat(b,1,1).type_as(flow)+disp_x | |
source_y=grid_y.unsqueeze(0).repeat(b,1,1).type_as(flow)+disp_y | |
source_x_norm=normalize_axis(source_x,w_src) | |
source_y_norm=normalize_axis(source_y,h_src) | |
sampling_grid=torch.cat((source_x_norm.unsqueeze(3), source_y_norm.unsqueeze(3)), dim=3) | |
if use_cuda: | |
sampling_grid = sampling_grid.cuda() | |
return sampling_grid | |
def warp_image_torch(image, flow): | |
""" | |
Warp image (tensor, shape=[b, 3, h_src, w_src]) with flow (tensor, shape=[b, h_tgt, w_tgt, 2]) | |
""" | |
b,c,h_src,w_src=image.size() | |
sampling_grid_torch = torch_flow_to_th_sampling_grid(flow, h_src, w_src) | |
warped_image_torch = F.grid_sample(image, sampling_grid_torch) | |
return warped_image_torch | |
class Trainer(nn.Module): | |
def __init__(self, config, opts): | |
super(Trainer, self).__init__() | |
# Load Hyperparameters | |
self.config = config | |
self.device = torch.device(self.config['device']) | |
self.scale = int(np.log2(config['resolution']/config['enc_resolution'])) | |
self.scale_mode = 'bilinear' | |
self.opts = opts | |
self.n_styles = 2 * int(np.log2(config['resolution'])) - 2 | |
self.idx_k = 5 | |
if 'idx_k' in self.config: | |
self.idx_k = self.config['idx_k'] | |
if 'stylegan_version' in self.config and self.config['stylegan_version'] == 3: | |
self.n_styles = 16 | |
# Networks | |
in_channels = 256 | |
if 'in_c' in self.config: | |
in_channels = config['in_c'] | |
enc_residual = False | |
if 'enc_residual' in self.config: | |
enc_residual = self.config['enc_residual'] | |
enc_residual_coeff = False | |
if 'enc_residual_coeff' in self.config: | |
enc_residual_coeff = self.config['enc_residual_coeff'] | |
resnet_layers = [4,5,6] | |
if 'enc_start_layer' in self.config: | |
st_l = self.config['enc_start_layer'] | |
resnet_layers = [st_l, st_l+1, st_l+2] | |
if 'scale_mode' in self.config: | |
self.scale_mode = self.config['scale_mode'] | |
# Load encoder | |
self.stride = (self.config['fs_stride'], self.config['fs_stride']) | |
self.enc = fs_encoder_v2(n_styles=self.n_styles, opts=opts, residual=enc_residual, use_coeff=enc_residual_coeff, resnet_layer=resnet_layers, stride=self.stride) | |
########################## | |
# Other nets | |
self.StyleGAN = self.init_stylegan(config) | |
self.Arcface = iresnet50() | |
self.parsing_net = BiSeNet(n_classes=19) | |
# Optimizers | |
# Latent encoder | |
self.enc_params = list(self.enc.parameters()) | |
if 'freeze_iresnet' in self.config and self.config['freeze_iresnet']: | |
self.enc_params = list(self.enc.styles.parameters()) | |
if 'optimizer' in self.config and self.config['optimizer'] == 'ranger': | |
self.enc_opt = Ranger(self.enc_params, lr=config['lr'], betas=(config['beta_1'], config['beta_2']), weight_decay=config['weight_decay']) | |
else: | |
self.enc_opt = torch.optim.Adam(self.enc_params, lr=config['lr'], betas=(config['beta_1'], config['beta_2']), weight_decay=config['weight_decay']) | |
self.enc_scheduler = torch.optim.lr_scheduler.StepLR(self.enc_opt, step_size=config['step_size'], gamma=config['gamma']) | |
self.fea_avg = None | |
def initialize(self, stylegan_model_path, arcface_model_path, parsing_model_path): | |
# load StyleGAN model | |
stylegan_state_dict = torch.load(stylegan_model_path, map_location='cpu') | |
self.StyleGAN.load_state_dict(get_keys(stylegan_state_dict, 'decoder'), strict=True) | |
self.StyleGAN.to(self.device) | |
# get StyleGAN average latent in w space and the noise inputs | |
self.dlatent_avg = stylegan_state_dict['latent_avg'].to(self.device) | |
self.noise_inputs = [getattr(self.StyleGAN.noises, f'noise_{i}').to(self.device) for i in range(self.StyleGAN.num_layers)] | |
# load Arcface weight | |
self.Arcface.load_state_dict(torch.load(self.opts.arcface_model_path)) | |
self.Arcface.eval() | |
# load face parsing net weight | |
self.parsing_net.load_state_dict(torch.load(self.opts.parsing_model_path)) | |
self.parsing_net.eval() | |
# load lpips net weight | |
# self.loss_fn = lpips.LPIPS(net='alex', spatial=False) | |
# self.loss_fn.to(self.device) | |
def init_stylegan(self, config): | |
"""StyleGAN = G_main( | |
truncation_psi=config['truncation_psi'], | |
resolution=config['resolution'], | |
use_noise=config['use_noise'], | |
randomize_noise=config['randomize_noise'] | |
)""" | |
StyleGAN = Generator(1024, 512, 8) | |
return StyleGAN | |
def mapping(self, z): | |
return self.StyleGAN.get_latent(z).detach() | |
def L1loss(self, input, target): | |
return nn.L1Loss()(input,target) | |
def L2loss(self, input, target): | |
return nn.MSELoss()(input,target) | |
def CEloss(self, x, target_age): | |
return nn.CrossEntropyLoss()(x, target_age) | |
def LPIPS(self, input, target, multi_scale=False): | |
if multi_scale: | |
out = 0 | |
for k in range(3): | |
out += self.loss_fn.forward(downscale(input, k, self.scale_mode), downscale(target, k, self.scale_mode)).mean() | |
else: | |
out = self.loss_fn.forward(downscale(input, self.scale, self.scale_mode), downscale(target, self.scale, self.scale_mode)).mean() | |
return out | |
def IDloss(self, input, target): | |
x_1 = F.interpolate(input, (112,112)) | |
x_2 = F.interpolate(target, (112,112)) | |
cos = nn.CosineSimilarity(dim=1, eps=1e-6) | |
if 'multi_layer_idloss' in self.config and self.config['multi_layer_idloss']: | |
id_1 = self.Arcface(x_1, return_features=True) | |
id_2 = self.Arcface(x_2, return_features=True) | |
return sum([1 - cos(id_1[i].flatten(start_dim=1), id_2[i].flatten(start_dim=1)) for i in range(len(id_1))]) | |
else: | |
id_1 = self.Arcface(x_1) | |
id_2 = self.Arcface(x_2) | |
return 1 - cos(id_1, id_2) | |
def landmarkloss(self, input, target): | |
cos = nn.CosineSimilarity(dim=1, eps=1e-6) | |
x_1 = stylegan_to_classifier(input, out_size=(512, 512)) | |
x_2 = stylegan_to_classifier(target, out_size=(512,512)) | |
out_1 = self.parsing_net(x_1) | |
out_2 = self.parsing_net(x_2) | |
parsing_loss = sum([1 - cos(out_1[i].flatten(start_dim=1), out_2[i].flatten(start_dim=1)) for i in range(len(out_1))]) | |
return parsing_loss.mean() | |
def feature_match(self, enc_feat, dec_feat, layer_idx=None): | |
loss = [] | |
if layer_idx is None: | |
layer_idx = [i for i in range(len(enc_feat))] | |
for i in layer_idx: | |
loss.append(self.L1loss(enc_feat[i], dec_feat[i])) | |
return loss | |
def encode(self, img): | |
w_recon, fea = self.enc(downscale(img, self.scale, self.scale_mode)) | |
w_recon = w_recon + self.dlatent_avg | |
return w_recon, fea | |
def get_image(self, w=None, img=None, noise=None, zero_noise_input=True, training_mode=True): | |
x_1, n_1 = img, noise | |
if x_1 is None: | |
x_1, _ = self.StyleGAN([w], input_is_latent=True, noise = n_1) | |
w_delta = None | |
fea = None | |
features = None | |
return_features = False | |
# Reconstruction | |
k = 0 | |
if 'use_fs_encoder' in self.config and self.config['use_fs_encoder']: | |
return_features = True | |
k = self.idx_k | |
w_recon, fea = self.enc(downscale(x_1, self.scale, self.scale_mode)) | |
w_recon = w_recon + self.dlatent_avg | |
features = [None]*k + [fea] + [None]*(17-k) | |
else: | |
w_recon = self.enc(downscale(x_1, self.scale, self.scale_mode)) + self.dlatent_avg | |
# generate image | |
x_1_recon, fea_recon = self.StyleGAN([w_recon], input_is_latent=True, return_features=True, features_in=features, feature_scale=min(1.0, 0.0001*self.n_iter)) | |
fea_recon = fea_recon[k].detach() | |
return [x_1_recon, x_1[:,:3,:,:], w_recon, w_delta, n_1, fea, fea_recon] | |
def compute_loss(self, w=None, img=None, noise=None, real_img=None): | |
return self.compute_loss_stylegan2(w=w, img=img, noise=noise, real_img=real_img) | |
def compute_loss_stylegan2(self, w=None, img=None, noise=None, real_img=None): | |
if img is None: | |
# generate synthetic images | |
if noise is None: | |
noise = [torch.randn(w.size()[:1] + ee.size()[1:]).to(self.device) for ee in self.noise_inputs] | |
img, _ = self.StyleGAN([w], input_is_latent=True, noise = noise) | |
img = img.detach() | |
if img is not None and real_img is not None: | |
# concat synthetic and real data | |
img = torch.cat([img, real_img], dim=0) | |
noise = [torch.cat([ee, ee], dim=0) for ee in noise] | |
out = self.get_image(w=w, img=img, noise=noise) | |
x_1_recon, x_1, w_recon, w_delta, n_1, fea_1, fea_recon = out | |
# Loss setting | |
w_l2, w_lpips, w_id = self.config['w']['l2'], self.config['w']['lpips'], self.config['w']['id'] | |
b = x_1.size(0)//2 | |
if 'l2loss_on_real_image' in self.config and self.config['l2loss_on_real_image']: | |
b = x_1.size(0) | |
self.l2_loss = self.L2loss(x_1_recon[:b], x_1[:b]) if w_l2 > 0 else torch.tensor(0) # l2 loss only on synthetic data | |
# LPIPS | |
multiscale_lpips=False if 'multiscale_lpips' not in self.config else self.config['multiscale_lpips'] | |
self.lpips_loss = self.LPIPS(x_1_recon, x_1, multi_scale=multiscale_lpips).mean() if w_lpips > 0 else torch.tensor(0) | |
self.id_loss = self.IDloss(x_1_recon, x_1).mean() if w_id > 0 else torch.tensor(0) | |
self.landmark_loss = self.landmarkloss(x_1_recon, x_1) if self.config['w']['landmark'] > 0 else torch.tensor(0) | |
if 'use_fs_encoder' in self.config and self.config['use_fs_encoder']: | |
k = self.idx_k | |
features = [None]*k + [fea_1] + [None]*(17-k) | |
x_1_recon_2, _ = self.StyleGAN([w_recon], noise=n_1, input_is_latent=True, features_in=features, feature_scale=min(1.0, 0.0001*self.n_iter)) | |
self.lpips_loss += self.LPIPS(x_1_recon_2, x_1, multi_scale=multiscale_lpips).mean() if w_lpips > 0 else torch.tensor(0) | |
self.id_loss += self.IDloss(x_1_recon_2, x_1).mean() if w_id > 0 else torch.tensor(0) | |
self.landmark_loss += self.landmarkloss(x_1_recon_2, x_1) if self.config['w']['landmark'] > 0 else torch.tensor(0) | |
# downscale image | |
x_1 = downscale(x_1, self.scale, self.scale_mode) | |
x_1_recon = downscale(x_1_recon, self.scale, self.scale_mode) | |
# Total loss | |
w_l2, w_lpips, w_id = self.config['w']['l2'], self.config['w']['lpips'], self.config['w']['id'] | |
self.loss = w_l2*self.l2_loss + w_lpips*self.lpips_loss + w_id*self.id_loss | |
if 'f_recon' in self.config['w']: | |
self.feature_recon_loss = self.L2loss(fea_1, fea_recon) | |
self.loss += self.config['w']['f_recon']*self.feature_recon_loss | |
if 'l1' in self.config['w'] and self.config['w']['l1']>0: | |
self.l1_loss = self.L1loss(x_1_recon, x_1) | |
self.loss += self.config['w']['l1']*self.l1_loss | |
if 'landmark' in self.config['w']: | |
self.loss += self.config['w']['landmark']*self.landmark_loss | |
return self.loss | |
def test(self, w=None, img=None, noise=None, zero_noise_input=True, return_latent=False, training_mode=False): | |
if 'n_iter' not in self.__dict__.keys(): | |
self.n_iter = 1e5 | |
out = self.get_image(w=w, img=img, noise=noise, training_mode=training_mode) | |
x_1_recon, x_1, w_recon, w_delta, n_1, fea_1 = out[:6] | |
output = [x_1, x_1_recon] | |
if return_latent: | |
output += [w_recon, fea_1] | |
return output | |
def log_loss(self, logger, n_iter, prefix='scripts'): | |
logger.log_value(prefix + '/l2_loss', self.l2_loss.item(), n_iter + 1) | |
logger.log_value(prefix + '/lpips_loss', self.lpips_loss.item(), n_iter + 1) | |
logger.log_value(prefix + '/id_loss', self.id_loss.item(), n_iter + 1) | |
logger.log_value(prefix + '/total_loss', self.loss.item(), n_iter + 1) | |
if 'f_recon' in self.config['w']: | |
logger.log_value(prefix + '/feature_recon_loss', self.feature_recon_loss.item(), n_iter + 1) | |
if 'l1' in self.config['w'] and self.config['w']['l1']>0: | |
logger.log_value(prefix + '/l1_loss', self.l1_loss.item(), n_iter + 1) | |
if 'landmark' in self.config['w']: | |
logger.log_value(prefix + '/landmark_loss', self.landmark_loss.item(), n_iter + 1) | |
def save_image(self, log_dir, n_epoch, n_iter, prefix='/scripts/', w=None, img=None, noise=None, training_mode=True): | |
return self.save_image_stylegan2(log_dir=log_dir, n_epoch=n_epoch, n_iter=n_iter, prefix=prefix, w=w, img=img, noise=noise, training_mode=training_mode) | |
def save_image_stylegan2(self, log_dir, n_epoch, n_iter, prefix='/scripts/', w=None, img=None, noise=None, training_mode=True): | |
os.makedirs(log_dir + prefix, exist_ok=True) | |
with torch.no_grad(): | |
out = self.get_image(w=w, img=img, noise=noise, training_mode=training_mode) | |
x_1_recon, x_1, w_recon, w_delta, n_1, fea_1 = out[:6] | |
x_1 = downscale(x_1, self.scale, self.scale_mode) | |
x_1_recon = downscale(x_1_recon, self.scale, self.scale_mode) | |
out_img = torch.cat((x_1, x_1_recon), dim=3) | |
#fs | |
if 'use_fs_encoder' in self.config and self.config['use_fs_encoder']: | |
k = self.idx_k | |
features = [None]*k + [fea_1] + [None]*(17-k) | |
x_1_recon_2, _ = self.StyleGAN([w_recon], noise=n_1, input_is_latent=True, features_in=features, feature_scale=min(1.0, 0.0001*self.n_iter)) | |
x_1_recon_2 = downscale(x_1_recon_2, self.scale, self.scale_mode) | |
out_img = torch.cat((x_1, x_1_recon, x_1_recon_2), dim=3) | |
utils.save_image(clip_img(out_img[:1]), log_dir + prefix + 'epoch_' +str(n_epoch+1) + '_iter_' + str(n_iter+1) + '_0.jpg') | |
if out_img.size(0)>1: | |
utils.save_image(clip_img(out_img[1:]), log_dir + prefix + 'epoch_' +str(n_epoch+1) + '_iter_' + str(n_iter+1) + '_1.jpg') | |
def save_model(self, log_dir): | |
torch.save(self.enc.state_dict(),'{:s}/enc.pth.tar'.format(log_dir)) | |
def save_checkpoint(self, n_epoch, log_dir): | |
checkpoint_state = { | |
'n_epoch': n_epoch, | |
'enc_state_dict': self.enc.state_dict(), | |
'enc_opt_state_dict': self.enc_opt.state_dict(), | |
'enc_scheduler_state_dict': self.enc_scheduler.state_dict() | |
} | |
torch.save(checkpoint_state, '{:s}/checkpoint.pth'.format(log_dir)) | |
if (n_epoch+1)%10 == 0 : | |
torch.save(checkpoint_state, '{:s}/checkpoint'.format(log_dir)+'_'+str(n_epoch+1)+'.pth') | |
def load_model(self, log_dir): | |
self.enc.load_state_dict(torch.load('{:s}/enc.pth.tar'.format(log_dir))) | |
def load_checkpoint(self, checkpoint_path): | |
state_dict = torch.load(checkpoint_path) | |
self.enc.load_state_dict(state_dict['enc_state_dict']) | |
self.enc_opt.load_state_dict(state_dict['enc_opt_state_dict']) | |
self.enc_scheduler.load_state_dict(state_dict['enc_scheduler_state_dict']) | |
return state_dict['n_epoch'] + 1 | |
def update(self, w=None, img=None, noise=None, real_img=None, n_iter=0): | |
self.n_iter = n_iter | |
self.enc_opt.zero_grad() | |
self.compute_loss(w=w, img=img, noise=noise, real_img=real_img).backward() | |
self.enc_opt.step() | |