import torch import torch.nn.functional as F import torchvision.transforms as T from torch import nn from models.CtrlHair.shape_branch.config import cfg as cfg_mask from models.CtrlHair.shape_branch.solver import get_hair_face_code, get_new_shape, Solver as SolverMask from models.Encoders import RotateModel from models.Net import Net, get_segmentation from models.sean_codes.models.pix2pix_model import Pix2PixModel, SEAN_OPT, encode_sean, decode_sean from utils.image_utils import DilateErosion from utils.save_utils import save_vis_mask, save_gen_image, save_latents class Alignment(nn.Module): """ Module for transferring the desired hair shape """ def __init__(self, opts, latent_encoder=None, net=None): super().__init__() self.opts = opts self.latent_encoder = latent_encoder if not net: self.net = Net(self.opts) else: self.net = net self.sean_model = Pix2PixModel(SEAN_OPT) self.sean_model.eval() solver_mask = SolverMask(cfg_mask, device=self.opts.device, local_rank=-1, training=False) self.mask_generator = solver_mask.gen self.mask_generator.load_state_dict(torch.load('pretrained_models/ShapeAdaptor/mask_generator.pth')) self.rotate_model = RotateModel() self.rotate_model.load_state_dict(torch.load(self.opts.rotate_checkpoint)['model_state_dict']) self.rotate_model.to(self.opts.device).eval() self.dilate_erosion = DilateErosion(dilate_erosion=self.opts.smooth, device=self.opts.device) self.to_bisenet = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) @torch.inference_mode() def shape_module(self, im_name1: str, im_name2: str, name_to_embed, only_target=True, **kwargs): device = self.opts.device # load images img1_in = name_to_embed[im_name1]['image_256'] img2_in = name_to_embed[im_name2]['image_256'] # load latents latent_W_1 = name_to_embed[im_name1]["W"] latent_W_2 = name_to_embed[im_name2]["W"] # load masks inp_mask1 = name_to_embed[im_name1]['mask'] inp_mask2 = name_to_embed[im_name2]['mask'] # Rotate stage if img1_in is not img2_in: rotate_to = self.rotate_model(latent_W_2[:, :6], latent_W_1[:, :6]) rotate_to = torch.cat((rotate_to, latent_W_2[:, 6:]), dim=1) I_rot, _ = self.net.generator([rotate_to], input_is_latent=True, return_latents=False) I_rot_to_seg = ((I_rot + 1) / 2).clip(0, 1) I_rot_to_seg = self.to_bisenet(I_rot_to_seg) rot_mask = get_segmentation(I_rot_to_seg) else: I_rot = None rot_mask = inp_mask2 # Shape Adaptor if img1_in is not img2_in: face_1, hair_1 = get_hair_face_code(self.mask_generator, inp_mask1[0, 0, ...]) face_2, hair_2 = get_hair_face_code(self.mask_generator, rot_mask[0, 0, ...]) target_mask = get_new_shape(self.mask_generator, face_1, hair_2)[None, None] else: target_mask = inp_mask1 # Hair mask hair_mask_target = torch.where(target_mask == 13, torch.ones_like(target_mask, device=device), torch.zeros_like(target_mask, device=device)) if self.opts.save_all: exp_name = exp_name if (exp_name := kwargs.get('exp_name')) is not None else "" output_dir = self.opts.save_all_dir / exp_name if I_rot is not None: save_gen_image(output_dir, 'Shape', f'{im_name2}_rotate_to_{im_name1}.png', I_rot) save_vis_mask(output_dir, 'Shape', f'mask_{im_name1}.png', inp_mask1) save_vis_mask(output_dir, 'Shape', f'mask_{im_name2}.png', inp_mask2) save_vis_mask(output_dir, 'Shape', f'mask_{im_name2}_rotate_to_{im_name1}.png', rot_mask) save_vis_mask(output_dir, 'Shape', f'mask_{im_name1}_{im_name2}_target.png', target_mask) if only_target: return {'HM_X': hair_mask_target} else: hair_mask1 = torch.where(inp_mask1 == 13, torch.ones_like(inp_mask1, device=device), torch.zeros_like(inp_mask1, device=device)) hair_mask2 = torch.where(inp_mask2 == 13, torch.ones_like(inp_mask2, device=device), torch.zeros_like(inp_mask2, device=device)) return inp_mask1, hair_mask1, inp_mask2, hair_mask2, target_mask, hair_mask_target @torch.inference_mode() def align_images(self, im_name1, im_name2, name_to_embed, **kwargs): # load images img1_in = name_to_embed[im_name1]['image_256'] img2_in = name_to_embed[im_name2]['image_256'] # load latents latent_S_1, latent_F_1 = name_to_embed[im_name1]["S"], name_to_embed[im_name1]["F"] latent_S_2, latent_F_2 = name_to_embed[im_name2]["S"], name_to_embed[im_name2]["F"] # Shape Module if img1_in is img2_in: hair_mask_target = self.shape_module(im_name1, im_name2, name_to_embed, only_target=True, **kwargs)['HM_X'] return {'latent_F_align': latent_F_1, 'HM_X': hair_mask_target} inp_mask1, hair_mask1, inp_mask2, hair_mask2, target_mask, hair_mask_target = ( self.shape_module(im_name1, im_name2, name_to_embed, only_target=False, **kwargs) ) images = torch.cat([img1_in, img2_in], dim=0) labels = torch.cat([inp_mask1, inp_mask2], dim=0) # SEAN for inpaint img1_code, img2_code = encode_sean(self.sean_model, images, labels) gen1_sean = decode_sean(self.sean_model, img1_code.unsqueeze(0), target_mask) gen2_sean = decode_sean(self.sean_model, img2_code.unsqueeze(0), target_mask) # Encoding result in F from E4E enc_imgs = self.latent_encoder([gen1_sean, gen2_sean]) intermediate_align, latent_inter = enc_imgs["F"][0].unsqueeze(0), enc_imgs["W"][0].unsqueeze(0) latent_F_out_new, latent_out = enc_imgs["F"][1].unsqueeze(0), enc_imgs["W"][1].unsqueeze(0) # Alignment of F space masks = [ 1 - (1 - hair_mask1) * (1 - hair_mask_target), hair_mask_target, hair_mask2 * hair_mask_target ] masks = torch.cat(masks, dim=0) # masks = T.functional.resize(masks, (1024, 1024), interpolation=T.InterpolationMode.NEAREST) dilate, erosion = self.dilate_erosion.mask(masks) free_mask = [ dilate[0], erosion[1], erosion[2] ] free_mask = torch.stack(free_mask, dim=0) free_mask_down_32 = F.interpolate(free_mask.float(), size=(32, 32), mode='bicubic') interpolation_low = 1 - free_mask_down_32 latent_F_align = intermediate_align + interpolation_low[0] * (latent_F_1 - intermediate_align) latent_F_align = latent_F_out_new + interpolation_low[1] * (latent_F_align - latent_F_out_new) latent_F_align = latent_F_2 + interpolation_low[2] * (latent_F_align - latent_F_2) if self.opts.save_all: exp_name = exp_name if (exp_name := kwargs.get('exp_name')) is not None else "" output_dir = self.opts.save_all_dir / exp_name save_gen_image(output_dir, 'Align', f'{im_name1}_{im_name2}_SEAN.png', gen1_sean) save_gen_image(output_dir, 'Align', f'{im_name2}_{im_name1}_SEAN.png', gen2_sean) img1_e4e = self.net.generator([latent_inter], input_is_latent=True, return_latents=False, start_layer=4, end_layer=8, layer_in=intermediate_align)[0] img2_e4e = self.net.generator([latent_out], input_is_latent=True, return_latents=False, start_layer=4, end_layer=8, layer_in=latent_F_out_new)[0] save_gen_image(output_dir, 'Align', f'{im_name1}_{im_name2}_e4e.png', img1_e4e) save_gen_image(output_dir, 'Align', f'{im_name2}_{im_name1}_e4e.png', img2_e4e) gen_im, _ = self.net.generator([latent_S_1], input_is_latent=True, return_latents=False, start_layer=4, end_layer=8, layer_in=latent_F_align) save_gen_image(output_dir, 'Align', f'{im_name1}_{im_name2}_output.png', gen_im) save_latents(output_dir, 'Align', f'{im_name1}_{im_name2}_F.npz', latent_F_align=latent_F_align) return {'latent_F_align': latent_F_align, 'HM_X': hair_mask_target}