File size: 6,577 Bytes
7aefe45 9a925ec 7aefe45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
import argparse
from accelerate.utils import set_seed
sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])
from libs.engine import merge_and_update_config
from libs.utils.argparse import accelerate_parser, base_data_parser
from pipelines.painter.diffsketchedit_pipeline import DiffSketchEditPipeline
class PromptInfo:
def __init__(self, prompts, token_ind, changing_region_words, reweight_word=None, reweight_weight=None):
self.prompts = prompts
self.token_ind = token_ind
self.changing_region_words = changing_region_words
self.reweight_word = reweight_word
self.reweight_weight = reweight_weight
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="vary style and content painterly rendering",
parents=[accelerate_parser(), base_data_parser()]
)
# config
parser.add_argument("-c", "--config",
type=str,
default="diffsketchedit.yaml",
help="YAML/YML file for configuration.")
parser.add_argument("-style", "--style_file",
default="", type=str,
help="the path of style img place.")
# result path
parser.add_argument("-respath", "--results_path",
type=str, default="./workdir",
help="If it is None, it is automatically generated.")
parser.add_argument("-npt", "--negative_prompt", default="", type=str)
parser.add_argument("--sd_image_only", default=0, type=int,
help="1 for generating the SD images only; 0 for generating the subsequent vector sketches.")
parser.add_argument("--vector_local_edit", default=1, type=int)
parser.add_argument("--vector_local_edit_bin_threshold_replace", default=0.3, type=float)
parser.add_argument("--vector_local_edit_bin_threshold_refine", default=0.3, type=float)
parser.add_argument("--vector_local_edit_bin_threshold_reweight", default=0.3, type=float)
parser.add_argument("--vector_local_edit_attn_res", default=16, choices=[16, 32, 64], type=int)
# DiffSVG
parser.add_argument("--print_timing", "-timing", action="store_true",
help="set print svg rendering timing.")
# diffuser
parser.add_argument("--download", default=0, type=int,
help="download models from huggingface automatically.")
parser.add_argument("--force_download", "-download", action="store_true",
help="force the models to be downloaded from huggingface.")
parser.add_argument("--resume_download", "-dpm_resume", action="store_true",
help="download the models again from the breakpoint.")
# rendering quantity
# like: python main.py -rdbz -srange 100 200
parser.add_argument("--render_batch", "-rdbz", action="store_true")
parser.add_argument("-srange", "--seed_range",
required=False, nargs='+',
help="Sampling quantity.")
# visual rendering process
parser.add_argument("-mv", "--make_video", action="store_true",
help="make a video of the rendering process.")
parser.add_argument("-frame_freq", "--video_frame_freq",
default=1, type=int,
help="video frame control.")
args = parser.parse_args()
args = merge_and_update_config(args)
############################### main parameters ###############################
seeds_list = [25760]
# seeds_list = [random.randint(1, 65536) for _ in range(100)]
args.edit_type = "replace" # ["replace", "refine", "reweight"]
prompt_infos = [
## "replace" examples
PromptInfo(prompts=["A painting of a squirrel eating a burger",
"A painting of a rabbit eating a burger",
"A painting of a rabbit eating a pumpkin",
"A painting of a owl eating a pumpkin"],
token_ind=5,
changing_region_words=[["", ""], ["squirrel", "rabbit"], ["burger", "pumpkin"], ["rabbit", "owl"]]),
# PromptInfo(prompts=["A boy wearing a cap",
# "A boy wearing a beanie"],
# token_ind=2,
# changing_region_words=[["", ""], ["cap", "beanie"]]),
# PromptInfo(prompts=["A desk near the bookshelf",
# "A chair near the bookshelf"],
# token_ind=2,
# changing_region_words=[["", ""], ["desk", "chair"]]),
## "refine" examples
# PromptInfo(prompts=["An evening dress",
# "An evening dress with sleeves",
# "An evening dress with sleeves and a belt"],
# token_ind=3,
# changing_region_words=[["", ""], ["", "sleeves"], ["", "belt"]]),
## "reweight" examples
# PromptInfo(prompts=["An emoji face with moustache and smile"] * 3,
# token_ind=3,
# changing_region_words=[["", ""], ["moustache", "moustache"], ["smile", "smile"]],
# reweight_word=["moustache", "smile"],
# reweight_weight=[-1.0, 3.0]),
# PromptInfo(prompts=["A photo of a birthday cake with candles"] * 2,
# token_ind=6,
# changing_region_words=[["", ""], ["candles", "candles"]],
# reweight_word=["candles"],
# reweight_weight=[-5.0])
]
############################### main parameters (end) ###############################
args.batch_size = 1 # rendering one SVG at a time
pipe = DiffSketchEditPipeline(args)
for seed in seeds_list:
for prompt_info in prompt_infos:
run_stages = len(prompt_info.prompts)
for run_stage in range(run_stages):
args.run_stage = run_stage
set_seed(seed)
pipe.update_info(seed, prompt_info.token_ind, prompt_info.prompts[0])
pipe.painterly_rendering(prompt_info.prompts,
prompt_info.token_ind, prompt_info.changing_region_words,
reweight_word=prompt_info.reweight_word, reweight_weight=prompt_info.reweight_weight)
pipe.close(msg="painterly rendering complete.")
|