Spaces:
Build error
Build error
import argparse | |
import os | |
import random | |
import sys | |
from argparse import Namespace | |
from pathlib import Path | |
from tempfile import TemporaryDirectory | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import wandb | |
from PIL import Image | |
from sklearn.model_selection import train_test_split | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms as T | |
from tqdm.auto import tqdm | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
from losses.pp_losses import LossBuilder, LossBuilderMulti | |
from models.Encoders import ModulationModule, FeatureiResnet, FeatureEncoderMult | |
from models.Net import Net | |
from models.face_parsing.model import BiSeNet, seg_mean, seg_std | |
from models.stylegan2 import dnnlib | |
from models.stylegan2.model import PixelNorm | |
from utils.bicubic import BicubicDownSample | |
from utils.image_utils import DilateErosion | |
from utils.train import image_grid, WandbLogger, toggle_grad, _LegacyUnpickler, seed_everything, get_fid_calc | |
class Trainer: | |
def __init__(self, | |
model=None, | |
args=None, | |
optimizer=None, | |
scheduler=None, | |
train_dataloader=None, | |
test_dataloader=None, | |
logger=None | |
): | |
self.model = model | |
self.args = args | |
self.optimizer = optimizer | |
self.scheduler = scheduler | |
self.train_dataloader = train_dataloader | |
self.test_dataloader = test_dataloader | |
self.logger = logger | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.dilate_erosion = DilateErosion(device=self.device) | |
self.normalize = T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) | |
if self.model is not None: | |
self.fid_calc = get_fid_calc('input/fid.pkl', args.fid_dataset) | |
self.net = Net(Namespace(size=1024, ckpt='pretrained_models/StyleGAN/ffhq.pt', channel_multiplier=2, latent=512, | |
n_mlp=8, device=self.device)) | |
with dnnlib.util.open_url("pretrained_models/StyleGAN/ffhq.pkl") as f: | |
data = _LegacyUnpickler(f).load() | |
self.discriminator = data['D'].cuda().eval() | |
self.disc_optim = torch.optim.Adam(self.discriminator.parameters(), lr=3e-4, betas=(0.9, 0.999), amsgrad=False, | |
weight_decay=0) | |
self.seg = BiSeNet(n_classes=16) | |
self.seg.to(self.device) | |
self.seg.load_state_dict(torch.load('pretrained_models/BiSeNet/seg.pth')) | |
self.seg.eval() | |
toggle_grad(self.discriminator, False) | |
toggle_grad(self.net.generator, False) | |
toggle_grad(self.seg, False) | |
self.downsample_512 = BicubicDownSample(factor=2) | |
self.downsample_256 = BicubicDownSample(factor=4) | |
self.downsample_128 = BicubicDownSample(factor=8) | |
self.best_loss = float('+inf') | |
if self.args is not None: | |
if self.args.pretrain: | |
self.LossBuilder = LossBuilder( | |
{'lpips_scale': 0.8, 'id': 0.1, 'landmark': 0, 'feat_rec': 0.01, 'adv': self.args.adv_coef}) | |
else: | |
self.LossBuilder = LossBuilderMulti( | |
{'lpips_scale': 0.8, 'id': 0.1, 'landmark': 0.1, 'feat_rec': 0.01, 'adv': self.args.adv_coef, | |
'inpaint': self.args.inpaint}) | |
self.cur_iter = 1 | |
def generate_mask(self, I): | |
IM = (self.downsample_512(I) - seg_mean) / seg_std | |
down_seg, _, _ = self.seg(IM) | |
current_mask = torch.argmax(down_seg, dim=1).long().float() | |
HM_X = torch.where(current_mask == 10, torch.ones_like(current_mask), torch.zeros_like(current_mask)) | |
HM_X = F.interpolate(HM_X.unsqueeze(1), size=(256, 256), mode='nearest') | |
HM_XD, HM_XE = self.dilate_erosion.mask(HM_X) | |
return HM_XD, HM_XE | |
def save_model(self, name, save_online=True): | |
with TemporaryDirectory() as tmp_dir: | |
model_state_dict = self.model.state_dict() | |
# delete pretrained clip | |
for key in list(model_state_dict.keys()): | |
if key.startswith("clip_model."): | |
del model_state_dict[key] | |
torch.save( | |
{'model_state_dict': model_state_dict, 'D': self.discriminator.state_dict(), 'cur_iter': self.cur_iter}, | |
f'{tmp_dir}/{name}.pth') | |
self.logger.save(f'{tmp_dir}/{name}.pth', save_online) | |
def load_model(self, checkpoint_path): | |
checkpoint = torch.load(checkpoint_path) | |
if 'D' in checkpoint: | |
self.discriminator.load_state_dict(checkpoint['D'], strict=False) | |
if 'model_state_dict' in checkpoint: | |
self.model.load_state_dict(checkpoint['model_state_dict'], strict=False) | |
def train_one_epoch(self): | |
self.model.to(self.device).train() | |
for batch in tqdm(self.train_dataloader): | |
source, target, target_mask, HT_E = map(lambda x: x.to(self.device), batch) | |
source, source_1024 = self.downsample_256(source).clip(0, 1), self.normalize(source) | |
latent_s, latent_f = self.model(self.normalize(source), self.normalize(target), target_mask, HT_E) | |
gen_im_W, _ = self.net.generator([latent_s], input_is_latent=True, return_latents=False) | |
F_w, _ = self.net.generator([latent_s], input_is_latent=True, return_latents=False, | |
start_layer=0, end_layer=4) | |
if self.args.pretrain: | |
alpha = min(1, self.cur_iter / self.args.iter_before) | |
latent_f_gen = alpha * latent_f + (1 - alpha) * F_w | |
else: | |
latent_f_gen = latent_f | |
gen_im_F, _ = self.net.generator([latent_s], input_is_latent=True, return_latents=False, | |
start_layer=5, end_layer=8, layer_in=latent_f_gen) | |
losses = self.LossBuilder(source, target, target_mask, HT_E, gen_im_W, F_w, gen_im_F, latent_f) | |
if self.args.use_adv and self.cur_iter >= self.args.iter_before: | |
losses.update(self.LossBuilder.CalcAdvLoss(self.discriminator, gen_im_F)) | |
losses['loss'] = sum(losses.values()) | |
self.optimizer.zero_grad() | |
losses['loss'].backward() | |
total_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5) | |
self.optimizer.step() | |
if self.args.use_adv and self.cur_iter >= self.args.iter_before: | |
if self.cur_iter == self.args.iter_before: | |
print('Start scripts discr') | |
toggle_grad(self.discriminator, True) | |
self.discriminator.train() | |
disc_loss = self.LossBuilder.CalcDisLoss(self.discriminator, source_1024, gen_im_F.detach()) | |
if self.cur_iter % self.args.d_reg_every: | |
disc_loss.update(self.LossBuilder.CalcR1Loss(self.discriminator, source_1024)) | |
total_loss = sum(disc_loss.values()) | |
self.disc_optim.zero_grad() | |
total_loss.backward() | |
total_norm_d = torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 0.5) | |
disc_loss['grad disc'] = total_norm_d | |
self.disc_optim.step() | |
toggle_grad(self.discriminator, False) | |
self.discriminator.eval() | |
losses.update(disc_loss) | |
losses['scripts grad'] = total_norm | |
self.logger.next_step() | |
self.logger.log_scalars({f'scripts {key}': val for key, val in losses.items()}) | |
self.cur_iter += 1 | |
def validate(self): | |
self.model.to(self.device).eval() | |
sum_losses = lambda x, y: {key: y.get(key, 0) + x.get(key, 0) for key in set(x.keys()) | set(y.keys())} | |
files = [] | |
val_losses = {} | |
to_299 = T.Resize((299, 299)) | |
images_to_fid = [] | |
for batch in tqdm(self.test_dataloader): | |
source, target, target_mask, HT_E = map(lambda x: x.to(self.device), batch) | |
source = self.downsample_256(source).clip(0, 1) | |
bsz = source.size(0) | |
latent_s, latent_f = self.model(self.normalize(source), self.normalize(target), target_mask, HT_E) | |
gen_im_W, _ = self.net.generator([latent_s], input_is_latent=True, return_latents=False) | |
F_w, _ = self.net.generator([latent_s], input_is_latent=True, return_latents=False, | |
start_layer=0, end_layer=4) | |
gen_im_F, _ = self.net.generator([latent_s], input_is_latent=True, return_latents=False, | |
start_layer=5, end_layer=8, layer_in=latent_f) | |
losses = self.LossBuilder(source, target, target_mask, HT_E, gen_im_W, F_w, gen_im_F, latent_f) | |
losses['loss'] = sum(losses.values()) | |
gen_w_256 = self.downsample_256((gen_im_W + 1) / 2).clip(0, 1) | |
gen_f_256 = self.downsample_256((gen_im_F + 1) / 2).clip(0, 1) | |
images_to_fid.append(to_299((gen_im_F + 1) / 2).clip(0, 1)) | |
val_losses = sum_losses(val_losses, losses) | |
for k in range(bsz): | |
files.append([source[k].cpu(), target[k].cpu(), gen_w_256[k].cpu(), gen_f_256[k].cpu()]) | |
val_losses['FID CLIP'] = self.fid_calc(torch.cat(images_to_fid)) | |
for key, val in val_losses.items(): | |
if key != 'FID CLIP': | |
val = val.item() / len(self.test_dataloader) | |
self.logger.log_scalars({f'val {key}': val}) | |
np.random.seed(1927) | |
idxs = np.random.choice(len(files), size=min(len(files), 100), replace=False) | |
images_to_log = [image_grid(list(map(T.functional.to_pil_image, files[idx])), 1, len(files[idx])) for idx in | |
idxs] | |
self.logger.log_scalars({'val images': [wandb.Image(image) for image in images_to_log]}) | |
return val_losses['loss'] | |
def train_loop(self, epochs): | |
self.validate() | |
for epoch in range(epochs): | |
self.train_one_epoch() | |
loss = self.validate() | |
self.save_model('last', save_online=False) | |
if loss <= self.best_loss: | |
self.best_loss = loss | |
self.save_model(f'best_{epoch}', save_online=False) | |
class PP_dataset(Dataset): | |
def __init__(self, source, target, target_mask, HT_E, is_test=False): | |
super().__init__() | |
self.source = source | |
self.target = target | |
self.target_mask = target_mask | |
self.HT_E = HT_E | |
self.is_test = is_test | |
def __len__(self): | |
return len(self.source) | |
def load_image(self, path): | |
return T.functional.to_tensor(Image.open(path)) | |
def __transform__(self, img1, img2, mask1, mask2): | |
if self.is_test: | |
return img1, img2, mask1, mask2 | |
if random.random() > 0.5: | |
img1 = T.functional.hflip(img1) | |
img2 = T.functional.hflip(img2) | |
mask1 = T.functional.hflip(mask1) | |
mask2 = T.functional.hflip(mask2) | |
return img1, img2, mask1, mask2 | |
def __getitem__(self, idx): | |
return self.__transform__(self.load_image(self.source[idx]), self.target[idx], self.target_mask[idx], | |
self.HT_E[idx]) | |
class PostProcessModel(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
self.encoder_face = FeatureEncoderMult(fs_layers=[9], opts=argparse.Namespace( | |
**{'arcface_model_path': "pretrained_models/ArcFace/backbone_ir50.pth"})) | |
if not self.args.finetune: | |
toggle_grad(self.encoder_face, False) | |
self.latent_avg = torch.load('pretrained_models/PostProcess/latent_avg.pt', map_location=torch.device('cuda')) | |
self.to_feature = FeatureiResnet([[1024, 2], [768, 2], [512, 2]]) | |
if self.args.use_mod: | |
self.to_latent_1 = nn.ModuleList([ModulationModule(18, i == 4) for i in range(5)]) | |
self.to_latent_2 = nn.ModuleList([ModulationModule(18, i == 4) for i in range(5)]) | |
self.pixelnorm = PixelNorm() | |
else: | |
self.to_latent = nn.Sequential(nn.Linear(1024, 1024), nn.LayerNorm([1024]), nn.LeakyReLU(), | |
nn.Linear(1024, 512)) | |
def forward(self, source, target, target_mask=None, *args, **kwargs): | |
s_face, [f_face] = self.encoder_face(source) | |
if self.args.pretrain: | |
return self.latent_avg + s_face, f_face | |
s_hair, [f_hair] = self.encoder_face(target) | |
if self.args.use_mod: | |
dt_latent_face = self.pixelnorm(s_face) | |
dt_latent_hair = self.pixelnorm(s_hair) | |
for mod_module in self.to_latent_1: | |
dt_latent_face = mod_module(dt_latent_face, s_hair) | |
for mod_module in self.to_latent_2: | |
dt_latent_hair = mod_module(dt_latent_hair, s_face) | |
finall_s = self.latent_avg + 0.1 * (dt_latent_face + dt_latent_hair) | |
else: | |
cat_s = torch.cat((s_face, s_hair), dim=-1) | |
finall_s = self.latent_avg + self.to_latent(cat_s) | |
if self.args.use_full: | |
cat_f = torch.cat((f_face, f_hair), dim=1) | |
else: | |
t_mask = F.interpolate(target_mask, size=(64, 64), mode='nearest') | |
cat_f = torch.cat((f_face * t_mask, f_hair * (1 - t_mask)), dim=1) | |
finall_f = self.to_feature(cat_f) | |
return finall_s, finall_f | |
def main(args): | |
seed_everything() | |
dataset = [] | |
idx = 1 | |
while os.path.isfile(args.dataset / f'pp_part_{idx}.dataset'): | |
batch_data = torch.load(args.dataset / f'pp_part_{idx}.dataset') | |
dataset.extend(batch_data) | |
idx += 1 | |
X_train, X_test = train_test_split(dataset, test_size=1024, random_state=42) | |
train_dataset = PP_dataset(*list(zip(*X_train))) | |
test_dataset = PP_dataset(*list(zip(*X_test)), is_test=True) | |
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=10, pin_memory=True, | |
shuffle=True, | |
drop_last=True) | |
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=10, pin_memory=True, | |
shuffle=False) | |
logger = WandbLogger(name=args.name_run, project='HairFast-PostProcess') | |
logger.start_logging() | |
logger.save(__file__) | |
model = PostProcessModel(args) | |
if args.pretrain: | |
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, weight_decay=0) | |
else: | |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0) | |
trainer = Trainer(model, args, optimizer, None, train_dataloader, test_dataloader, logger) | |
if not args.pretrain: | |
trainer.load_model(args.checkpoint) | |
trainer.train_loop(1000) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Post Process trainer') | |
parser.add_argument('--name_run', type=str, default='test') | |
parser.add_argument('--FFHQ', type=Path) | |
parser.add_argument('--dataset', type=Path, default='input/pp_dataset') | |
parser.add_argument('--fid_dataset', type=str, default='input') | |
parser.add_argument('--batch_size', type=int, default=16) | |
parser.add_argument('--iter_before', type=int, default=10_000) | |
parser.add_argument('--d_reg_every', type=int, default=16) | |
parser.add_argument('--inpaint', type=float, default=0.) | |
parser.add_argument('--use_adv', action='store_true') | |
parser.add_argument('--adv_coef', type=float, default=0.05) | |
parser.add_argument('--checkpoint', type=str) | |
parser.add_argument('--use_mod', action='store_true') | |
parser.add_argument('--use_full', action='store_true') | |
parser.add_argument('--pretrain', action='store_true') | |
parser.add_argument('--finetune', action='store_true') | |
args = parser.parse_args() | |
main(args) | |