Spaces:
Build error
Build error
import argparse | |
import os | |
import sys | |
import tempfile | |
from pathlib import Path | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from joblib import Parallel, delayed | |
from torch.utils.data import DataLoader | |
from torchvision import transforms as T | |
from torchvision.utils import save_image | |
from tqdm.auto import tqdm | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
from hair_swap import get_parser, HairFast | |
from scripts.pp_train import Trainer | |
from utils.train import seed_everything | |
from utils.image_utils import list_image_files | |
class ImageException(Exception): | |
def __init__(self, image, message="Return image before PP"): | |
self.image = image | |
self.message = message | |
super().__init__(self.message) | |
def hairfast_wo_pp(hair_fast): | |
class RaiseDownsample(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, image): | |
image = ((image[0] + 1) / 2).clip(0, 1) | |
raise ImageException(image) | |
def blend_images(func): | |
def wrapper(*args, **kwargs): | |
try: | |
func(*args, **kwargs) | |
except ImageException as e: | |
return e.image | |
return wrapper | |
hair_fast.blend.downsample_256 = RaiseDownsample() | |
hair_fast.blend.blend_images = blend_images(hair_fast.blend.blend_images) | |
def load_image(path): | |
return T.functional.to_tensor(Image.open(path)) | |
def load_dataset_images(imgs, dataset_path): | |
net_trainer = Trainer() | |
source_images = Parallel(n_jobs=-1)(delayed(load_image)(os.path.join(args.FFHQ, img[0])) for img in tqdm(imgs)) | |
target_images = Parallel(n_jobs=-1)(delayed(load_image)(os.path.join(dataset_path, img[1])) for img in tqdm(imgs)) | |
tensors_dataloader = DataLoader(list(zip(source_images, target_images)), batch_size=64, pin_memory=False, | |
shuffle=False, drop_last=False) | |
source_files = [os.path.join(args.FFHQ, img[0]) for img in imgs] | |
data = [] | |
total = 0 | |
for batch in tqdm(tensors_dataloader): | |
source, target = [elem.to('cuda') for elem in batch] | |
HS_D, _ = net_trainer.generate_mask(source) | |
HT_D, HT_E = net_trainer.generate_mask(target) | |
target_mask = (1 - HS_D) * (1 - HT_D) | |
data.extend(list(zip( | |
[source_files[total + i] for i in range(len(source))], | |
net_trainer.downsample_256(target).clip(0, 1).cpu(), | |
target_mask.cpu(), | |
HT_E.cpu(), | |
))) | |
total += len(source) | |
return data | |
def main(args): | |
seed_everything(args.seed) | |
# init HairFast | |
model_parser = get_parser() | |
model_args = model_parser.parse_args([]) | |
hair_fast = HairFast(model_args) | |
hairfast_wo_pp(hair_fast) | |
# generate dataset | |
os.makedirs(args.output, exist_ok=True) | |
images = list_image_files(args.FFHQ) | |
face, shape, color = np.array_split(np.random.choice(images, size=3 * args.size), 3) | |
exps = [] | |
for exp in zip(face, shape, color): | |
imgs = map(lambda im: im.split('.')[0], exp) | |
exps.append([exp[0], f"{'_'.join(imgs)}.png", exp]) | |
batch = 5_000 | |
left, right, idx = 0, min(len(exps), batch), 1 | |
while left < len(exps): | |
with tempfile.TemporaryDirectory() as temp_dir: | |
for exp in tqdm(exps[left:right]): | |
im1, im2, im3 = exp[-1] | |
image = hair_fast(args.FFHQ / im1, args.FFHQ / im2, args.FFHQ / im3) | |
save_image(image, os.path.join(temp_dir, exp[1])) | |
batch_data = load_dataset_images(exps[left:right], temp_dir) | |
torch.save(batch_data, os.path.join(args.output, f'pp_part_{idx}.dataset')) | |
left = right | |
right = min(len(exps), right + batch) | |
idx += 1 | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Blending dataset') | |
parser.add_argument('--FFHQ', type=Path) | |
parser.add_argument('--seed', type=int, default=3407) | |
parser.add_argument('--size', type=int, default=10_000) | |
parser.add_argument('--output', type=Path, default='input/pp_dataset') | |
args = parser.parse_args() | |
main(args) | |