|
import os |
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
|
import sys |
|
import argparse |
|
|
|
from accelerate.utils import set_seed |
|
|
|
sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0]) |
|
|
|
from libs.engine import merge_and_update_config |
|
from libs.utils.argparse import accelerate_parser, base_data_parser |
|
from pipelines.painter.diffsketchedit_pipeline import DiffSketchEditPipeline |
|
|
|
|
|
class PromptInfo: |
|
def __init__(self, prompts, token_ind, changing_region_words, reweight_word=None, reweight_weight=None): |
|
self.prompts = prompts |
|
self.token_ind = token_ind |
|
self.changing_region_words = changing_region_words |
|
self.reweight_word = reweight_word |
|
self.reweight_weight = reweight_weight |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser( |
|
description="vary style and content painterly rendering", |
|
parents=[accelerate_parser(), base_data_parser()] |
|
) |
|
|
|
parser.add_argument("-c", "--config", |
|
type=str, |
|
default="diffsketchedit.yaml", |
|
help="YAML/YML file for configuration.") |
|
|
|
parser.add_argument("-style", "--style_file", |
|
default="", type=str, |
|
help="the path of style img place.") |
|
|
|
|
|
parser.add_argument("-respath", "--results_path", |
|
type=str, default="./workdir", |
|
help="If it is None, it is automatically generated.") |
|
parser.add_argument("-npt", "--negative_prompt", default="", type=str) |
|
|
|
parser.add_argument("--sd_image_only", default=0, type=int, |
|
help="1 for generating the SD images only; 0 for generating the subsequent vector sketches.") |
|
|
|
parser.add_argument("--vector_local_edit", default=1, type=int) |
|
parser.add_argument("--vector_local_edit_bin_threshold_replace", default=0.3, type=float) |
|
parser.add_argument("--vector_local_edit_bin_threshold_refine", default=0.3, type=float) |
|
parser.add_argument("--vector_local_edit_bin_threshold_reweight", default=0.3, type=float) |
|
parser.add_argument("--vector_local_edit_attn_res", default=16, choices=[16, 32, 64], type=int) |
|
|
|
|
|
parser.add_argument("--print_timing", "-timing", action="store_true", |
|
help="set print svg rendering timing.") |
|
|
|
parser.add_argument("--download", default=0, type=int, |
|
help="download models from huggingface automatically.") |
|
parser.add_argument("--force_download", "-download", action="store_true", |
|
help="force the models to be downloaded from huggingface.") |
|
parser.add_argument("--resume_download", "-dpm_resume", action="store_true", |
|
help="download the models again from the breakpoint.") |
|
|
|
|
|
parser.add_argument("--render_batch", "-rdbz", action="store_true") |
|
parser.add_argument("-srange", "--seed_range", |
|
required=False, nargs='+', |
|
help="Sampling quantity.") |
|
|
|
parser.add_argument("-mv", "--make_video", action="store_true", |
|
help="make a video of the rendering process.") |
|
parser.add_argument("-frame_freq", "--video_frame_freq", |
|
default=1, type=int, |
|
help="video frame control.") |
|
args = parser.parse_args() |
|
|
|
args = merge_and_update_config(args) |
|
|
|
|
|
|
|
seeds_list = [25760] |
|
|
|
args.edit_type = "replace" |
|
prompt_infos = [ |
|
|
|
PromptInfo(prompts=["A painting of a squirrel eating a burger", |
|
"A painting of a rabbit eating a burger", |
|
"A painting of a rabbit eating a pumpkin", |
|
"A painting of a owl eating a pumpkin"], |
|
token_ind=5, |
|
changing_region_words=[["", ""], ["squirrel", "rabbit"], ["burger", "pumpkin"], ["rabbit", "owl"]]), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
|
|
args.batch_size = 1 |
|
pipe = DiffSketchEditPipeline(args) |
|
|
|
for seed in seeds_list: |
|
for prompt_info in prompt_infos: |
|
run_stages = len(prompt_info.prompts) |
|
for run_stage in range(run_stages): |
|
args.run_stage = run_stage |
|
set_seed(seed) |
|
pipe.update_info(seed, prompt_info.token_ind, prompt_info.prompts[0]) |
|
pipe.painterly_rendering(prompt_info.prompts, |
|
prompt_info.token_ind, prompt_info.changing_region_words, |
|
reweight_word=prompt_info.reweight_word, reweight_weight=prompt_info.reweight_weight) |
|
pipe.close(msg="painterly rendering complete.") |
|
|