import argparse import os import sys from argparse import Namespace from tempfile import TemporaryDirectory import numpy as np import torch import torch.nn.functional as F import wandb from sklearn.model_selection import train_test_split from torch.utils.data import Dataset, DataLoader from torchvision import transforms as T from tqdm.auto import tqdm sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from models.STAR.lib import utility from models.Encoders import RotateModel from models.Net import Net from models.Net import iresnet100 from models.encoder4editing.utils.model_utils import setup_model, get_latents from utils.bicubic import BicubicDownSample from utils.train import image_grid, WandbLogger, seed_everything, toggle_grad class MovingAverageLoss: def __init__(self, weights: dict, alpha=0.02): self.alpha = alpha self.weights = weights self.vals = {} def reset(self): self.vals = {} def update(self, cur_vals): for key, val in cur_vals.items(): self.vals[key] = self.alpha * val + (1 - self.alpha) * self.vals.get(key, val) def calc_loss(self, losses): loss = 0. for key, val in losses.items(): loss += self.weights.get(key, 1) * val / self.vals.get(key, 1) return loss class Trainer: def __init__(self, model=None, args=None, optimizer=None, scheduler=None, train_dataloader=None, test_dataloader=None, logger=None ): self.model = model self.args = args self.optimizer = optimizer self.scheduler = scheduler self.train_dataloader = train_dataloader self.test_dataloader = test_dataloader self.logger = logger self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.net = Net(Namespace(size=1024, ckpt='pretrained_models/StyleGAN/ffhq.pt', channel_multiplier=2, latent=512, n_mlp=8, device=self.device)) self.e4e = setup_model('pretrained_models/encoder4editing/e4e_ffhq_encode.pt', 'cuda')[0] self.arc_face = iresnet100() self.arc_face.load_state_dict(torch.load("pretrained_models/ArcFace/backbone_r100.pth")) self.arc_face.eval().cuda() self.toArcface = T.Compose([ T.Resize((112, 112)), T.Normalize(0.5, 0.5) ]) # init landmarks config = utility.get_config(utility.landmarks_arg) self.kp_extractor = utility.get_net(config) model_path = utility.landmarks_arg.pretrained_weight checkpoint = torch.load(model_path) self.kp_extractor.load_state_dict(checkpoint["net"]) self.kp_extractor = self.kp_extractor.float().to('cuda') self.kp_extractor.eval() self.toLandmarks = T.Compose([ T.Resize((256, 256)), T.Normalize(0.5, 0.5) ]) toggle_grad(self.arc_face, False) toggle_grad(self.kp_extractor, False) toggle_grad(self.net.generator, False) toggle_grad(self.e4e.encoder, False) self.downsample_512 = BicubicDownSample(factor=2) self.downsample_256 = BicubicDownSample(factor=4) self.downsample_128 = BicubicDownSample(factor=8) self.MAL = MovingAverageLoss({'mse points to': 6, 'mse latents': 2}) self.best_loss = float('+inf') def generate_key_points(self, batch): _, _, landmarks = self.kp_extractor(self.toLandmarks(batch)) final_marks_2D = (landmarks[:, :76] + 1) / 2 * torch.tensor([256 - 1, 256 - 1]).to('cuda').view(1, 1, 2) return final_marks_2D @torch.no_grad() def generate_latents(self, batch): return get_latents(self.e4e, batch) def save_model(self, name, save_online=True): with TemporaryDirectory() as tmp_dir: model_state_dict = self.model.state_dict() # delete pretrained clip for key in list(model_state_dict.keys()): if key.startswith("clip_model."): del model_state_dict[key] torch.save({'model_state_dict': model_state_dict}, f'{tmp_dir}/{name}.pth') self.logger.save(f'{tmp_dir}/{name}.pth', save_online) def load_model(self, checkpoint_path): self.model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'], strict=False) def calc_loss(self, I_to, I_from, key_points_to, latents_from, latents_to, ret_images=False, normalize=True ): # rotate rotate_to = self.model(latents_from[:, :6], latents_to[:, :6]) latent_in = torch.cat((rotate_to, latents_from[:, 6:]), axis=1) I_G_to, _ = self.net.generator([latent_in], input_is_latent=True, return_latents=False) I_G_to_0_1 = ((I_G_to + 1) / 2) I_gen_to = self.downsample_256(I_G_to_0_1).clip(0, 1) # key_point_loss key_points_gen_to = self.generate_key_points(I_gen_to) key_point_loss_to = F.mse_loss(key_points_gen_to, key_points_to) # arcface loss gen_embed = self.arc_face(self.toArcface(I_gen_to)) gt_embed = self.arc_face(self.toArcface(I_from)) arc_face_loss = 20 * (1 - F.cosine_similarity(gen_embed, gt_embed)).mean() losses = { 'mse points to': key_point_loss_to, 'arc face': arc_face_loss } if normalize: losses['loss'] = self.MAL.calc_loss(losses) else: losses['loss'] = sum(losses.values()) if ret_images: return losses['loss'], {key: val.item() for key, val in losses.items()}, I_gen_to, latent_in else: return losses['loss'], {key: val.item() for key, val in losses.items()} def calc_hair_loss(self, latents_from, latents_to, ret_images=False, normalize=True ): # rotate rotate_to = self.model(latents_from[:, :6], latents_to[:, :6]) mse_latents = 300 * F.mse_loss(rotate_to, latents_to[:, :6]) losses = { 'mse latents': mse_latents } if normalize: losses['loss'] = self.MAL.calc_loss(losses) else: losses['loss'] = sum(losses.values()) if ret_images: latent_in = torch.cat((rotate_to, latents_from[:, 6:]), axis=1) I_G_to, _ = self.net.generator([latent_in], input_is_latent=True, return_latents=False) I_G_to_0_1 = ((I_G_to + 1) / 2) I_gen_to = self.downsample_256(I_G_to_0_1).clip(0, 1) return losses['loss'], {key: val.item() for key, val in losses.items()}, I_gen_to else: return losses['loss'], {key: val.item() for key, val in losses.items()} def train_one_epoch(self): self.model.to(self.device).train() sum_losses = lambda x, y: {key: y.get(key, 0) + x.get(key, 0) for key in set(x.keys()) | set(y.keys())} dataloader_to = iter(self.train_dataloader) for batch in tqdm(self.train_dataloader): I_from, key_points_from, latents_from = map(lambda x: x.to(self.device), batch) I_to, key_points_to, latents_to = map(lambda x: x.to(self.device), next(dataloader_to)) self.optimizer.zero_grad() loss, info, _, gen_latent = self.calc_loss( I_to, I_from, key_points_to, latents_from, latents_to, ret_images=True ) if self.args.use_hair_loss: hair_loss, info2 = self.calc_hair_loss( gen_latent, latents_from ) loss += hair_loss info = sum_losses(info, info2) loss.backward() self.MAL.update(info) total_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5) self.optimizer.step() self.logger.next_step() for key, val in info.items(): self.logger.log(key, val) self.logger.log('grad', total_norm.item()) @torch.no_grad() def validate(self): self.model.to(self.device).eval() sum_losses = lambda x, y: {key: y.get(key, 0) + x.get(key, 0) for key in set(x.keys()) | set(y.keys())} files = [] losses = {} for batch in tqdm(self.test_dataloader): I_from, key_points_from, latents_from, \ I_to, key_points_to, latents_to, = map(lambda x: x.to(self.device), batch) bsz = I_from.size(0) loss, info, I_gen_to, gen_latent = self.calc_loss( I_to, I_from, key_points_to, latents_from, latents_to, ret_images=True, normalize=False ) if args.use_hair_loss: loss, info2, I_gen_to_rec = self.calc_hair_loss( gen_latent, latents_from, ret_images=True, normalize=False ) losses = sum_losses(losses, info2) else: I_G_from, _ = self.net.generator([latents_from], input_is_latent=True, return_latents=False) I_G_from_0_1 = ((I_G_from + 1) / 2) I_gen_to_rec = self.downsample_256(I_G_from_0_1).clip(0, 1) losses = sum_losses(losses, info) for k in range(bsz): files.append([I_from[k].cpu(), I_gen_to_rec[k].cpu(), I_gen_to[k].cpu(), I_to[k].cpu()]) for key, val in losses.items(): val /= len(self.test_dataloader) self.logger.log(f'val {key}', val) np.random.seed(1927) idxs = np.random.choice(len(files), size=min(len(files), 100), replace=False) images_to_log = [image_grid(list(map(T.functional.to_pil_image, files[idx])), 1, 4) for idx in idxs] self.logger.log('val images', [wandb.Image(image) for image in images_to_log]) return losses['loss'] / len(self.test_dataloader) def train_loop(self, epochs): # self.validate() for epoch in range(epochs): self.train_one_epoch() loss = self.validate() self.save_model(f'rotate_{epoch}', save_online=False) self.save_model('last') if loss <= self.best_loss: self.best_loss = loss self.save_model(f'best', save_online=False) class Rotate_dataset(Dataset): def __init__(self, tensors_images, key_points, latents, is_test=False): super().__init__() self.tensors_images = tensors_images self.key_points = key_points self.latents = latents self.is_test = is_test def __len__(self): return len(self.tensors_images) def __get_elem__(self, idx): return self.tensors_images[idx], self.key_points[idx], self.latents[idx] def __getitem__(self, idx): if self.is_test: return *self.__get_elem__(idx), *self.__get_elem__(-idx) else: return self.__get_elem__(idx) def main(args): seed_everything() data = list(torch.load(args.dataset).values()) X_train, X_test = train_test_split(list(zip(data[0], data[1], data[2])), test_size=512, random_state=42) train_dataset = Rotate_dataset(*list(zip(*X_train))) test_dataset = Rotate_dataset(*list(zip(*X_test)), is_test=True) train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, shuffle=True, drop_last=True, num_workers=4) test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, pin_memory=True, shuffle=False, num_workers=4) logger = WandbLogger(name=args.name_run, project='HairFast-Rotate') logger.start_logging() logger.save(__file__) model = RotateModel() optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.000001) trainer = Trainer(model, args, optimizer, None, train_dataloader, test_dataloader, logger) trainer.train_loop(1000) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Rotate trainer') parser.add_argument('--name_run', type=str, default='test') parser.add_argument('--dataset', type=str, default='input/rotate_dataset.pkl') parser.add_argument('--use_hair_loss', action='store_false') parser.add_argument('--batch_size', type=int, default=16) args = parser.parse_args() main(args)