File size: 2,958 Bytes
6d314be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
import os
import sys
from pathlib import Path

import numpy as np
from tqdm.auto import tqdm

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from utils.train import seed_everything
from utils.image_utils import list_image_files
from hair_swap import get_parser, HairFast
from utils.save_utils import save_latents


def identity_func(align_shape, align_color, name_to_embed, **kwargs):
    return align_shape, align_color, name_to_embed


def align_instead_shape(hair_fast):
    def shape_module(func):
        def wrapper(*args, **kwargs):
            if kwargs.get('align_flag', False):
                return hair_fast.align.align_images(*args, **kwargs)
            else:
                return func(*args, **kwargs)

        return wrapper

    def align_module(func):
        def wrapper(*args, **kwargs):
            if 'align_flag' in kwargs:
                kwargs = kwargs.copy()
                kwargs.pop('align_flag')
            return func(*args, **kwargs)

        return wrapper

    hair_fast.align.shape_module = shape_module(hair_fast.align.shape_module)
    hair_fast.align.align_images = align_module(hair_fast.align.align_images)


def main(args):
    seed_everything(args.seed)

    # init HairFast
    model_parser = get_parser()
    model_args = model_parser.parse_args([])
    hair_fast = HairFast(model_args)
    hair_fast.blend.blend_images = identity_func
    align_instead_shape(hair_fast)

    # generate dataset
    images = list_image_files(args.FFHQ)
    face, shape, color = np.array_split(np.random.choice(images, size=3 * args.size), 3)

    os.makedirs(args.output, exist_ok=True)
    with open(args.output / 'dataset.exps', 'w') as f_exps:
        for imgs in tqdm(zip(face, shape, color)):
            im1, im2, im3 = map(lambda im: im.split('.')[0], imgs)
            print(im1, im2, im3, file=f_exps, flush=True)

            pt1, pt2, pt3 = map(lambda im: args.FFHQ / im, imgs)
            align_shape, align_color, name_to_embed = hair_fast(pt1, pt2, pt3, align_flag=True)
            save_latents(args.output, 'FS', f'{im1}.npz', latent_in=name_to_embed['face']['S'])
            save_latents(args.output, 'FS', f'{im2}.npz', latent_in=name_to_embed['shape']['S'])
            save_latents(args.output, 'FS', f'{im3}.npz', latent_in=name_to_embed['color']['S'])
            save_latents(args.output, 'Align', f'{im1}_{im2}.npz', latent_F=align_shape['latent_F_align'])
            save_latents(args.output, 'Align', f'{im1}_{im3}.npz', latent_F=align_color['latent_F_align'])


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Blending dataset')
    parser.add_argument('--FFHQ', type=Path)
    parser.add_argument('--seed', type=int, default=3407)
    parser.add_argument('--size', type=int, default=3_000)
    parser.add_argument('--output', type=Path, default='input/blending_dataset')
    args = parser.parse_args()

    main(args)