newTryOn / scripts /pp_gen.py
amanSethSmava
new commit
6d314be
raw
history blame
4.22 kB
import argparse
import os
import sys
import tempfile
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from joblib import Parallel, delayed
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.utils import save_image
from tqdm.auto import tqdm
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from hair_swap import get_parser, HairFast
from scripts.pp_train import Trainer
from utils.train import seed_everything
from utils.image_utils import list_image_files
class ImageException(Exception):
def __init__(self, image, message="Return image before PP"):
self.image = image
self.message = message
super().__init__(self.message)
def hairfast_wo_pp(hair_fast):
class RaiseDownsample(nn.Module):
def __init__(self):
super().__init__()
def forward(self, image):
image = ((image[0] + 1) / 2).clip(0, 1)
raise ImageException(image)
def blend_images(func):
def wrapper(*args, **kwargs):
try:
func(*args, **kwargs)
except ImageException as e:
return e.image
return wrapper
hair_fast.blend.downsample_256 = RaiseDownsample()
hair_fast.blend.blend_images = blend_images(hair_fast.blend.blend_images)
def load_image(path):
return T.functional.to_tensor(Image.open(path))
@torch.no_grad()
def load_dataset_images(imgs, dataset_path):
net_trainer = Trainer()
source_images = Parallel(n_jobs=-1)(delayed(load_image)(os.path.join(args.FFHQ, img[0])) for img in tqdm(imgs))
target_images = Parallel(n_jobs=-1)(delayed(load_image)(os.path.join(dataset_path, img[1])) for img in tqdm(imgs))
tensors_dataloader = DataLoader(list(zip(source_images, target_images)), batch_size=64, pin_memory=False,
shuffle=False, drop_last=False)
source_files = [os.path.join(args.FFHQ, img[0]) for img in imgs]
data = []
total = 0
for batch in tqdm(tensors_dataloader):
source, target = [elem.to('cuda') for elem in batch]
HS_D, _ = net_trainer.generate_mask(source)
HT_D, HT_E = net_trainer.generate_mask(target)
target_mask = (1 - HS_D) * (1 - HT_D)
data.extend(list(zip(
[source_files[total + i] for i in range(len(source))],
net_trainer.downsample_256(target).clip(0, 1).cpu(),
target_mask.cpu(),
HT_E.cpu(),
)))
total += len(source)
return data
def main(args):
seed_everything(args.seed)
# init HairFast
model_parser = get_parser()
model_args = model_parser.parse_args([])
hair_fast = HairFast(model_args)
hairfast_wo_pp(hair_fast)
# generate dataset
os.makedirs(args.output, exist_ok=True)
images = list_image_files(args.FFHQ)
face, shape, color = np.array_split(np.random.choice(images, size=3 * args.size), 3)
exps = []
for exp in zip(face, shape, color):
imgs = map(lambda im: im.split('.')[0], exp)
exps.append([exp[0], f"{'_'.join(imgs)}.png", exp])
batch = 5_000
left, right, idx = 0, min(len(exps), batch), 1
while left < len(exps):
with tempfile.TemporaryDirectory() as temp_dir:
for exp in tqdm(exps[left:right]):
im1, im2, im3 = exp[-1]
image = hair_fast(args.FFHQ / im1, args.FFHQ / im2, args.FFHQ / im3)
save_image(image, os.path.join(temp_dir, exp[1]))
batch_data = load_dataset_images(exps[left:right], temp_dir)
torch.save(batch_data, os.path.join(args.output, f'pp_part_{idx}.dataset'))
left = right
right = min(len(exps), right + batch)
idx += 1
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Blending dataset')
parser.add_argument('--FFHQ', type=Path)
parser.add_argument('--seed', type=int, default=3407)
parser.add_argument('--size', type=int, default=10_000)
parser.add_argument('--output', type=Path, default='input/pp_dataset')
args = parser.parse_args()
main(args)