File size: 18,985 Bytes
7d5189a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfcfb88
a5812ce
 
 
 
 
 
 
 
 
 
 
 
 
7d5189a
 
 
 
 
 
ee38d7b
7d5189a
 
 
70bea70
 
 
 
7d5189a
 
 
 
 
 
 
 
 
70bea70
 
7d5189a
 
 
 
 
 
 
 
 
 
 
ee38d7b
 
 
7d5189a
ee38d7b
7d5189a
ee38d7b
7d5189a
ee38d7b
7d5189a
ee38d7b
7d5189a
ee38d7b
7d5189a
ee38d7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d5189a
6fd659b
ee38d7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d5189a
ee38d7b
7d5189a
ee38d7b
 
 
7d5189a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee38d7b
 
 
7d5189a
 
 
 
 
70bea70
 
7d5189a
 
 
 
 
 
70bea70
 
7d5189a
 
 
ee38d7b
 
7d5189a
a819b10
7d5189a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee38d7b
 
 
 
 
 
 
a819b10
a5812ce
 
a819b10
ee38d7b
a5812ce
ee38d7b
 
a5812ce
7d5189a
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
import gradio as gr
import torch    
import os
from utils import call
from diffusers import (
    DDPMScheduler,
    DDIMScheduler,
    PNDMScheduler,
    LMSDiscreteScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    DPMSolverMultistepScheduler,
)
from diffusers.pipelines import StableDiffusionXLPipeline
StableDiffusionXLPipeline.__call__ = call
import os
from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
from trainscripts.textsliders.demotrain import train_xl

os.environ['CURL_CA_BUNDLE'] = ''

model_map = {
             '年龄调整': 'models/age.pt', 
             '体型丰满': 'models/chubby.pt',
             '肌肉感': 'models/muscular.pt',
             '惊讶表情': 'models/suprised_look.pt',
             '微笑': 'models/smiling.pt',
             '职业感': 'models/professional.pt',
             '长发': 'models/long_hair.pt',
             '卷发': 'models/curlyhair.pt',
             'Pixar风格': 'models/pixar_style.pt',
             '雕塑风格': 'models/sculpture_style.pt',
             '陶土风格': 'models/clay_style.pt',
             '修复图像': 'models/repair_slider.pt',
             '修复手部': 'models/fix_hands.pt',
             '杂乱房间': 'models/cluttered_room.pt',
             '阴暗天气': 'models/dark_weather.pt',
             '节日氛围': 'models/festive.pt',
             '热带天气': 'models/tropical_weather.pt',
             '冬季天气': 'models/winter_weather.pt',
             '弯眉': 'models/eyebrow.pt',
             '眼睛大小 (使用刻度 -3, -1, 1, 3)': 'models/eyesize.pt',
}

ORIGINAL_SPACE_ID = 'baulab/ConceptSliders'
SPACE_ID = os.getenv('SPACE_ID')

SHARED_UI_WARNING = f'''## 注意 - 在此共享UI中训练可能会很慢。您可以选择复制并使用至少40GB GPU的设备,或克隆此存储库以在自己的机器上运行。
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-复制空间-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="复制空间"></a></center>
'''


def merge_lora_networks(networks):
    if not networks:
        return None

    base_network = networks[0]
    for network in networks[1:]:
        for name, param in network.named_parameters():
            if name in base_network.state_dict():
                base_network.state_dict()[name].add_(param)
            else:
                base_network.state_dict()[name] = param.clone()
    return base_network

class Demo:

    def __init__(self) -> None:

        self.training = False
        self.generating = False
        self.device = 'cuda'
        self.weight_dtype = torch.bfloat16
        
        model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
        if torch.cuda.is_available():
            self.device = 'cuda'
        else:
            self.device = 'cpu'
        pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=self.weight_dtype).to(self.device)
        pipe = None
        del pipe
        torch.cuda.empty_cache()
        
        model_id = "stabilityai/sdxl-turbo"
        self.current_model = 'SDXL Turbo'
        euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
        self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=euler_anc, torch_dtype=self.weight_dtype).to(self.device)
        if torch.cuda.is_available():
            self.pipe.enable_xformers_memory_efficient_attention()
        
        self.guidance_scale = 1
        self.num_inference_steps = 3
        
        with gr.Blocks() as demo:
            self.layout()
            demo.queue(max_size=5).launch(share=True, max_threads=2)
        

    def layout(self):

        with gr.Row():

            if SPACE_ID == ORIGINAL_SPACE_ID:

                self.warning = gr.Markdown(SHARED_UI_WARNING)
          
        with gr.Row():
                
            with gr.Tab("测试") as inference_column:

                with gr.Row():

                    self.explain_infr = gr.Markdown(value='这是[概念滑块:用于扩散模型的LoRA适配器](https://sliders.baulab.info/)的演示。要尝试可以控制特定概念的模型,请选择一个模型并输入任何提示词,选择一个种子值,最后选择SDEdit时间步以保持结构。较高的SDEdit时间步会导致更多的结构变化。例如,如果选择“惊讶表情”模型,可以生成提示词“A picture of a person, realistic, 8k”的图像,并将滑块效果与原始模型生成的图像进行比较。我们还提供了几个其他预先微调的模型,如“修复”滑块,用于修复SDXL生成图像中的缺陷(请查看“预训练滑块”下拉菜单)。您还可以训练和运行自己的自定义滑块。请查看“训练”部分以进行自定义概念滑块训练。<b>当前推理正在运行SDXL Turbo!</b>')

                with gr.Row():

                    with gr.Column(scale=1):

                        self.prompt_input_infr = gr.Text(
                            placeholder="photo of a person, with bokeh street background, realistic, 8k",
                            label="提示词",
                            info="生成图像的提示词",
                            value="photo of a person, with bokeh street background, realistic, 8k"
                        )

                        with gr.Row():
                            
                            self.model_dropdown = gr.Dropdown(
                                label="预训练滑块",
                                choices= list(model_map.keys()),
                                value=['年龄调整'],
                                interactive=True,
                                multiselect=True  # 允许多选
                            )

                            self.seed_infr = gr.Number(
                                label="种子值",
                                value=42753
                            )
                            
                            self.slider_scale_infr = gr.Slider(
                                -4,
                                4,
                                label="滑块刻度",
                                value=3,
                                info="较大的滑块刻度会导致更强的编辑效果"
                            )

                            
                            self.start_noise_infr = gr.Slider(
                                600, 900, 
                                value=750, 
                                label="SDEdit时间步", 
                                info="选择较小的值以保持更多结构"
                            )
                            self.model_type = gr.Dropdown(
                                label="模型",
                                choices= ['SDXL Turbo', 'SDXL'],
                                value='SDXL Turbo',
                                interactive=True
                            )
                    with gr.Column(scale=2):

                        self.infr_button = gr.Button(
                            value="生成",
                            interactive=True
                        )

                        with gr.Row():

                            self.image_orig = gr.Image(
                                label="原始SD",
                                interactive=False,
                                type='pil',
                            )
                            
                            self.image_new = gr.Image(
                                label=f"概念滑块",
                                interactive=False,
                                type='pil',
                            )

            with gr.Tab("训练") as training_column:

                with gr.Row():

                    self.explain_train= gr.Markdown(value='在这一部分,您可以为Stable Diffusion XL训练文本概念滑块。输入您希望进行编辑的目标概念(例如:人)。接下来,输入您希望编辑的属性的增强提示词(例如:控制人的年龄,输入“person, old”)。然后,输入属性的抑制提示词(例如:输入“person, young”)。然后按“训练”按钮。使用默认设置,训练一个滑块大约需要25分钟;然后您可以在上面的“测试”选项卡中尝试推理或下载权重。为了更快的训练,请复制此存储库并使用A100或更大的GPU进行训练。代码和详细信息在[github链接](https://github.com/rohitgandikota/sliders)。')

                with gr.Row():

                    with gr.Column(scale=3):

                        self.target_concept = gr.Text(
                            placeholder="输入要进行编辑的目标概念...",
                            label="编辑概念的提示词",
                            info="对应于要编辑的概念的提示词(例如:“person”)",
                            value = ''
                        )
                        
                        self.positive_prompt = gr.Text(
                            placeholder="输入编辑的增强提示词...",
                            label="增强提示词",
                            info="对应于要增强的概念的提示词(例如:“person, old”)",
                            value = ''
                        )
                        
                        self.negative_prompt = gr.Text(
                            placeholder="输入编辑的抑制提示词...",
                            label="抑制提示词",
                            info="对应于要抑制的概念的提示词(例如:“person, young”)",
                            value = ''
                        )
                        
                        self.attributes_input = gr.Text(
                            placeholder="输入要保留的概念(用逗号分隔)。如果不需要,请留空...",
                            label="要保留的概念",
                            info="要保留/解缠的概念(例如:“male, female”)",
                            value = ''
                        )
                        self.is_person = gr.Checkbox(
                            label="人", 
                            info="您是否在为人训练滑块?")

                        self.rank = gr.Number(
                            value=4,
                            label="滑块等级",
                            info='要训练的滑块等级'
                        )
                        choices = ['xattn', 'noxattn']
                        self.train_method_input = gr.Dropdown(
                            choices=choices,
                            value='xattn',
                            label='训练方法',
                            info='训练方法。如果[* xattn *] - loras将仅在交叉注意层上。如果[* noxattn *](官方实现) - 除交叉注意层外的所有层',
                            interactive=True
                        )
                        self.iterations_input = gr.Number(
                            value=500,
                            precision=0,
                            label="迭代次数",
                            info='用于训练的迭代次数 - 最大为1000'
                        )

                        self.lr_input = gr.Number(
                            value=2e-4,
                            label="学习率",
                            info='用于训练的学习率'
                        )

                    with gr.Column(scale=1):

                        self.train_status = gr.Button(value='', variant='primary', interactive=False)

                        self.train_button = gr.Button(
                            value="训练",
                        )

                        self.download = gr.Files()

        self.infr_button.click(self.inference, inputs = [
            self.prompt_input_infr,
            self.seed_infr,
            self.start_noise_infr,
            self.slider_scale_infr,
            self.model_dropdown,
            self.model_type
            ],
            outputs=[
                self.image_new,
                self.image_orig
            ]
        )
        self.train_button.click(self.train, inputs = [
            self.target_concept,
            self.positive_prompt,
            self.negative_prompt,
            self.rank,
            self.iterations_input,
            self.lr_input,
            self.attributes_input,
            self.is_person,
            self.train_method_input
        ],
        outputs=[self.train_button,  self.train_status, self.download, self.model_dropdown]
        )

    def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, train_method_input, pbar = gr.Progress(track_tqdm=True)):
        iterations_input = min(int(iterations_input),1000)
        if attributes_input == '':
            attributes_input = None
        print(target_concept, positive_prompt, negative_prompt, attributes_input, is_person)
        
        randn = torch.randint(1, 10000000, (1,)).item()
        save_name = f"{randn}_{positive_prompt.replace(',','').replace(' ','').replace('.','')[:20]}"
        save_name += f'_alpha-{1}'
        save_name += f'_{train_method_input}'
        save_name += f'_rank_{int(rank)}.pt'
        
#         if torch.cuda.get_device_properties(0).total_memory * 1e-9 < 40:
#             return [gr.update(interactive=True, value='Train'), gr.update(value='GPU Memory is not enough for training... Please upgrade to GPU atleast 40GB or clone the repo to your local machine.'), None, gr.update()]
        if self.training:
            return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
        
        attributes = attributes_input
        if is_person:
            attributes = 'white, black, asian, hispanic, indian, male, female'
        
        self.training = True
        train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=int(rank), train_method=train_method_input, device=self.device, attributes=attributes, save_name=save_name)
        self.training = False

        torch.cuda.empty_cache()
        model_map[save_name.replace('.pt','')] = f'models/{save_name}'
        
        return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), f'models/{save_name}', gr.update(choices=list(model_map.keys()), value=save_name.replace('.pt',''))]

    
    def inference(self, prompt, seed, start_noise, scale, model_names, model, pbar = gr.Progress(track_tqdm=True)):
        
        seed = seed or 42753
        if self.current_model != model:
            if model=='SDXL Turbo':
                model_id = "stabilityai/sdxl-turbo"
                euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
                self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=euler_anc, torch_dtype=self.weight_dtype).to(self.device)
                if torch.cuda.is_available():
                    self.pipe.enable_xformers_memory_efficient_attention()
                self.guidance_scale = 1
                self.num_inference_steps = 3
                self.current_model = 'SDXL Turbo'
            else:
                model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
                self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=self.weight_dtype).to(self.device)
                if torch.cuda.is_available():
                    self.pipe.enable_xformers_memory_efficient_attention()
                self.guidance_scale = 7.5
                self.num_inference_steps = 20
                self.current_model = 'SDXL'
        generator = torch.manual_seed(seed)

        networks = []
        for model_name in model_names:
            model_path = model_map[model_name]
            unet = self.pipe.unet
            network_type = "c3lier"
            if 'full' in model_path:
                train_method = 'full'
            elif 'noxattn' in model_path:
                train_method = 'noxattn'
            elif 'xattn' in model_path:
                train_method = 'xattn'
                network_type = 'lierla'
            else:
                train_method = 'noxattn'

            modules = DEFAULT_TARGET_REPLACE
            if network_type == "c3lier":
                modules += UNET_TARGET_REPLACE_MODULE_CONV

            name = os.path.basename(model_path)
            rank = 4
            alpha = 1
            if 'rank' in model_path:
                rank = int(float(model_path.split('_')[-1].replace('.pt','')))
            if 'alpha1' in model_path:
                alpha = 1.0
            network = LoRANetwork(
                unet,
                rank=rank,
                multiplier=1.0,
                alpha=alpha,
                train_method=train_method,
            ).to(self.device, dtype=self.weight_dtype)
            network.load_state_dict(torch.load(model_path))
            networks.append(network)
            
        __network__ = merge_lora_networks(networks)

        generator = torch.manual_seed(seed)
        edited_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=__network__, start_noise=int(start_noise), scale=float(scale), unet=unet, guidance_scale=self.guidance_scale).images[0]
        
        generator = torch.manual_seed(seed)
        original_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=__network__, start_noise=start_noise, scale=0, unet=unet, guidance_scale=self.guidance_scale).images[0]
        
        del unet, networks
        unet = None
        networks = None
        torch.cuda.empty_cache()
        
        return edited_image, original_image 

demo = Demo()