Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import os | |
| from utils import call | |
| from diffusers import ( | |
| DDPMScheduler, | |
| DDIMScheduler, | |
| PNDMScheduler, | |
| LMSDiscreteScheduler, | |
| EulerAncestralDiscreteScheduler, | |
| EulerDiscreteScheduler, | |
| DPMSolverMultistepScheduler, | |
| ) | |
| from diffusers.pipelines import StableDiffusionXLPipeline | |
| StableDiffusionXLPipeline.__call__ = call | |
| import os | |
| from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV | |
| from trainscripts.textsliders.demotrain import train_xl | |
| os.environ['CURL_CA_BUNDLE'] = '' | |
| model_map = { | |
| '年龄调整': 'models/age.pt', | |
| '体型丰满': 'models/chubby.pt', | |
| '肌肉感': 'models/muscular.pt', | |
| '惊讶表情': 'models/suprised_look.pt', | |
| '微笑': 'models/smiling.pt', | |
| '职业感': 'models/professional.pt', | |
| '长发': 'models/long_hair.pt', | |
| '卷发': 'models/curlyhair.pt', | |
| 'Pixar风格': 'models/pixar_style.pt', | |
| '雕塑风格': 'models/sculpture_style.pt', | |
| '陶土风格': 'models/clay_style.pt', | |
| '修复图像': 'models/repair_slider.pt', | |
| '修复手部': 'models/fix_hands.pt', | |
| '杂乱房间': 'models/cluttered_room.pt', | |
| '阴暗天气': 'models/dark_weather.pt', | |
| '节日氛围': 'models/festive.pt', | |
| '热带天气': 'models/tropical_weather.pt', | |
| '冬季天气': 'models/winter_weather.pt', | |
| '弯眉': 'models/eyebrow.pt', | |
| '眼睛大小 (使用刻度 -3, -1, 1, 3)': 'models/eyesize.pt', | |
| } | |
| ORIGINAL_SPACE_ID = 'baulab/ConceptSliders' | |
| SPACE_ID = os.getenv('SPACE_ID') | |
| SHARED_UI_WARNING = f'''## 注意 - 在此共享UI中训练可能会很慢。您可以选择复制并使用至少40GB GPU的设备,或克隆此存储库以在自己的机器上运行。 | |
| <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-复制空间-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="复制空间"></a></center> | |
| ''' | |
| class Demo: | |
| def __init__(self) -> None: | |
| self.training = False | |
| self.generating = False | |
| self.device = 'cuda' | |
| self.weight_dtype = torch.bfloat16 | |
| model_id = 'stabilityai/stable-diffusion-xl-base-1.0' | |
| pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=self.weight_dtype).to(self.device) | |
| pipe = None | |
| del pipe | |
| torch.cuda.empty_cache() | |
| model_id = "stabilityai/sdxl-turbo" | |
| self.current_model = 'SDXL Turbo' | |
| euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") | |
| self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=euler_anc, torch_dtype=self.weight_dtype).to(self.device) | |
| self.pipe.enable_xformers_memory_efficient_attention() | |
| self.guidance_scale = 1 | |
| self.num_inference_steps = 3 | |
| with gr.Blocks() as demo: | |
| self.layout() | |
| demo.queue(max_size=5).launch(share=True, max_threads=2) | |
| def layout(self): | |
| with gr.Row(): | |
| if SPACE_ID == ORIGINAL_SPACE_ID: | |
| self.warning = gr.Markdown(SHARED_UI_WARNING) | |
| with gr.Row(): | |
| with gr.Tab("测试") as inference_column: | |
| with gr.Row(): | |
| self.explain_infr = gr.Markdown(value='这是[概念滑块:用于扩散模型的LoRA适配器](https://sliders.baulab.info/)的演示。要尝试可以控制特定概念的模型,请选择一个模型并输入任何提示词,选择一个种子值,最后选择SDEdit时间步以保持结构。较高的SDEdit时间步会导致更多的结构变化。例如,如果选择“惊讶表情”模型,可以生成提示词“A picture of a person, realistic, 8k”的图像,并将滑块效果与原始模型生成的图像进行比较。我们还提供了几个其他预先微调的模型,如“修复”滑块,用于修复SDXL生成图像中的缺陷(请查看“预训练滑块”下拉菜单)。您还可以训练和运行自己的自定义滑块。请查看“训练”部分以进行自定义概念滑块训练。<b>当前推理正在运行SDXL Turbo!</b>') | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| self.prompt_input_infr = gr.Text( | |
| placeholder="photo of a person, with bokeh street background, realistic, 8k", | |
| label="提示词", | |
| info="生成图像的提示词", | |
| value="photo of a person, with bokeh street background, realistic, 8k" | |
| ) | |
| with gr.Row(): | |
| self.model_dropdown = gr.Dropdown( | |
| label="预训练滑块", | |
| choices= list(model_map.keys()), | |
| value='年龄调整', | |
| interactive=True | |
| ) | |
| self.seed_infr = gr.Number( | |
| label="种子值", | |
| value=42753 | |
| ) | |
| self.slider_scale_infr = gr.Slider( | |
| -4, | |
| 4, | |
| label="滑块刻度", | |
| value=3, | |
| info="较大的滑块刻度会导致更强的编辑效果" | |
| ) | |
| self.start_noise_infr = gr.Slider( | |
| 600, 900, | |
| value=750, | |
| label="SDEdit时间步", | |
| info="选择较小的值以保持更多结构" | |
| ) | |
| self.model_type = gr.Dropdown( | |
| label="模型", | |
| choices= ['SDXL Turbo', 'SDXL'], | |
| value='SDXL Turbo', | |
| interactive=True | |
| ) | |
| with gr.Column(scale=2): | |
| self.infr_button = gr.Button( | |
| value="生成", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| self.image_orig = gr.Image( | |
| label="原始SD", | |
| interactive=False, | |
| type='pil', | |
| ) | |
| self.image_new = gr.Image( | |
| label=f"概念滑块", | |
| interactive=False, | |
| type='pil', | |
| ) | |
| with gr.Tab("训练") as training_column: | |
| with gr.Row(): | |
| self.explain_train= gr.Markdown(value='在这一部分,您可以为Stable Diffusion XL训练文本概念滑块。输入您希望进行编辑的目标概念(例如:人)。接下来,输入您希望编辑的属性的增强提示词(例如:控制人的年龄,输入“person, old”)。然后,输入属性的抑制提示词(例如:输入“person, young”)。然后按“训练”按钮。使用默认设置,训练一个滑块大约需要25分钟;然后您可以在上面的“测试”选项卡中尝试推理或下载权重。为了更快的训练,请复制此存储库并使用A100或更大的GPU进行训练。代码和详细信息在[github链接](https://github.com/rohitgandikota/sliders)。') | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| self.target_concept = gr.Text( | |
| placeholder="输入要进行编辑的目标概念...", | |
| label="编辑概念的提示词", | |
| info="对应于要编辑的概念的提示词(例如:“person”)", | |
| value = '' | |
| ) | |
| self.positive_prompt = gr.Text( | |
| placeholder="输入编辑的增强提示词...", | |
| label="增强提示词", | |
| info="对应于要增强的概念的提示词(例如:“person, old”)", | |
| value = '' | |
| ) | |
| self.negative_prompt = gr.Text( | |
| placeholder="输入编辑的抑制提示词...", | |
| label="抑制提示词", | |
| info="对应于要抑制的概念的提示词(例如:“person, young”)", | |
| value = '' | |
| ) | |
| self.attributes_input = gr.Text( | |
| placeholder="输入要保留的概念(用逗号分隔)。如果不需要,请留空...", | |
| label="要保留的概念", | |
| info="要保留/解缠的概念(例如:“male, female”)", | |
| value = '' | |
| ) | |
| self.is_person = gr.Checkbox( | |
| label="人", | |
| info="您是否在为人训练滑块?") | |
| self.rank = gr.Number( | |
| value=4, | |
| label="滑块等级", | |
| info='要训练的滑块等级' | |
| ) | |
| choices = ['xattn', 'noxattn'] | |
| self.train_method_input = gr.Dropdown( | |
| choices=choices, | |
| value='xattn', | |
| label='训练方法', | |
| info='训练方法。如果[* xattn *] - loras将仅在交叉注意层上。如果[* noxattn *](官方实现) - 除交叉注意层外的所有层', | |
| interactive=True | |
| ) | |
| self.iterations_input = gr.Number( | |
| value=500, | |
| precision=0, | |
| label="迭代次数", | |
| info='用于训练的迭代次数 - 最大为1000' | |
| ) | |
| self.lr_input = gr.Number( | |
| value=2e-4, | |
| label="学习率", | |
| info='用于训练的学习率' | |
| ) | |
| with gr.Column(scale=1): | |
| self.train_status = gr.Button(value='', variant='primary', interactive=False) | |
| self.train_button = gr.Button( | |
| value="训练", | |
| ) | |
| self.download = gr.Files() | |
| self.infr_button.click(self.inference, inputs = [ | |
| self.prompt_input_infr, | |
| self.seed_infr, | |
| self.start_noise_infr, | |
| self.slider_scale_infr, | |
| self.model_dropdown, | |
| self.model_type | |
| ], | |
| outputs=[ | |
| self.image_new, | |
| self.image_orig | |
| ] | |
| ) | |
| self.train_button.click(self.train, inputs = [ | |
| self.target_concept, | |
| self.positive_prompt, | |
| self.negative_prompt, | |
| self.rank, | |
| self.iterations_input, | |
| self.lr_input, | |
| self.attributes_input, | |
| self.is_person, | |
| self.train_method_input | |
| ], | |
| outputs=[self.train_button, self.train_status, self.download, self.model_dropdown] | |
| ) | |
| def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, train_method_input, pbar = gr.Progress(track_tqdm=True)): | |
| iterations_input = min(int(iterations_input),1000) | |
| if attributes_input == '': | |
| attributes_input = None | |
| print(target_concept, positive_prompt, negative_prompt, attributes_input, is_person) | |
| randn = torch.randint(1, 10000000, (1,)).item() | |
| save_name = f"{randn}_{positive_prompt.replace(',','').replace(' ','').replace('.','')[:20]}" | |
| save_name += f'_alpha-{1}' | |
| save_name += f'_{train_method_input}' | |
| save_name += f'_rank_{int(rank)}.pt' | |
| # if torch.cuda.get_device_properties(0).total_memory * 1e-9 < 40: | |
| # return [gr.update(interactive=True, value='Train'), gr.update(value='GPU Memory is not enough for training... Please upgrade to GPU atleast 40GB or clone the repo to your local machine.'), None, gr.update()] | |
| if self.training: | |
| return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()] | |
| attributes = attributes_input | |
| if is_person: | |
| attributes = 'white, black, asian, hispanic, indian, male, female' | |
| self.training = True | |
| train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=int(rank), train_method=train_method_input, device=self.device, attributes=attributes, save_name=save_name) | |
| self.training = False | |
| torch.cuda.empty_cache() | |
| model_map[save_name.replace('.pt','')] = f'models/{save_name}' | |
| return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), f'models/{save_name}', gr.update(choices=list(model_map.keys()), value=save_name.replace('.pt',''))] | |
| def inference(self, prompt, seed, start_noise, scale, model_name, model, pbar = gr.Progress(track_tqdm=True)): | |
| seed = seed or 42753 | |
| if self.current_model != model: | |
| if model=='SDXL Turbo': | |
| model_id = "stabilityai/sdxl-turbo" | |
| euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") | |
| self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=euler_anc, torch_dtype=self.weight_dtype).to(self.device) | |
| self.pipe.enable_xformers_memory_efficient_attention() | |
| self.guidance_scale = 1 | |
| self.num_inference_steps = 3 | |
| self.current_model = 'SDXL Turbo' | |
| else: | |
| model_id = 'stabilityai/stable-diffusion-xl-base-1.0' | |
| self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=self.weight_dtype).to(self.device) | |
| self.pipe.enable_xformers_memory_efficient_attention() | |
| self.guidance_scale = 7.5 | |
| self.num_inference_steps = 20 | |
| self.current_model = 'SDXL' | |
| generator = torch.manual_seed(seed) | |
| model_path = model_map[model_name] | |
| unet = self.pipe.unet | |
| network_type = "c3lier" | |
| if 'full' in model_path: | |
| train_method = 'full' | |
| elif 'noxattn' in model_path: | |
| train_method = 'noxattn' | |
| elif 'xattn' in model_path: | |
| train_method = 'xattn' | |
| network_type = 'lierla' | |
| else: | |
| train_method = 'noxattn' | |
| modules = DEFAULT_TARGET_REPLACE | |
| if network_type == "c3lier": | |
| modules += UNET_TARGET_REPLACE_MODULE_CONV | |
| name = os.path.basename(model_path) | |
| rank = 4 | |
| alpha = 1 | |
| if 'rank' in model_path: | |
| rank = int(float(model_path.split('_')[-1].replace('.pt',''))) | |
| if 'alpha1' in model_path: | |
| alpha = 1.0 | |
| network = LoRANetwork( | |
| unet, | |
| rank=rank, | |
| multiplier=1.0, | |
| alpha=alpha, | |
| train_method=train_method, | |
| ).to(self.device, dtype=self.weight_dtype) | |
| network.load_state_dict(torch.load(model_path)) | |
| generator = torch.manual_seed(seed) | |
| edited_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=network, start_noise=int(start_noise), scale=float(scale), unet=unet, guidance_scale=self.guidance_scale).images[0] | |
| generator = torch.manual_seed(seed) | |
| original_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=network, start_noise=start_noise, scale=0, unet=unet, guidance_scale=self.guidance_scale).images[0] | |
| del unet, network | |
| unet = None | |
| network = None | |
| torch.cuda.empty_cache() | |
| return edited_image, original_image | |
| demo = Demo() | |