File size: 6,577 Bytes
7aefe45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a925ec
 
7aefe45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import sys
import argparse

from accelerate.utils import set_seed

sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])

from libs.engine import merge_and_update_config
from libs.utils.argparse import accelerate_parser, base_data_parser
from pipelines.painter.diffsketchedit_pipeline import DiffSketchEditPipeline


class PromptInfo:
    def __init__(self, prompts, token_ind, changing_region_words, reweight_word=None, reweight_weight=None):
        self.prompts = prompts
        self.token_ind = token_ind
        self.changing_region_words = changing_region_words
        self.reweight_word = reweight_word
        self.reweight_weight = reweight_weight


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="vary style and content painterly rendering",
        parents=[accelerate_parser(), base_data_parser()]
    )
    # config
    parser.add_argument("-c", "--config",
                        type=str,
                        default="diffsketchedit.yaml",
                        help="YAML/YML file for configuration.")

    parser.add_argument("-style", "--style_file",
                        default="", type=str,
                        help="the path of style img place.")

    # result path
    parser.add_argument("-respath", "--results_path",
                        type=str, default="./workdir",
                        help="If it is None, it is automatically generated.")
    parser.add_argument("-npt", "--negative_prompt", default="", type=str)

    parser.add_argument("--sd_image_only", default=0, type=int,
                        help="1 for generating the SD images only; 0 for generating the subsequent vector sketches.")

    parser.add_argument("--vector_local_edit", default=1, type=int)
    parser.add_argument("--vector_local_edit_bin_threshold_replace", default=0.3, type=float)
    parser.add_argument("--vector_local_edit_bin_threshold_refine", default=0.3, type=float)
    parser.add_argument("--vector_local_edit_bin_threshold_reweight", default=0.3, type=float)
    parser.add_argument("--vector_local_edit_attn_res", default=16, choices=[16, 32, 64], type=int)

    # DiffSVG
    parser.add_argument("--print_timing", "-timing", action="store_true",
                        help="set print svg rendering timing.")
    # diffuser
    parser.add_argument("--download", default=0, type=int,
                        help="download models from huggingface automatically.")
    parser.add_argument("--force_download", "-download", action="store_true",
                        help="force the models to be downloaded from huggingface.")
    parser.add_argument("--resume_download", "-dpm_resume", action="store_true",
                        help="download the models again from the breakpoint.")
    # rendering quantity
    # like: python main.py -rdbz -srange 100 200
    parser.add_argument("--render_batch", "-rdbz", action="store_true")
    parser.add_argument("-srange", "--seed_range",
                        required=False, nargs='+',
                        help="Sampling quantity.")
    # visual rendering process
    parser.add_argument("-mv", "--make_video", action="store_true",
                        help="make a video of the rendering process.")
    parser.add_argument("-frame_freq", "--video_frame_freq",
                        default=1, type=int,
                        help="video frame control.")
    args = parser.parse_args()

    args = merge_and_update_config(args)

    ############################### main parameters ###############################

    seeds_list = [25760]
    # seeds_list = [random.randint(1, 65536) for _ in range(100)]
    args.edit_type = "replace"  # ["replace", "refine", "reweight"]
    prompt_infos = [
        ## "replace" examples
        PromptInfo(prompts=["A painting of a squirrel eating a burger",
                            "A painting of a rabbit eating a burger",
                            "A painting of a rabbit eating a pumpkin",
                            "A painting of a owl eating a pumpkin"],
                   token_ind=5,
                   changing_region_words=[["", ""], ["squirrel", "rabbit"], ["burger", "pumpkin"], ["rabbit", "owl"]]),

        # PromptInfo(prompts=["A boy wearing a cap",
        #                     "A boy wearing a beanie"],
        #            token_ind=2,
        #            changing_region_words=[["", ""], ["cap", "beanie"]]),

        # PromptInfo(prompts=["A desk near the bookshelf",
        #                     "A chair near the bookshelf"],
        #            token_ind=2,
        #            changing_region_words=[["", ""], ["desk", "chair"]]),

        ## "refine" examples
        # PromptInfo(prompts=["An evening dress",
        #                     "An evening dress with sleeves",
        #                     "An evening dress with sleeves and a belt"],
        #            token_ind=3,
        #            changing_region_words=[["", ""], ["", "sleeves"], ["", "belt"]]),

        ## "reweight" examples
        # PromptInfo(prompts=["An emoji face with moustache and smile"] * 3,
        #            token_ind=3,
        #            changing_region_words=[["", ""], ["moustache", "moustache"], ["smile", "smile"]],
        #            reweight_word=["moustache", "smile"],
        #            reweight_weight=[-1.0, 3.0]),

        # PromptInfo(prompts=["A photo of a birthday cake with candles"] * 2,
        #            token_ind=6,
        #            changing_region_words=[["", ""], ["candles", "candles"]],
        #            reweight_word=["candles"],
        #            reweight_weight=[-5.0])
    ]

    ############################### main parameters (end) ###############################

    args.batch_size = 1  # rendering one SVG at a time
    pipe = DiffSketchEditPipeline(args)

    for seed in seeds_list:
        for prompt_info in prompt_infos:
            run_stages = len(prompt_info.prompts)
            for run_stage in range(run_stages):
                args.run_stage = run_stage
                set_seed(seed)
                pipe.update_info(seed, prompt_info.token_ind, prompt_info.prompts[0])
                pipe.painterly_rendering(prompt_info.prompts,
                                         prompt_info.token_ind, prompt_info.changing_region_words,
                                         reweight_word=prompt_info.reweight_word, reweight_weight=prompt_info.reweight_weight)
                pipe.close(msg="painterly rendering complete.")