import argparse import os import sys from pathlib import Path import numpy as np from tqdm.auto import tqdm sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from utils.train import seed_everything from utils.image_utils import list_image_files from hair_swap import get_parser, HairFast from utils.save_utils import save_latents def identity_func(align_shape, align_color, name_to_embed, **kwargs): return align_shape, align_color, name_to_embed def align_instead_shape(hair_fast): def shape_module(func): def wrapper(*args, **kwargs): if kwargs.get('align_flag', False): return hair_fast.align.align_images(*args, **kwargs) else: return func(*args, **kwargs) return wrapper def align_module(func): def wrapper(*args, **kwargs): if 'align_flag' in kwargs: kwargs = kwargs.copy() kwargs.pop('align_flag') return func(*args, **kwargs) return wrapper hair_fast.align.shape_module = shape_module(hair_fast.align.shape_module) hair_fast.align.align_images = align_module(hair_fast.align.align_images) def main(args): seed_everything(args.seed) # init HairFast model_parser = get_parser() model_args = model_parser.parse_args([]) hair_fast = HairFast(model_args) hair_fast.blend.blend_images = identity_func align_instead_shape(hair_fast) # generate dataset images = list_image_files(args.FFHQ) face, shape, color = np.array_split(np.random.choice(images, size=3 * args.size), 3) os.makedirs(args.output, exist_ok=True) with open(args.output / 'dataset.exps', 'w') as f_exps: for imgs in tqdm(zip(face, shape, color)): im1, im2, im3 = map(lambda im: im.split('.')[0], imgs) print(im1, im2, im3, file=f_exps, flush=True) pt1, pt2, pt3 = map(lambda im: args.FFHQ / im, imgs) align_shape, align_color, name_to_embed = hair_fast(pt1, pt2, pt3, align_flag=True) save_latents(args.output, 'FS', f'{im1}.npz', latent_in=name_to_embed['face']['S']) save_latents(args.output, 'FS', f'{im2}.npz', latent_in=name_to_embed['shape']['S']) save_latents(args.output, 'FS', f'{im3}.npz', latent_in=name_to_embed['color']['S']) save_latents(args.output, 'Align', f'{im1}_{im2}.npz', latent_F=align_shape['latent_F_align']) save_latents(args.output, 'Align', f'{im1}_{im3}.npz', latent_F=align_color['latent_F_align']) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Blending dataset') parser.add_argument('--FFHQ', type=Path) parser.add_argument('--seed', type=int, default=3407) parser.add_argument('--size', type=int, default=3_000) parser.add_argument('--output', type=Path, default='input/blending_dataset') args = parser.parse_args() main(args)