Vanisper commited on
Commit
7d5189a
·
0 Parent(s):

init: initial commit

Browse files
Files changed (41) hide show
  1. .gitattributes +35 -0
  2. .gitignore +11 -0
  3. README.md +18 -0
  4. __init__.py +2 -0
  5. _app.py +376 -0
  6. app-v1.py +368 -0
  7. app.py +426 -0
  8. demo.py +38 -0
  9. install_deps.bat +4 -0
  10. requirements.txt +27 -0
  11. trainscripts/__init__.py +1 -0
  12. trainscripts/imagesliders/config_util.py +104 -0
  13. trainscripts/imagesliders/data/config-xl.yaml +28 -0
  14. trainscripts/imagesliders/data/config.yaml +28 -0
  15. trainscripts/imagesliders/data/prompts-xl.yaml +275 -0
  16. trainscripts/imagesliders/data/prompts.yaml +174 -0
  17. trainscripts/imagesliders/debug_util.py +16 -0
  18. trainscripts/imagesliders/lora.py +256 -0
  19. trainscripts/imagesliders/model_util.py +283 -0
  20. trainscripts/imagesliders/prompt_util.py +174 -0
  21. trainscripts/imagesliders/train_lora-scale-xl.py +548 -0
  22. trainscripts/imagesliders/train_lora-scale.py +501 -0
  23. trainscripts/imagesliders/train_util.py +458 -0
  24. trainscripts/textsliders/__init__.py +0 -0
  25. trainscripts/textsliders/config_util.py +104 -0
  26. trainscripts/textsliders/data/config-xl.yaml +28 -0
  27. trainscripts/textsliders/data/config.yaml +28 -0
  28. trainscripts/textsliders/data/prompts-xl.yaml +486 -0
  29. trainscripts/textsliders/data/prompts.yaml +193 -0
  30. trainscripts/textsliders/debug_util.py +16 -0
  31. trainscripts/textsliders/demotrain.py +437 -0
  32. trainscripts/textsliders/flush.py +5 -0
  33. trainscripts/textsliders/generate_images_xl.py +513 -0
  34. trainscripts/textsliders/lora.py +258 -0
  35. trainscripts/textsliders/model_util.py +278 -0
  36. trainscripts/textsliders/prompt_util.py +183 -0
  37. trainscripts/textsliders/ptp_utils.py +295 -0
  38. trainscripts/textsliders/train_lora.py +419 -0
  39. trainscripts/textsliders/train_lora_xl.py +463 -0
  40. trainscripts/textsliders/train_util.py +419 -0
  41. utils.py +391 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 忽略所有 __pycache__ 目录
2
+ **/__pycache__/
3
+
4
+ # 忽略 Python 虚拟环境
5
+ .conda
6
+ .env/
7
+ .venv/
8
+ env/
9
+ venv/
10
+ models
11
+ .gradio
README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ConceptSliders
3
+ emoji: 🏃
4
+ colorFrom: yellow
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ python_version: 3.10
11
+ license: mit
12
+ ---
13
+
14
+ 查看配置参考文档:https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ 论文页面:https://huggingface.co/papers/2311.12092
17
+
18
+ 创建仓库:`huggingface-cli repo create ConceptSliders --type space --space_sdk gradio`
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from trainscripts.textsliders import lora
2
+ from trainscripts.textsliders import demotrain
_app.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ from utils import call
5
+ from diffusers import (
6
+ DDPMScheduler,
7
+ DDIMScheduler,
8
+ PNDMScheduler,
9
+ LMSDiscreteScheduler,
10
+ EulerAncestralDiscreteScheduler,
11
+ EulerDiscreteScheduler,
12
+ DPMSolverMultistepScheduler,
13
+ )
14
+ from diffusers.pipelines import StableDiffusionXLPipeline
15
+ StableDiffusionXLPipeline.__call__ = call
16
+ import os
17
+ from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
18
+ from trainscripts.textsliders.demotrain import train_xl
19
+
20
+ os.environ['CURL_CA_BUNDLE'] = ''
21
+
22
+ model_map = {
23
+ 'Age' : 'models/age.pt',
24
+ 'Chubby': 'models/chubby.pt',
25
+ 'Muscular': 'models/muscular.pt',
26
+ 'Surprised Look': 'models/suprised_look.pt',
27
+ 'Smiling' : 'models/smiling.pt',
28
+ 'Professional': 'models/professional.pt',
29
+
30
+
31
+
32
+ 'Long Hair' : 'models/long_hair.pt',
33
+ 'Curly Hair' : 'models/curlyhair.pt',
34
+
35
+ 'Pixar Style' : 'models/pixar_style.pt',
36
+ 'Sculpture Style': 'models/sculpture_style.pt',
37
+ 'Clay Style': 'models/clay_style.pt',
38
+
39
+ 'Repair Images': 'models/repair_slider.pt',
40
+ 'Fix Hands': 'models/fix_hands.pt',
41
+
42
+ 'Cluttered Room': 'models/cluttered_room.pt',
43
+
44
+ 'Dark Weather': 'models/dark_weather.pt',
45
+ 'Festive': 'models/festive.pt',
46
+ 'Tropical Weather': 'models/tropical_weather.pt',
47
+ 'Winter Weather': 'models/winter_weather.pt',
48
+
49
+ 'Wavy Eyebrows': 'models/eyebrow.pt',
50
+ 'Small Eyes (use scales -3, -1, 1, 3)': 'models/eyesize.pt',
51
+ }
52
+
53
+ ORIGINAL_SPACE_ID = 'baulab/ConceptSliders'
54
+ SPACE_ID = os.getenv('SPACE_ID')
55
+
56
+ SHARED_UI_WARNING = f'''## Attention - Training could be slow in this shared UI. You can alternatively duplicate and use it with a gpu with at least 40GB, or clone this repository to run on your own machine.
57
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
58
+ '''
59
+
60
+
61
+ class Demo:
62
+
63
+ def __init__(self) -> None:
64
+
65
+ self.training = False
66
+ self.generating = False
67
+ self.device = 'cuda'
68
+ self.weight_dtype = torch.bfloat16
69
+
70
+ model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
71
+ pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=self.weight_dtype).to(self.device)
72
+ pipe = None
73
+ del pipe
74
+ torch.cuda.empty_cache()
75
+
76
+ model_id = "stabilityai/sdxl-turbo"
77
+ self.current_model = 'SDXL Turbo'
78
+ euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
79
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=euler_anc, torch_dtype=self.weight_dtype).to(self.device)
80
+ self.pipe.enable_xformers_memory_efficient_attention()
81
+
82
+ self.guidance_scale = 1
83
+ self.num_inference_steps = 3
84
+
85
+ with gr.Blocks() as demo:
86
+ self.layout()
87
+ demo.queue(max_size=5).launch(share=True, max_threads=2)
88
+
89
+
90
+ def layout(self):
91
+
92
+ with gr.Row():
93
+
94
+ if SPACE_ID == ORIGINAL_SPACE_ID:
95
+
96
+ self.warning = gr.Markdown(SHARED_UI_WARNING)
97
+
98
+ with gr.Row():
99
+
100
+ with gr.Tab("Test") as inference_column:
101
+
102
+ with gr.Row():
103
+
104
+ self.explain_infr = gr.Markdown(value='This is a demo of [Concept Sliders: LoRA Adaptors for Precise Control in Diffusion Models](https://sliders.baulab.info/). To try out a model that can control a particular concept, select a model and enter any prompt, choose a seed, and finally choose the SDEdit timestep for structural preservation. Higher SDEdit timesteps results in more structural change. For example, if you select the model "Surprised Look" you can generate images for the prompt "A picture of a person, realistic, 8k" and compare the slider effect to the image generated by original model. We have also provided several other pre-fine-tuned models like "repair" sliders to repair flaws in SDXL generated images (Check out the "Pretrained Sliders" drop-down). You can also train and run your own custom sliders. Check out the "train" section for custom concept slider training. <b>Current Inference is running on SDXL Turbo!</b>')
105
+
106
+ with gr.Row():
107
+
108
+ with gr.Column(scale=1):
109
+
110
+ self.prompt_input_infr = gr.Text(
111
+ placeholder="photo of a person, with bokeh street background, realistic, 8k",
112
+ label="Prompt",
113
+ info="Prompt to generate",
114
+ value="photo of a person, with bokeh street background, realistic, 8k"
115
+ )
116
+
117
+ with gr.Row():
118
+
119
+ self.model_dropdown = gr.Dropdown(
120
+ label="Pretrained Sliders",
121
+ choices= list(model_map.keys()),
122
+ value='Age',
123
+ interactive=True
124
+ )
125
+
126
+ self.seed_infr = gr.Number(
127
+ label="Seed",
128
+ value=42753
129
+ )
130
+
131
+ self.slider_scale_infr = gr.Slider(
132
+ -4,
133
+ 4,
134
+ label="Slider Scale",
135
+ value=3,
136
+ info="Larger slider scale result in stronger edit"
137
+ )
138
+
139
+
140
+ self.start_noise_infr = gr.Slider(
141
+ 600, 900,
142
+ value=750,
143
+ label="SDEdit Timestep",
144
+ info="Choose smaller values for more structural preservation"
145
+ )
146
+ self.model_type = gr.Dropdown(
147
+ label="Model",
148
+ choices= ['SDXL Turbo', 'SDXL'],
149
+ value='SDXL Turbo',
150
+ interactive=True
151
+ )
152
+ with gr.Column(scale=2):
153
+
154
+ self.infr_button = gr.Button(
155
+ value="Generate",
156
+ interactive=True
157
+ )
158
+
159
+ with gr.Row():
160
+
161
+ self.image_orig = gr.Image(
162
+ label="Original SD",
163
+ interactive=False,
164
+ type='pil',
165
+ )
166
+
167
+ self.image_new = gr.Image(
168
+ label=f"Concept Slider",
169
+ interactive=False,
170
+ type='pil',
171
+ )
172
+
173
+ with gr.Tab("Train") as training_column:
174
+
175
+ with gr.Row():
176
+
177
+ self.explain_train= gr.Markdown(value='In this part you can train a textual concept sliders for Stable Diffusion XL. Enter a target concept you wish to make an edit on (eg. person). Next, enter a enhance prompt of the attribute you wish to edit (for controlling age of a person, enter "person, old"). Then, type the supress prompt of the attribute (for our example, enter "person, young"). Then press "train" button. With default settings, it takes about 25 minutes to train a slider; then you can try inference above or download the weights. For faster training, please duplicate the repo and train with A100 or larger GPU. Code and details are at [github link](https://github.com/rohitgandikota/sliders).')
178
+
179
+ with gr.Row():
180
+
181
+ with gr.Column(scale=3):
182
+
183
+ self.target_concept = gr.Text(
184
+ placeholder="Enter target concept to make edit on ...",
185
+ label="Prompt of concept on which edit is made",
186
+ info="Prompt corresponding to concept to edit (eg: 'person')",
187
+ value = ''
188
+ )
189
+
190
+ self.positive_prompt = gr.Text(
191
+ placeholder="Enter the enhance prompt for the edit ...",
192
+ label="Prompt to enhance",
193
+ info="Prompt corresponding to concept to enhance (eg: 'person, old')",
194
+ value = ''
195
+ )
196
+
197
+ self.negative_prompt = gr.Text(
198
+ placeholder="Enter the suppress prompt for the edit ...",
199
+ label="Prompt to suppress",
200
+ info="Prompt corresponding to concept to supress (eg: 'person, young')",
201
+ value = ''
202
+ )
203
+
204
+ self.attributes_input = gr.Text(
205
+ placeholder="Enter the concepts to preserve (comma seperated). Leave empty if not required ...",
206
+ label="Concepts to Preserve",
207
+ info="Comma seperated concepts to preserve/disentangle (eg: 'male, female')",
208
+ value = ''
209
+ )
210
+ self.is_person = gr.Checkbox(
211
+ label="Person",
212
+ info="Are you training a slider for person?")
213
+
214
+ self.rank = gr.Number(
215
+ value=4,
216
+ label="Rank of the Slider",
217
+ info='Slider Rank to train'
218
+ )
219
+ choices = ['xattn', 'noxattn']
220
+ self.train_method_input = gr.Dropdown(
221
+ choices=choices,
222
+ value='xattn',
223
+ label='Train Method',
224
+ info='Method of training. If [* xattn *] - loras will be on cross attns only. [* noxattn *] (official implementation) - all layers except cross attn',
225
+ interactive=True
226
+ )
227
+ self.iterations_input = gr.Number(
228
+ value=500,
229
+ precision=0,
230
+ label="Iterations",
231
+ info='iterations used to train - maximum of 1000'
232
+ )
233
+
234
+ self.lr_input = gr.Number(
235
+ value=2e-4,
236
+ label="Learning Rate",
237
+ info='Learning rate used to train'
238
+ )
239
+
240
+ with gr.Column(scale=1):
241
+
242
+ self.train_status = gr.Button(value='', variant='primary', interactive=False)
243
+
244
+ self.train_button = gr.Button(
245
+ value="Train",
246
+ )
247
+
248
+ self.download = gr.Files()
249
+
250
+ self.infr_button.click(self.inference, inputs = [
251
+ self.prompt_input_infr,
252
+ self.seed_infr,
253
+ self.start_noise_infr,
254
+ self.slider_scale_infr,
255
+ self.model_dropdown,
256
+ self.model_type
257
+ ],
258
+ outputs=[
259
+ self.image_new,
260
+ self.image_orig
261
+ ]
262
+ )
263
+ self.train_button.click(self.train, inputs = [
264
+ self.target_concept,
265
+ self.positive_prompt,
266
+ self.negative_prompt,
267
+ self.rank,
268
+ self.iterations_input,
269
+ self.lr_input,
270
+ self.attributes_input,
271
+ self.is_person,
272
+ self.train_method_input
273
+ ],
274
+ outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
275
+ )
276
+
277
+ def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, train_method_input, pbar = gr.Progress(track_tqdm=True)):
278
+ iterations_input = min(int(iterations_input),1000)
279
+ if attributes_input == '':
280
+ attributes_input = None
281
+ print(target_concept, positive_prompt, negative_prompt, attributes_input, is_person)
282
+
283
+ randn = torch.randint(1, 10000000, (1,)).item()
284
+ save_name = f"{randn}_{positive_prompt.replace(',','').replace(' ','').replace('.','')[:20]}"
285
+ save_name += f'_alpha-{1}'
286
+ save_name += f'_{train_method_input}'
287
+ save_name += f'_rank_{int(rank)}.pt'
288
+
289
+ # if torch.cuda.get_device_properties(0).total_memory * 1e-9 < 40:
290
+ # return [gr.update(interactive=True, value='Train'), gr.update(value='GPU Memory is not enough for training... Please upgrade to GPU atleast 40GB or clone the repo to your local machine.'), None, gr.update()]
291
+ if self.training:
292
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
293
+
294
+ attributes = attributes_input
295
+ if is_person:
296
+ attributes = 'white, black, asian, hispanic, indian, male, female'
297
+
298
+ self.training = True
299
+ train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=int(rank), train_method=train_method_input, device=self.device, attributes=attributes, save_name=save_name)
300
+ self.training = False
301
+
302
+ torch.cuda.empty_cache()
303
+ model_map[save_name.replace('.pt','')] = f'models/{save_name}'
304
+
305
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), f'models/{save_name}', gr.update(choices=list(model_map.keys()), value=save_name.replace('.pt',''))]
306
+
307
+
308
+ def inference(self, prompt, seed, start_noise, scale, model_name, model, pbar = gr.Progress(track_tqdm=True)):
309
+
310
+ seed = seed or 42753
311
+ if self.current_model != model:
312
+ if model=='SDXL Turbo':
313
+ model_id = "stabilityai/sdxl-turbo"
314
+ euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
315
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=euler_anc, torch_dtype=self.weight_dtype).to(self.device)
316
+ self.pipe.enable_xformers_memory_efficient_attention()
317
+ self.guidance_scale = 1
318
+ self.num_inference_steps = 3
319
+ self.current_model = 'SDXL Turbo'
320
+ else:
321
+ model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
322
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=self.weight_dtype).to(self.device)
323
+ self.pipe.enable_xformers_memory_efficient_attention()
324
+ self.guidance_scale = 7.5
325
+ self.num_inference_steps = 20
326
+ self.current_model = 'SDXL'
327
+ generator = torch.manual_seed(seed)
328
+
329
+ model_path = model_map[model_name]
330
+ unet = self.pipe.unet
331
+ network_type = "c3lier"
332
+ if 'full' in model_path:
333
+ train_method = 'full'
334
+ elif 'noxattn' in model_path:
335
+ train_method = 'noxattn'
336
+ elif 'xattn' in model_path:
337
+ train_method = 'xattn'
338
+ network_type = 'lierla'
339
+ else:
340
+ train_method = 'noxattn'
341
+
342
+ modules = DEFAULT_TARGET_REPLACE
343
+ if network_type == "c3lier":
344
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
345
+
346
+ name = os.path.basename(model_path)
347
+ rank = 4
348
+ alpha = 1
349
+ if 'rank' in model_path:
350
+ rank = int(float(model_path.split('_')[-1].replace('.pt','')))
351
+ if 'alpha1' in model_path:
352
+ alpha = 1.0
353
+ network = LoRANetwork(
354
+ unet,
355
+ rank=rank,
356
+ multiplier=1.0,
357
+ alpha=alpha,
358
+ train_method=train_method,
359
+ ).to(self.device, dtype=self.weight_dtype)
360
+ network.load_state_dict(torch.load(model_path))
361
+
362
+
363
+ generator = torch.manual_seed(seed)
364
+ edited_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=network, start_noise=int(start_noise), scale=float(scale), unet=unet, guidance_scale=self.guidance_scale).images[0]
365
+
366
+ generator = torch.manual_seed(seed)
367
+ original_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=network, start_noise=start_noise, scale=0, unet=unet, guidance_scale=self.guidance_scale).images[0]
368
+
369
+ del unet, network
370
+ unet = None
371
+ network = None
372
+ torch.cuda.empty_cache()
373
+
374
+ return edited_image, original_image
375
+
376
+ demo = Demo()
app-v1.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ from utils import call
5
+ from diffusers import (
6
+ DDPMScheduler,
7
+ DDIMScheduler,
8
+ PNDMScheduler,
9
+ LMSDiscreteScheduler,
10
+ EulerAncestralDiscreteScheduler,
11
+ EulerDiscreteScheduler,
12
+ DPMSolverMultistepScheduler,
13
+ )
14
+ from diffusers.pipelines import StableDiffusionXLPipeline
15
+ StableDiffusionXLPipeline.__call__ = call
16
+ import os
17
+ from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
18
+ from trainscripts.textsliders.demotrain import train_xl
19
+
20
+ os.environ['CURL_CA_BUNDLE'] = ''
21
+
22
+ model_map = {
23
+ '年龄调整': 'models/age.pt',
24
+ '体型丰满': 'models/chubby.pt',
25
+ '肌肉感': 'models/muscular.pt',
26
+ '惊讶表情': 'models/suprised_look.pt',
27
+ '微笑': 'models/smiling.pt',
28
+ '职业感': 'models/professional.pt',
29
+ '长发': 'models/long_hair.pt',
30
+ '卷发': 'models/curlyhair.pt',
31
+ 'Pixar风格': 'models/pixar_style.pt',
32
+ '雕塑风格': 'models/sculpture_style.pt',
33
+ '陶土风格': 'models/clay_style.pt',
34
+ '修复图像': 'models/repair_slider.pt',
35
+ '修复手部': 'models/fix_hands.pt',
36
+ '杂乱房间': 'models/cluttered_room.pt',
37
+ '阴暗天气': 'models/dark_weather.pt',
38
+ '节日氛围': 'models/festive.pt',
39
+ '热带天气': 'models/tropical_weather.pt',
40
+ '冬季天气': 'models/winter_weather.pt',
41
+ '弯眉': 'models/eyebrow.pt',
42
+ '眼睛大小 (使用刻度 -3, -1, 1, 3)': 'models/eyesize.pt',
43
+ }
44
+
45
+ ORIGINAL_SPACE_ID = 'baulab/ConceptSliders'
46
+ SPACE_ID = os.getenv('SPACE_ID')
47
+
48
+ SHARED_UI_WARNING = f'''## 注意 - 在此共享UI中训练可能会很慢。您可以选择复制并使用至少40GB GPU的设备,或克隆此存储库以在自己的机器上运行。
49
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-复制空间-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="复制空间"></a></center>
50
+ '''
51
+
52
+
53
+ class Demo:
54
+
55
+ def __init__(self) -> None:
56
+
57
+ self.training = False
58
+ self.generating = False
59
+ self.device = 'cuda'
60
+ self.weight_dtype = torch.bfloat16
61
+
62
+ model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
63
+ pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=self.weight_dtype).to(self.device)
64
+ pipe = None
65
+ del pipe
66
+ torch.cuda.empty_cache()
67
+
68
+ model_id = "stabilityai/sdxl-turbo"
69
+ self.current_model = 'SDXL Turbo'
70
+ euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
71
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=euler_anc, torch_dtype=self.weight_dtype).to(self.device)
72
+ self.pipe.enable_xformers_memory_efficient_attention()
73
+
74
+ self.guidance_scale = 1
75
+ self.num_inference_steps = 3
76
+
77
+ with gr.Blocks() as demo:
78
+ self.layout()
79
+ demo.queue(max_size=5).launch(share=True, max_threads=2)
80
+
81
+
82
+ def layout(self):
83
+
84
+ with gr.Row():
85
+
86
+ if SPACE_ID == ORIGINAL_SPACE_ID:
87
+
88
+ self.warning = gr.Markdown(SHARED_UI_WARNING)
89
+
90
+ with gr.Row():
91
+
92
+ with gr.Tab("测试") as inference_column:
93
+
94
+ with gr.Row():
95
+
96
+ self.explain_infr = gr.Markdown(value='这是[概念滑块:用于扩散模型的LoRA适配器](https://sliders.baulab.info/)的演示。要尝试可以控制特定概念的模型,请选择一个模型并输入任何提示词,选择一个种子值,最后选择SDEdit时间步以保持结构。较高的SDEdit时间步会导致更多的结构变化。例如,如果选择“惊讶表情”模型,可以生成提示词“A picture of a person, realistic, 8k”的图像,并将滑块效果与原始模型生成的图像进行比较。我们还提供了几个其他预先微调的模型,如“修复”滑块,用于修复SDXL生成图像中的缺陷(请查看“预训练滑块”下拉菜单)。您还可以训练和运行自己的自定义滑块。请查看“训练”部分以进行自定义概念滑块训练。<b>当前推理正在运行SDXL Turbo!</b>')
97
+
98
+ with gr.Row():
99
+
100
+ with gr.Column(scale=1):
101
+
102
+ self.prompt_input_infr = gr.Text(
103
+ placeholder="photo of a person, with bokeh street background, realistic, 8k",
104
+ label="提示词",
105
+ info="生成图像的提示词",
106
+ value="photo of a person, with bokeh street background, realistic, 8k"
107
+ )
108
+
109
+ with gr.Row():
110
+
111
+ self.model_dropdown = gr.Dropdown(
112
+ label="预训练滑块",
113
+ choices= list(model_map.keys()),
114
+ value='年龄调整',
115
+ interactive=True
116
+ )
117
+
118
+ self.seed_infr = gr.Number(
119
+ label="种子值",
120
+ value=42753
121
+ )
122
+
123
+ self.slider_scale_infr = gr.Slider(
124
+ -4,
125
+ 4,
126
+ label="滑块刻度",
127
+ value=3,
128
+ info="较大的滑块刻度会导致更强的编辑效果"
129
+ )
130
+
131
+
132
+ self.start_noise_infr = gr.Slider(
133
+ 600, 900,
134
+ value=750,
135
+ label="SDEdit时间步",
136
+ info="选择较小的值以保持更多结构"
137
+ )
138
+ self.model_type = gr.Dropdown(
139
+ label="模型",
140
+ choices= ['SDXL Turbo', 'SDXL'],
141
+ value='SDXL Turbo',
142
+ interactive=True
143
+ )
144
+ with gr.Column(scale=2):
145
+
146
+ self.infr_button = gr.Button(
147
+ value="生成",
148
+ interactive=True
149
+ )
150
+
151
+ with gr.Row():
152
+
153
+ self.image_orig = gr.Image(
154
+ label="原始SD",
155
+ interactive=False,
156
+ type='pil',
157
+ )
158
+
159
+ self.image_new = gr.Image(
160
+ label=f"概念滑块",
161
+ interactive=False,
162
+ type='pil',
163
+ )
164
+
165
+ with gr.Tab("训练") as training_column:
166
+
167
+ with gr.Row():
168
+
169
+ self.explain_train= gr.Markdown(value='在这一部分,您可以为Stable Diffusion XL训练文本概念滑块。输入您希望进行编辑的目标概念(例如:人)。接下来,输入您希望编辑的属性的增强提示词(例如:控制人的年龄,输入“person, old”)。然后,输入属性的抑制提示词(例如:输入“person, young”)。然后按“训练”按钮。使用默认设置,训练一个滑块大约需要25分钟;然后您可以在上面的“测试”选项卡中尝试推理或下载权重。为了更快的训练,请复制此存储库并使用A100或更大的GPU进行训练。代码和详细信息在[github链接](https://github.com/rohitgandikota/sliders)。')
170
+
171
+ with gr.Row():
172
+
173
+ with gr.Column(scale=3):
174
+
175
+ self.target_concept = gr.Text(
176
+ placeholder="输入要进行编辑的目标概念...",
177
+ label="编辑概念的提示词",
178
+ info="对应于要编辑的概念的提示词(例如:“person”)",
179
+ value = ''
180
+ )
181
+
182
+ self.positive_prompt = gr.Text(
183
+ placeholder="输入编辑的增强提示词...",
184
+ label="增强提示词",
185
+ info="对应于要增强的概念的提示词(例如:“person, old”)",
186
+ value = ''
187
+ )
188
+
189
+ self.negative_prompt = gr.Text(
190
+ placeholder="输入编辑的抑制提示词...",
191
+ label="抑制提示词",
192
+ info="对应于要抑制的概念的提示词(例如:“person, young”)",
193
+ value = ''
194
+ )
195
+
196
+ self.attributes_input = gr.Text(
197
+ placeholder="输入要保留的概念(用逗号分隔)。如果不需要,请留空...",
198
+ label="要保留的概念",
199
+ info="要保留/解缠的概念(例如:“male, female”)",
200
+ value = ''
201
+ )
202
+ self.is_person = gr.Checkbox(
203
+ label="人",
204
+ info="您是否在为人训练滑块?")
205
+
206
+ self.rank = gr.Number(
207
+ value=4,
208
+ label="滑块等级",
209
+ info='要训练的滑块等级'
210
+ )
211
+ choices = ['xattn', 'noxattn']
212
+ self.train_method_input = gr.Dropdown(
213
+ choices=choices,
214
+ value='xattn',
215
+ label='训练方法',
216
+ info='训练方法。如果[* xattn *] - loras将仅在交叉注意层上。如果[* noxattn *](官方实现) - 除交叉注意层外的所有层',
217
+ interactive=True
218
+ )
219
+ self.iterations_input = gr.Number(
220
+ value=500,
221
+ precision=0,
222
+ label="迭代次数",
223
+ info='用于训练的迭代次数 - 最大为1000'
224
+ )
225
+
226
+ self.lr_input = gr.Number(
227
+ value=2e-4,
228
+ label="学习率",
229
+ info='用于训练的学习率'
230
+ )
231
+
232
+ with gr.Column(scale=1):
233
+
234
+ self.train_status = gr.Button(value='', variant='primary', interactive=False)
235
+
236
+ self.train_button = gr.Button(
237
+ value="训练",
238
+ )
239
+
240
+ self.download = gr.Files()
241
+
242
+ self.infr_button.click(self.inference, inputs = [
243
+ self.prompt_input_infr,
244
+ self.seed_infr,
245
+ self.start_noise_infr,
246
+ self.slider_scale_infr,
247
+ self.model_dropdown,
248
+ self.model_type
249
+ ],
250
+ outputs=[
251
+ self.image_new,
252
+ self.image_orig
253
+ ]
254
+ )
255
+ self.train_button.click(self.train, inputs = [
256
+ self.target_concept,
257
+ self.positive_prompt,
258
+ self.negative_prompt,
259
+ self.rank,
260
+ self.iterations_input,
261
+ self.lr_input,
262
+ self.attributes_input,
263
+ self.is_person,
264
+ self.train_method_input
265
+ ],
266
+ outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
267
+ )
268
+
269
+ def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, train_method_input, pbar = gr.Progress(track_tqdm=True)):
270
+ iterations_input = min(int(iterations_input),1000)
271
+ if attributes_input == '':
272
+ attributes_input = None
273
+ print(target_concept, positive_prompt, negative_prompt, attributes_input, is_person)
274
+
275
+ randn = torch.randint(1, 10000000, (1,)).item()
276
+ save_name = f"{randn}_{positive_prompt.replace(',','').replace(' ','').replace('.','')[:20]}"
277
+ save_name += f'_alpha-{1}'
278
+ save_name += f'_{train_method_input}'
279
+ save_name += f'_rank_{int(rank)}.pt'
280
+
281
+ # if torch.cuda.get_device_properties(0).total_memory * 1e-9 < 40:
282
+ # return [gr.update(interactive=True, value='Train'), gr.update(value='GPU Memory is not enough for training... Please upgrade to GPU atleast 40GB or clone the repo to your local machine.'), None, gr.update()]
283
+ if self.training:
284
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
285
+
286
+ attributes = attributes_input
287
+ if is_person:
288
+ attributes = 'white, black, asian, hispanic, indian, male, female'
289
+
290
+ self.training = True
291
+ train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=int(rank), train_method=train_method_input, device=self.device, attributes=attributes, save_name=save_name)
292
+ self.training = False
293
+
294
+ torch.cuda.empty_cache()
295
+ model_map[save_name.replace('.pt','')] = f'models/{save_name}'
296
+
297
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), f'models/{save_name}', gr.update(choices=list(model_map.keys()), value=save_name.replace('.pt',''))]
298
+
299
+
300
+ def inference(self, prompt, seed, start_noise, scale, model_name, model, pbar = gr.Progress(track_tqdm=True)):
301
+
302
+ seed = seed or 42753
303
+ if self.current_model != model:
304
+ if model=='SDXL Turbo':
305
+ model_id = "stabilityai/sdxl-turbo"
306
+ euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
307
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=euler_anc, torch_dtype=self.weight_dtype).to(self.device)
308
+ self.pipe.enable_xformers_memory_efficient_attention()
309
+ self.guidance_scale = 1
310
+ self.num_inference_steps = 3
311
+ self.current_model = 'SDXL Turbo'
312
+ else:
313
+ model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
314
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=self.weight_dtype).to(self.device)
315
+ self.pipe.enable_xformers_memory_efficient_attention()
316
+ self.guidance_scale = 7.5
317
+ self.num_inference_steps = 20
318
+ self.current_model = 'SDXL'
319
+ generator = torch.manual_seed(seed)
320
+
321
+ model_path = model_map[model_name]
322
+ unet = self.pipe.unet
323
+ network_type = "c3lier"
324
+ if 'full' in model_path:
325
+ train_method = 'full'
326
+ elif 'noxattn' in model_path:
327
+ train_method = 'noxattn'
328
+ elif 'xattn' in model_path:
329
+ train_method = 'xattn'
330
+ network_type = 'lierla'
331
+ else:
332
+ train_method = 'noxattn'
333
+
334
+ modules = DEFAULT_TARGET_REPLACE
335
+ if network_type == "c3lier":
336
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
337
+
338
+ name = os.path.basename(model_path)
339
+ rank = 4
340
+ alpha = 1
341
+ if 'rank' in model_path:
342
+ rank = int(float(model_path.split('_')[-1].replace('.pt','')))
343
+ if 'alpha1' in model_path:
344
+ alpha = 1.0
345
+ network = LoRANetwork(
346
+ unet,
347
+ rank=rank,
348
+ multiplier=1.0,
349
+ alpha=alpha,
350
+ train_method=train_method,
351
+ ).to(self.device, dtype=self.weight_dtype)
352
+ network.load_state_dict(torch.load(model_path))
353
+
354
+
355
+ generator = torch.manual_seed(seed)
356
+ edited_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=network, start_noise=int(start_noise), scale=float(scale), unet=unet, guidance_scale=self.guidance_scale).images[0]
357
+
358
+ generator = torch.manual_seed(seed)
359
+ original_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=network, start_noise=start_noise, scale=0, unet=unet, guidance_scale=self.guidance_scale).images[0]
360
+
361
+ del unet, network
362
+ unet = None
363
+ network = None
364
+ torch.cuda.empty_cache()
365
+
366
+ return edited_image, original_image
367
+
368
+ demo = Demo()
app.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ from utils import call
5
+ from diffusers import (
6
+ DDPMScheduler,
7
+ DDIMScheduler,
8
+ PNDMScheduler,
9
+ LMSDiscreteScheduler,
10
+ EulerAncestralDiscreteScheduler,
11
+ EulerDiscreteScheduler,
12
+ DPMSolverMultistepScheduler,
13
+ )
14
+ from diffusers.pipelines import StableDiffusionXLPipeline
15
+ StableDiffusionXLPipeline.__call__ = call
16
+ import os
17
+ from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
18
+ from trainscripts.textsliders.demotrain import train_xl
19
+
20
+ os.environ['CURL_CA_BUNDLE'] = ''
21
+
22
+ model_map = {
23
+ '年龄调整': 'models/age.pt',
24
+ '体型丰满': 'models/chubby.pt',
25
+ '肌肉感': 'models/muscular.pt',
26
+ '惊讶表情': 'models/suprised_look.pt',
27
+ '微笑': 'models/smiling.pt',
28
+ '职业感': 'models/professional.pt',
29
+ '长发': 'models/long_hair.pt',
30
+ '卷发': 'models/curlyhair.pt',
31
+ 'Pixar风格': 'models/pixar_style.pt',
32
+ '雕塑风格': 'models/sculpture_style.pt',
33
+ '陶土风格': 'models/clay_style.pt',
34
+ '修复图像': 'models/repair_slider.pt',
35
+ '修复手部': 'models/fix_hands.pt',
36
+ '杂乱房间': 'models/cluttered_room.pt',
37
+ '阴暗天气': 'models/dark_weather.pt',
38
+ '节日氛围': 'models/festive.pt',
39
+ '热带天气': 'models/tropical_weather.pt',
40
+ '冬季天气': 'models/winter_weather.pt',
41
+ '弯眉': 'models/eyebrow.pt',
42
+ '眼睛大小 (使用刻度 -3, -1, 1, 3)': 'models/eyesize.pt',
43
+ }
44
+
45
+ ORIGINAL_SPACE_ID = 'baulab/ConceptSliders'
46
+ SPACE_ID = os.getenv('SPACE_ID')
47
+
48
+ SHARED_UI_WARNING = f'''## 注意 - 在此共享UI中训练可能会很慢。您可以选择复制并使用至少40GB GPU的设备,或克隆此存储库以在自己的机器上运行。
49
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-复制空间-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="复制空间"></a></center>
50
+ '''
51
+
52
+ def merge_lora_networks(networks):
53
+ if not networks:
54
+ return None
55
+
56
+ base_network = networks[0]
57
+ for network in networks[1:]:
58
+ for name, param in network.named_parameters():
59
+ if name in base_network.state_dict():
60
+ base_network.state_dict()[name].add_(param)
61
+ else:
62
+ base_network.state_dict()[name] = param.clone()
63
+ return base_network
64
+
65
+ class Demo:
66
+
67
+ def __init__(self) -> None:
68
+
69
+ self.training = False
70
+ self.generating = False
71
+ self.device = 'cuda'
72
+ self.weight_dtype = torch.bfloat16
73
+ self.model_sections = []
74
+ self.model_sections_count = gr.State(0)
75
+
76
+ model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
77
+ pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=self.weight_dtype).to(self.device)
78
+ pipe = None
79
+ del pipe
80
+ torch.cuda.empty_cache()
81
+
82
+ model_id = "stabilityai/sdxl-turbo"
83
+ self.current_model = 'SDXL Turbo'
84
+ euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
85
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=euler_anc, torch_dtype=self.weight_dtype).to(self.device)
86
+ self.pipe.enable_xformers_memory_efficient_attention()
87
+
88
+ self.guidance_scale = 1
89
+ self.num_inference_steps = 3
90
+
91
+ with gr.Blocks() as demo:
92
+ @gr.render(inputs=self.model_sections_count)
93
+ def render_model_sections(count):
94
+ print("Adding model section")
95
+ for i in range(count):
96
+ # 创建新的模块布局
97
+ with gr.Row(visible=True) as new_section:
98
+ with gr.Column():
99
+ gr.Markdown(f"### 模块 {i + 1}")
100
+ model_dropdown = gr.Dropdown(
101
+ label="预训练模型选择",
102
+ choices=list(model_map.keys()),
103
+ value="年龄调整",
104
+ interactive=True,
105
+ )
106
+ seed_infr = gr.Number(
107
+ label="种子值",
108
+ value=42753,
109
+ )
110
+ slider_scale_infr = gr.Slider(
111
+ minimum=-4,
112
+ maximum=4,
113
+ label="编辑强度",
114
+ value=3,
115
+ info="较大的值会导致更强的编辑效果",
116
+ )
117
+ # del_btn = gr.Button(value="删除", elem_classes="del-btn")
118
+
119
+ # # 删除按钮的逻辑
120
+ # def delete_section():
121
+ # if new_section in self.model_sections:
122
+ # self.model_sections.remove(new_section)
123
+ # return gr.update(visible=False)
124
+
125
+ # del_btn.click(delete_section, outputs=new_section)
126
+ # 添加新模块到 sections 列表
127
+ self.model_sections.append(new_section)
128
+
129
+ self.layout()
130
+ demo.queue(max_size=5).launch(share=True, max_threads=2)
131
+
132
+
133
+ def layout(self):
134
+ with gr.Row():
135
+
136
+ if SPACE_ID == ORIGINAL_SPACE_ID:
137
+
138
+ self.warning = gr.Markdown(SHARED_UI_WARNING)
139
+
140
+ with gr.Row():
141
+
142
+ with gr.Tab("测试") as inference_column:
143
+
144
+ with gr.Row():
145
+
146
+ self.explain_infr = gr.Markdown(value='这是[概念滑块:用于扩散模型的LoRA适配器](https://sliders.baulab.info/)的演示。要尝试可以控制特定概念的模型,请选择一个模型并输入任何提示词,选择一个种子值,最后选择SDEdit时间步以保持结构。较高的SDEdit时间步会导致更多的结构变化。例如,如果选择“惊讶表情”模型,可以生成提示词“A picture of a person, realistic, 8k”的图像,并将滑块效果与原始模型生成的图像进行比较。我们还提供了几个其他预先微调的模型,如“修复”滑块,用于修复SDXL生成图像中的缺陷(请查看“预训练滑块”下拉菜单)。您还可以训练和运行自己的自定义滑块。请查看“训练”部分以进行自定义概念滑块训练。<b>当前推理正在运行SDXL Turbo!</b>')
147
+
148
+ with gr.Row():
149
+
150
+ self.prompt_input_infr = gr.Text(
151
+ placeholder="photo of a person, with bokeh street background, realistic, 8k",
152
+ label="提示词",
153
+ info="生成图像的提示词",
154
+ value="photo of a person, with bokeh street background, realistic, 8k"
155
+ )
156
+
157
+ self.add_model_button = gr.Button(value="新增模型")
158
+
159
+ with gr.Row():
160
+ self.start_noise_infr = gr.Slider(
161
+ 600, 900,
162
+ value=750,
163
+ label="SDEdit时间步",
164
+ info="选择较小的值以保持更多结构"
165
+ )
166
+ self.model_type = gr.Dropdown(
167
+ label="模型",
168
+ choices=['SDXL Turbo', 'SDXL'],
169
+ value='SDXL Turbo',
170
+ interactive=True
171
+ )
172
+
173
+ with gr.Row():
174
+ self.infr_button = gr.Button(
175
+ value="生成",
176
+ interactive=True
177
+ )
178
+
179
+ with gr.Row():
180
+ self.image_orig = gr.Image(
181
+ label="原始SD",
182
+ interactive=False,
183
+ type='pil',
184
+ )
185
+
186
+ self.image_new = gr.Image(
187
+ label=f"概念滑块",
188
+ interactive=False,
189
+ type='pil',
190
+ )
191
+
192
+ with gr.Tab("训练") as training_column:
193
+
194
+ with gr.Row():
195
+
196
+ self.explain_train= gr.Markdown(value='在这一部分,您可以为Stable Diffusion XL训练文本概念滑块。输入您希望进行编辑的目标概念(例如:人)。接下来,输入您希望编辑的属性的增强提示词(例如:控制人的年龄,输入“person, old”)。然后,输入属性的抑制提示词(例如:输入“person, young”)。然后按“训练”按钮。使用默认设置,训练一个滑块大约需要25分钟;然后您可以在上面的“测试”选项卡中尝试推理或下载权重。为了更快的训练,请复制此存储库并使用A100或更大的GPU进行训练。代码和详细信息在[github链接](https://github.com/rohitgandikota/sliders)。')
197
+
198
+ with gr.Row():
199
+
200
+ with gr.Column(scale=3):
201
+
202
+ self.target_concept = gr.Text(
203
+ placeholder="输入要进行编辑的目标概念...",
204
+ label="编辑概念的提示词",
205
+ info="对应于要编辑的概念的提示词(例如:“person”)",
206
+ value = ''
207
+ )
208
+
209
+ self.positive_prompt = gr.Text(
210
+ placeholder="输入编辑的增强提示词...",
211
+ label="增强提示词",
212
+ info="对应于要增强的概念的提示词(例如:“person, old”)",
213
+ value = ''
214
+ )
215
+
216
+ self.negative_prompt = gr.Text(
217
+ placeholder="输入编辑的抑制提示词...",
218
+ label="抑制提示词",
219
+ info="对应于要抑制的概念的提示词(例如:“person, young”)",
220
+ value = ''
221
+ )
222
+
223
+ self.attributes_input = gr.Text(
224
+ placeholder="输入要保留的概念(用逗号分隔)。如果不需要,请留空...",
225
+ label="要保留的概念",
226
+ info="要保留/解缠的概念(例如:“male, female”)",
227
+ value = ''
228
+ )
229
+ self.is_person = gr.Checkbox(
230
+ label="人",
231
+ info="您是否在为人训练滑块?")
232
+
233
+ self.rank = gr.Number(
234
+ value=4,
235
+ label="滑块等级",
236
+ info='要训练的滑块等级'
237
+ )
238
+ choices = ['xattn', 'noxattn']
239
+ self.train_method_input = gr.Dropdown(
240
+ choices=choices,
241
+ value='xattn',
242
+ label='训练方法',
243
+ info='训练方法。如果[* xattn *] - loras将仅在交叉注意层上。如果[* noxattn *](官方实现) - 除交叉注意层外的所有层',
244
+ interactive=True
245
+ )
246
+ self.iterations_input = gr.Number(
247
+ value=500,
248
+ precision=0,
249
+ label="迭代次数",
250
+ info='用于训练的迭代次数 - 最大为1000'
251
+ )
252
+
253
+ self.lr_input = gr.Number(
254
+ value=2e-4,
255
+ label="学习率",
256
+ info='用于训练的学习率'
257
+ )
258
+
259
+ with gr.Column(scale=1):
260
+
261
+ self.train_status = gr.Button(value='', variant='primary', interactive=False)
262
+
263
+ self.train_button = gr.Button(
264
+ value="训练",
265
+ )
266
+
267
+ self.download = gr.Files()
268
+ self.model_dropdown = gr.Dropdown(choices=list(model_map.keys()))
269
+
270
+ self.add_model_button.click(lambda x: x + 1, self.model_sections_count, self.model_sections_count)
271
+ self.infr_button.click(self.inference, inputs=[
272
+ self.prompt_input_infr,
273
+ self.start_noise_infr,
274
+ self.model_type
275
+ ],
276
+ outputs=[
277
+ self.image_new,
278
+ self.image_orig
279
+ ]
280
+ )
281
+ self.train_button.click(self.train, inputs = [
282
+ self.target_concept,
283
+ self.positive_prompt,
284
+ self.negative_prompt,
285
+ self.rank,
286
+ self.iterations_input,
287
+ self.lr_input,
288
+ self.attributes_input,
289
+ self.is_person,
290
+ self.train_method_input
291
+ ],
292
+ outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
293
+ )
294
+
295
+ # def add_model_section(self):
296
+ # with self.model_sections:
297
+ # model_dropdown = gr.Dropdown(
298
+ # label="预训练滑块",
299
+ # choices=list(model_map.keys()),
300
+ # value='年龄调整',
301
+ # interactive=True
302
+ # )
303
+ # seed_infr = gr.Number(
304
+ # label="种子值",
305
+ # value=42753
306
+ # )
307
+ # slider_scale_infr = gr.Slider(
308
+ # -4,
309
+ # 4,
310
+ # label="滑块刻度",
311
+ # value=3,
312
+ # info="较大的滑块刻度会导致更强的编辑效果"
313
+ # )
314
+ # self.model_sections.add_child(model_dropdown)
315
+ # self.model_sections.add_child(seed_infr)
316
+ # self.model_sections.add_child(slider_scale_infr)
317
+ # return gr.update(visible=True, children=self.model_sections.children)
318
+
319
+ def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, train_method_input, pbar = gr.Progress(track_tqdm=True)):
320
+ iterations_input = min(int(iterations_input),1000)
321
+ if attributes_input == '':
322
+ attributes_input = None
323
+ print(target_concept, positive_prompt, negative_prompt, attributes_input, is_person)
324
+
325
+ randn = torch.randint(1, 10000000, (1,)).item()
326
+ save_name = f"{randn}_{positive_prompt.replace(',','').replace(' ','').replace('.','')[:20]}"
327
+ save_name += f'_alpha-{1}'
328
+ save_name += f'_{train_method_input}'
329
+ save_name += f'_rank_{int(rank)}.pt'
330
+
331
+ # if torch.cuda.get_device_properties(0).total_memory * 1e-9 < 40:
332
+ # return [gr.update(interactive=True, value='Train'), gr.update(value='GPU Memory is not enough for training... Please upgrade to GPU atleast 40GB or clone the repo to your local machine.'), None, gr.update()]
333
+ if self.training:
334
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
335
+
336
+ attributes = attributes_input
337
+ if is_person:
338
+ attributes = 'white, black, asian, hispanic, indian, male, female'
339
+
340
+ self.training = True
341
+ train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=int(rank), train_method=train_method_input, device=self.device, attributes=attributes, save_name=save_name)
342
+ self.training = False
343
+
344
+ torch.cuda.empty_cache()
345
+ model_map[save_name.replace('.pt','')] = f'models/{save_name}'
346
+
347
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), f'models/{save_name}', gr.update(choices=list(model_map.keys()), value=save_name.replace('.pt',''))]
348
+
349
+
350
+ def inference(self, prompt, start_noise, model, pbar=gr.Progress(track_tqdm=True)):
351
+ model_sections = self.model_sections.get_children()
352
+ model_names = [section[0].value for section in model_sections]
353
+ seed_list = [section[1].value for section in model_sections]
354
+ scale_list = [section[2].value for section in model_sections]
355
+
356
+ if self.current_model != model:
357
+ if model=='SDXL Turbo':
358
+ model_id = "stabilityai/sdxl-turbo"
359
+ euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
360
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=euler_anc, torch_dtype=self.weight_dtype).to(self.device)
361
+ self.pipe.enable_xformers_memory_efficient_attention()
362
+ self.guidance_scale = 1
363
+ self.num_inference_steps = 3
364
+ self.current_model = 'SDXL Turbo'
365
+ else:
366
+ model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
367
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=self.weight_dtype).to(self.device)
368
+ self.pipe.enable_xformers_memory_efficient_attention()
369
+ self.guidance_scale = 7.5
370
+ self.num_inference_steps = 20
371
+ self.current_model = 'SDXL'
372
+
373
+ networks = []
374
+ for i, model_name in enumerate(model_names):
375
+ model_path = model_map[model_name]
376
+ unet = self.pipe.unet
377
+ network_type = "c3lier"
378
+ if 'full' in model_path:
379
+ train_method = 'full'
380
+ elif 'noxattn' in model_path:
381
+ train_method = 'noxattn'
382
+ elif 'xattn' in model_path:
383
+ train_method = 'xattn'
384
+ network_type = 'lierla'
385
+ else:
386
+ train_method = 'noxattn'
387
+
388
+ modules = DEFAULT_TARGET_REPLACE
389
+ if network_type == "c3lier":
390
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
391
+
392
+ name = os.path.basename(model_path)
393
+ rank = 4
394
+ alpha = 1
395
+ if 'rank' in model_path:
396
+ rank = int(float(model_path.split('_')[-1].replace('.pt','')))
397
+ if 'alpha1' in model_path:
398
+ alpha = 1.0
399
+ network = LoRANetwork(
400
+ unet,
401
+ rank=rank,
402
+ multiplier=1.0,
403
+ alpha=alpha,
404
+ train_method=train_method,
405
+ ).to(self.device, dtype=self.weight_dtype)
406
+ network.load_state_dict(torch.load(model_path))
407
+ networks.append((network, seed_list[i], scale_list[i]))
408
+
409
+ generator = torch.manual_seed(seed_list[0])
410
+ edited_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=networks[0][0], start_noise=int(start_noise), scale=networks[0][2], unet=unet, guidance_scale=self.guidance_scale).images[0]
411
+
412
+ generator = torch.manual_seed(seed_list[0])
413
+ original_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=networks[0][0], start_noise=start_noise, scale=0, unet=unet, guidance_scale=self.guidance_scale).images[0]
414
+
415
+ for network, seed, scale in networks[1:]:
416
+ generator = torch.manual_seed(seed)
417
+ edited_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=network, start_noise=int(start_noise), scale=scale, unet=unet, guidance_scale=self.guidance_scale).images[0]
418
+
419
+ del unet, networks
420
+ unet = None
421
+ networks = None
422
+ torch.cuda.empty_cache()
423
+
424
+ return edited_image, original_image
425
+
426
+ demo = Demo()
demo.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ with gr.Blocks() as demo:
4
+
5
+ tasks = gr.State([])
6
+ new_task = gr.Textbox(label="Task Name", autofocus=True)
7
+
8
+ def add_task(tasks, new_task_name):
9
+ return tasks + [{"name": new_task_name, "complete": False}], ""
10
+
11
+ new_task.submit(add_task, [tasks, new_task], [tasks, new_task])
12
+
13
+ @gr.render(inputs=tasks)
14
+ def render_todos(task_list):
15
+ complete = [task for task in task_list if task["complete"]]
16
+ incomplete = [task for task in task_list if not task["complete"]]
17
+ gr.Markdown(f"### Incomplete Tasks ({len(incomplete)})")
18
+ for task in incomplete:
19
+ with gr.Row():
20
+ gr.Textbox(task['name'], show_label=False, container=False)
21
+ done_btn = gr.Button("Done", scale=0)
22
+ def mark_done(task=task):
23
+ task["complete"] = True
24
+ return task_list
25
+ done_btn.click(mark_done, None, [tasks])
26
+
27
+ delete_btn = gr.Button("Delete", scale=0, variant="stop")
28
+ def delete(task=task):
29
+ task_list.remove(task)
30
+ return task_list
31
+ delete_btn.click(delete, None, [tasks])
32
+
33
+ gr.Markdown(f"### Complete Tasks ({len(complete)})")
34
+ for task in complete:
35
+ gr.Textbox(task['name'], show_label=False, container=False)
36
+
37
+ if __name__ == "__main__":
38
+ demo.launch(share=True)
install_deps.bat ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
2
+ pip install -U xformers --index-url https://download.pytorch.org/whl/cu124
3
+ pip install -r requirements.txt
4
+ pip install --upgrade gradio
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bitsandbytes==0.41.1
2
+ dadaptation==3.1
3
+ diffusers==0.20.2
4
+ ipython==8.7.0
5
+ lion_pytorch==0.1.2
6
+ lpips==0.1.4
7
+ matplotlib==3.6.2
8
+ numpy==1.23.5
9
+ opencv_python==4.5.5.64
10
+ opencv_python_headless==4.7.0.68
11
+ pandas==1.5.2
12
+ Pillow==10.1.0
13
+ prodigyopt==1.0
14
+ pydantic==2.10.5
15
+ PyYAML==6.0.1
16
+ Requests==2.31.0
17
+ safetensors==0.3.1
18
+ torch==2.5.1
19
+ torchvision==0.20.1
20
+ xformers
21
+ tqdm==4.64.1
22
+ transformers==4.27.4
23
+ wandb==0.12.21
24
+ accelerate==0.16.0
25
+ gradio==5.12.0
26
+ gradio_client==1.5.4
27
+ huggingface-hub==0.27.1
trainscripts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # from textsliders import lora
trainscripts/imagesliders/config_util.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional
2
+
3
+ import yaml
4
+
5
+ from pydantic import BaseModel
6
+ import torch
7
+
8
+ from lora import TRAINING_METHODS
9
+
10
+ PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"]
11
+ NETWORK_TYPES = Literal["lierla", "c3lier"]
12
+
13
+
14
+ class PretrainedModelConfig(BaseModel):
15
+ name_or_path: str
16
+ v2: bool = False
17
+ v_pred: bool = False
18
+
19
+ clip_skip: Optional[int] = None
20
+
21
+
22
+ class NetworkConfig(BaseModel):
23
+ type: NETWORK_TYPES = "lierla"
24
+ rank: int = 4
25
+ alpha: float = 1.0
26
+
27
+ training_method: TRAINING_METHODS = "full"
28
+
29
+
30
+ class TrainConfig(BaseModel):
31
+ precision: PRECISION_TYPES = "bfloat16"
32
+ noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim"
33
+
34
+ iterations: int = 500
35
+ lr: float = 1e-4
36
+ optimizer: str = "adamw"
37
+ optimizer_args: str = ""
38
+ lr_scheduler: str = "constant"
39
+
40
+ max_denoising_steps: int = 50
41
+
42
+
43
+ class SaveConfig(BaseModel):
44
+ name: str = "untitled"
45
+ path: str = "./output"
46
+ per_steps: int = 200
47
+ precision: PRECISION_TYPES = "float32"
48
+
49
+
50
+ class LoggingConfig(BaseModel):
51
+ use_wandb: bool = False
52
+
53
+ verbose: bool = False
54
+
55
+
56
+ class OtherConfig(BaseModel):
57
+ use_xformers: bool = False
58
+
59
+
60
+ class RootConfig(BaseModel):
61
+ prompts_file: str
62
+ pretrained_model: PretrainedModelConfig
63
+
64
+ network: NetworkConfig
65
+
66
+ train: Optional[TrainConfig]
67
+
68
+ save: Optional[SaveConfig]
69
+
70
+ logging: Optional[LoggingConfig]
71
+
72
+ other: Optional[OtherConfig]
73
+
74
+
75
+ def parse_precision(precision: str) -> torch.dtype:
76
+ if precision == "fp32" or precision == "float32":
77
+ return torch.float32
78
+ elif precision == "fp16" or precision == "float16":
79
+ return torch.float16
80
+ elif precision == "bf16" or precision == "bfloat16":
81
+ return torch.bfloat16
82
+
83
+ raise ValueError(f"Invalid precision type: {precision}")
84
+
85
+
86
+ def load_config_from_yaml(config_path: str) -> RootConfig:
87
+ with open(config_path, "r") as f:
88
+ config = yaml.load(f, Loader=yaml.FullLoader)
89
+
90
+ root = RootConfig(**config)
91
+
92
+ if root.train is None:
93
+ root.train = TrainConfig()
94
+
95
+ if root.save is None:
96
+ root.save = SaveConfig()
97
+
98
+ if root.logging is None:
99
+ root.logging = LoggingConfig()
100
+
101
+ if root.other is None:
102
+ root.other = OtherConfig()
103
+
104
+ return root
trainscripts/imagesliders/data/config-xl.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompts_file: "trainscripts/imagesliders/data/prompts-xl.yaml"
2
+ pretrained_model:
3
+ name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" # you can also use .ckpt or .safetensors models
4
+ v2: false # true if model is v2.x
5
+ v_pred: false # true if model uses v-prediction
6
+ network:
7
+ type: "c3lier" # or "c3lier" or "lierla"
8
+ rank: 4
9
+ alpha: 1.0
10
+ training_method: "noxattn"
11
+ train:
12
+ precision: "bfloat16"
13
+ noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
14
+ iterations: 1000
15
+ lr: 0.0002
16
+ optimizer: "AdamW"
17
+ lr_scheduler: "constant"
18
+ max_denoising_steps: 50
19
+ save:
20
+ name: "temp"
21
+ path: "./models"
22
+ per_steps: 500
23
+ precision: "bfloat16"
24
+ logging:
25
+ use_wandb: false
26
+ verbose: false
27
+ other:
28
+ use_xformers: true
trainscripts/imagesliders/data/config.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompts_file: "trainscripts/imagesliders/data/prompts.yaml"
2
+ pretrained_model:
3
+ name_or_path: "CompVis/stable-diffusion-v1-4" # you can also use .ckpt or .safetensors models
4
+ v2: false # true if model is v2.x
5
+ v_pred: false # true if model uses v-prediction
6
+ network:
7
+ type: "c3lier" # or "c3lier" or "lierla"
8
+ rank: 4
9
+ alpha: 1.0
10
+ training_method: "noxattn"
11
+ train:
12
+ precision: "bfloat16"
13
+ noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
14
+ iterations: 1000
15
+ lr: 0.0002
16
+ optimizer: "AdamW"
17
+ lr_scheduler: "constant"
18
+ max_denoising_steps: 50
19
+ save:
20
+ name: "temp"
21
+ path: "./models"
22
+ per_steps: 500
23
+ precision: "bfloat16"
24
+ logging:
25
+ use_wandb: false
26
+ verbose: false
27
+ other:
28
+ use_xformers: true
trainscripts/imagesliders/data/prompts-xl.yaml ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ####################################################################################################### AGE SLIDER
2
+ # - target: "male person" # what word for erasing the positive concept from
3
+ # positive: "male person, very old" # concept to erase
4
+ # unconditional: "male person, very young" # word to take the difference from the positive concept
5
+ # neutral: "male person" # starting point for conditioning the target
6
+ # action: "enhance" # erase or enhance
7
+ # guidance_scale: 4
8
+ # resolution: 512
9
+ # dynamic_resolution: false
10
+ # batch_size: 1
11
+ # - target: "female person" # what word for erasing the positive concept from
12
+ # positive: "female person, very old" # concept to erase
13
+ # unconditional: "female person, very young" # word to take the difference from the positive concept
14
+ # neutral: "female person" # starting point for conditioning the target
15
+ # action: "enhance" # erase or enhance
16
+ # guidance_scale: 4
17
+ # resolution: 512
18
+ # dynamic_resolution: false
19
+ # batch_size: 1
20
+ ####################################################################################################### GLASSES SLIDER
21
+ # - target: "male person" # what word for erasing the positive concept from
22
+ # positive: "male person, wearing glasses" # concept to erase
23
+ # unconditional: "male person" # word to take the difference from the positive concept
24
+ # neutral: "male person" # starting point for conditioning the target
25
+ # action: "enhance" # erase or enhance
26
+ # guidance_scale: 4
27
+ # resolution: 512
28
+ # dynamic_resolution: false
29
+ # batch_size: 1
30
+ # - target: "female person" # what word for erasing the positive concept from
31
+ # positive: "female person, wearing glasses" # concept to erase
32
+ # unconditional: "female person" # word to take the difference from the positive concept
33
+ # neutral: "female person" # starting point for conditioning the target
34
+ # action: "enhance" # erase or enhance
35
+ # guidance_scale: 4
36
+ # resolution: 512
37
+ # dynamic_resolution: false
38
+ # batch_size: 1
39
+ ####################################################################################################### ASTRONAUGHT SLIDER
40
+ # - target: "astronaught" # what word for erasing the positive concept from
41
+ # positive: "astronaught, with orange colored spacesuit" # concept to erase
42
+ # unconditional: "astronaught" # word to take the difference from the positive concept
43
+ # neutral: "astronaught" # starting point for conditioning the target
44
+ # action: "enhance" # erase or enhance
45
+ # guidance_scale: 4
46
+ # resolution: 512
47
+ # dynamic_resolution: false
48
+ # batch_size: 1
49
+ ####################################################################################################### SMILING SLIDER
50
+ # - target: "male person" # what word for erasing the positive concept from
51
+ # positive: "male person, smiling" # concept to erase
52
+ # unconditional: "male person, frowning" # word to take the difference from the positive concept
53
+ # neutral: "male person" # starting point for conditioning the target
54
+ # action: "enhance" # erase or enhance
55
+ # guidance_scale: 4
56
+ # resolution: 512
57
+ # dynamic_resolution: false
58
+ # batch_size: 1
59
+ # - target: "female person" # what word for erasing the positive concept from
60
+ # positive: "female person, smiling" # concept to erase
61
+ # unconditional: "female person, frowning" # word to take the difference from the positive concept
62
+ # neutral: "female person" # starting point for conditioning the target
63
+ # action: "enhance" # erase or enhance
64
+ # guidance_scale: 4
65
+ # resolution: 512
66
+ # dynamic_resolution: false
67
+ # batch_size: 1
68
+ ####################################################################################################### CAR COLOR SLIDER
69
+ # - target: "car" # what word for erasing the positive concept from
70
+ # positive: "car, white color" # concept to erase
71
+ # unconditional: "car, black color" # word to take the difference from the positive concept
72
+ # neutral: "car" # starting point for conditioning the target
73
+ # action: "enhance" # erase or enhance
74
+ # guidance_scale: 4
75
+ # resolution: 512
76
+ # dynamic_resolution: false
77
+ # batch_size: 1
78
+ ####################################################################################################### DETAILS SLIDER
79
+ # - target: "" # what word for erasing the positive concept from
80
+ # positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality, hyper realistic" # concept to erase
81
+ # unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept
82
+ # neutral: "" # starting point for conditioning the target
83
+ # action: "enhance" # erase or enhance
84
+ # guidance_scale: 4
85
+ # resolution: 512
86
+ # dynamic_resolution: false
87
+ # batch_size: 1
88
+ ####################################################################################################### BOKEH SLIDER
89
+ # - target: "" # what word for erasing the positive concept from
90
+ # positive: "blurred background, narrow DOF, bokeh effect" # concept to erase
91
+ # # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept
92
+ # unconditional: ""
93
+ # neutral: "" # starting point for conditioning the target
94
+ # action: "enhance" # erase or enhance
95
+ # guidance_scale: 4
96
+ # resolution: 512
97
+ # dynamic_resolution: false
98
+ # batch_size: 1
99
+ ####################################################################################################### LONG HAIR SLIDER
100
+ # - target: "male person" # what word for erasing the positive concept from
101
+ # positive: "male person, with long hair" # concept to erase
102
+ # unconditional: "male person, with short hair" # word to take the difference from the positive concept
103
+ # neutral: "male person" # starting point for conditioning the target
104
+ # action: "enhance" # erase or enhance
105
+ # guidance_scale: 4
106
+ # resolution: 512
107
+ # dynamic_resolution: false
108
+ # batch_size: 1
109
+ # - target: "female person" # what word for erasing the positive concept from
110
+ # positive: "female person, with long hair" # concept to erase
111
+ # unconditional: "female person, with short hair" # word to take the difference from the positive concept
112
+ # neutral: "female person" # starting point for conditioning the target
113
+ # action: "enhance" # erase or enhance
114
+ # guidance_scale: 4
115
+ # resolution: 512
116
+ # dynamic_resolution: false
117
+ # batch_size: 1
118
+ ####################################################################################################### IMAGE SLIDER
119
+ - target: "" # what word for erasing the positive concept from
120
+ positive: "" # concept to erase
121
+ unconditional: "" # word to take the difference from the positive concept
122
+ neutral: "" # starting point for conditioning the target
123
+ action: "enhance" # erase or enhance
124
+ guidance_scale: 4
125
+ resolution: 512
126
+ dynamic_resolution: false
127
+ batch_size: 1
128
+ ####################################################################################################### IMAGE SLIDER
129
+ # - target: "food" # what word for erasing the positive concept from
130
+ # positive: "food, expensive and fine dining" # concept to erase
131
+ # unconditional: "food, cheap and low quality" # word to take the difference from the positive concept
132
+ # neutral: "food" # starting point for conditioning the target
133
+ # action: "enhance" # erase or enhance
134
+ # guidance_scale: 4
135
+ # resolution: 512
136
+ # dynamic_resolution: false
137
+ # batch_size: 1
138
+ # - target: "room" # what word for erasing the positive concept from
139
+ # positive: "room, dirty disorganised and cluttered" # concept to erase
140
+ # unconditional: "room, neat organised and clean" # word to take the difference from the positive concept
141
+ # neutral: "room" # starting point for conditioning the target
142
+ # action: "enhance" # erase or enhance
143
+ # guidance_scale: 4
144
+ # resolution: 512
145
+ # dynamic_resolution: false
146
+ # batch_size: 1
147
+ # - target: "male person" # what word for erasing the positive concept from
148
+ # positive: "male person, with a surprised look" # concept to erase
149
+ # unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept
150
+ # neutral: "male person" # starting point for conditioning the target
151
+ # action: "enhance" # erase or enhance
152
+ # guidance_scale: 4
153
+ # resolution: 512
154
+ # dynamic_resolution: false
155
+ # batch_size: 1
156
+ # - target: "female person" # what word for erasing the positive concept from
157
+ # positive: "female person, with a surprised look" # concept to erase
158
+ # unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept
159
+ # neutral: "female person" # starting point for conditioning the target
160
+ # action: "enhance" # erase or enhance
161
+ # guidance_scale: 4
162
+ # resolution: 512
163
+ # dynamic_resolution: false
164
+ # batch_size: 1
165
+ # - target: "sky" # what word for erasing the positive concept from
166
+ # positive: "peaceful sky" # concept to erase
167
+ # unconditional: "sky" # word to take the difference from the positive concept
168
+ # neutral: "sky" # starting point for conditioning the target
169
+ # action: "enhance" # erase or enhance
170
+ # guidance_scale: 4
171
+ # resolution: 512
172
+ # dynamic_resolution: false
173
+ # batch_size: 1
174
+ # - target: "sky" # what word for erasing the positive concept from
175
+ # positive: "chaotic dark sky" # concept to erase
176
+ # unconditional: "sky" # word to take the difference from the positive concept
177
+ # neutral: "sky" # starting point for conditioning the target
178
+ # action: "erase" # erase or enhance
179
+ # guidance_scale: 4
180
+ # resolution: 512
181
+ # dynamic_resolution: false
182
+ # batch_size: 1
183
+ # - target: "person" # what word for erasing the positive concept from
184
+ # positive: "person, very young" # concept to erase
185
+ # unconditional: "person" # word to take the difference from the positive concept
186
+ # neutral: "person" # starting point for conditioning the target
187
+ # action: "erase" # erase or enhance
188
+ # guidance_scale: 4
189
+ # resolution: 512
190
+ # dynamic_resolution: false
191
+ # batch_size: 1
192
+ # overweight
193
+ # - target: "art" # what word for erasing the positive concept from
194
+ # positive: "realistic art" # concept to erase
195
+ # unconditional: "art" # word to take the difference from the positive concept
196
+ # neutral: "art" # starting point for conditioning the target
197
+ # action: "enhance" # erase or enhance
198
+ # guidance_scale: 4
199
+ # resolution: 512
200
+ # dynamic_resolution: false
201
+ # batch_size: 1
202
+ # - target: "art" # what word for erasing the positive concept from
203
+ # positive: "abstract art" # concept to erase
204
+ # unconditional: "art" # word to take the difference from the positive concept
205
+ # neutral: "art" # starting point for conditioning the target
206
+ # action: "erase" # erase or enhance
207
+ # guidance_scale: 4
208
+ # resolution: 512
209
+ # dynamic_resolution: false
210
+ # batch_size: 1
211
+ # sky
212
+ # - target: "weather" # what word for erasing the positive concept from
213
+ # positive: "bright pleasant weather" # concept to erase
214
+ # unconditional: "weather" # word to take the difference from the positive concept
215
+ # neutral: "weather" # starting point for conditioning the target
216
+ # action: "enhance" # erase or enhance
217
+ # guidance_scale: 4
218
+ # resolution: 512
219
+ # dynamic_resolution: false
220
+ # batch_size: 1
221
+ # - target: "weather" # what word for erasing the positive concept from
222
+ # positive: "dark gloomy weather" # concept to erase
223
+ # unconditional: "weather" # word to take the difference from the positive concept
224
+ # neutral: "weather" # starting point for conditioning the target
225
+ # action: "erase" # erase or enhance
226
+ # guidance_scale: 4
227
+ # resolution: 512
228
+ # dynamic_resolution: false
229
+ # batch_size: 1
230
+ # hair
231
+ # - target: "person" # what word for erasing the positive concept from
232
+ # positive: "person with long hair" # concept to erase
233
+ # unconditional: "person" # word to take the difference from the positive concept
234
+ # neutral: "person" # starting point for conditioning the target
235
+ # action: "enhance" # erase or enhance
236
+ # guidance_scale: 4
237
+ # resolution: 512
238
+ # dynamic_resolution: false
239
+ # batch_size: 1
240
+ # - target: "person" # what word for erasing the positive concept from
241
+ # positive: "person with short hair" # concept to erase
242
+ # unconditional: "person" # word to take the difference from the positive concept
243
+ # neutral: "person" # starting point for conditioning the target
244
+ # action: "erase" # erase or enhance
245
+ # guidance_scale: 4
246
+ # resolution: 512
247
+ # dynamic_resolution: false
248
+ # batch_size: 1
249
+ # - target: "girl" # what word for erasing the positive concept from
250
+ # positive: "baby girl" # concept to erase
251
+ # unconditional: "girl" # word to take the difference from the positive concept
252
+ # neutral: "girl" # starting point for conditioning the target
253
+ # action: "enhance" # erase or enhance
254
+ # guidance_scale: -4
255
+ # resolution: 512
256
+ # dynamic_resolution: false
257
+ # batch_size: 1
258
+ # - target: "boy" # what word for erasing the positive concept from
259
+ # positive: "old man" # concept to erase
260
+ # unconditional: "boy" # word to take the difference from the positive concept
261
+ # neutral: "boy" # starting point for conditioning the target
262
+ # action: "enhance" # erase or enhance
263
+ # guidance_scale: 4
264
+ # resolution: 512
265
+ # dynamic_resolution: false
266
+ # batch_size: 1
267
+ # - target: "boy" # what word for erasing the positive concept from
268
+ # positive: "baby boy" # concept to erase
269
+ # unconditional: "boy" # word to take the difference from the positive concept
270
+ # neutral: "boy" # starting point for conditioning the target
271
+ # action: "enhance" # erase or enhance
272
+ # guidance_scale: -4
273
+ # resolution: 512
274
+ # dynamic_resolution: false
275
+ # batch_size: 1
trainscripts/imagesliders/data/prompts.yaml ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - target: "person" # what word for erasing the positive concept from
2
+ # positive: "person, very old" # concept to erase
3
+ # unconditional: "person" # word to take the difference from the positive concept
4
+ # neutral: "person" # starting point for conditioning the target
5
+ # action: "enhance" # erase or enhance
6
+ # guidance_scale: 4
7
+ # resolution: 512
8
+ # dynamic_resolution: false
9
+ # batch_size: 1
10
+ - target: "" # what word for erasing the positive concept from
11
+ positive: "" # concept to erase
12
+ unconditional: "" # word to take the difference from the positive concept
13
+ neutral: "" # starting point for conditioning the target
14
+ action: "enhance" # erase or enhance
15
+ guidance_scale: 1
16
+ resolution: 512
17
+ dynamic_resolution: false
18
+ batch_size: 1
19
+ # - target: "" # what word for erasing the positive concept from
20
+ # positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" # concept to erase
21
+ # unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept
22
+ # neutral: "" # starting point for conditioning the target
23
+ # action: "enhance" # erase or enhance
24
+ # guidance_scale: 4
25
+ # resolution: 512
26
+ # dynamic_resolution: false
27
+ # batch_size: 1
28
+ # - target: "food" # what word for erasing the positive concept from
29
+ # positive: "food, expensive and fine dining" # concept to erase
30
+ # unconditional: "food, cheap and low quality" # word to take the difference from the positive concept
31
+ # neutral: "food" # starting point for conditioning the target
32
+ # action: "enhance" # erase or enhance
33
+ # guidance_scale: 4
34
+ # resolution: 512
35
+ # dynamic_resolution: false
36
+ # batch_size: 1
37
+ # - target: "room" # what word for erasing the positive concept from
38
+ # positive: "room, dirty disorganised and cluttered" # concept to erase
39
+ # unconditional: "room, neat organised and clean" # word to take the difference from the positive concept
40
+ # neutral: "room" # starting point for conditioning the target
41
+ # action: "enhance" # erase or enhance
42
+ # guidance_scale: 4
43
+ # resolution: 512
44
+ # dynamic_resolution: false
45
+ # batch_size: 1
46
+ # - target: "male person" # what word for erasing the positive concept from
47
+ # positive: "male person, with a surprised look" # concept to erase
48
+ # unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept
49
+ # neutral: "male person" # starting point for conditioning the target
50
+ # action: "enhance" # erase or enhance
51
+ # guidance_scale: 4
52
+ # resolution: 512
53
+ # dynamic_resolution: false
54
+ # batch_size: 1
55
+ # - target: "female person" # what word for erasing the positive concept from
56
+ # positive: "female person, with a surprised look" # concept to erase
57
+ # unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept
58
+ # neutral: "female person" # starting point for conditioning the target
59
+ # action: "enhance" # erase or enhance
60
+ # guidance_scale: 4
61
+ # resolution: 512
62
+ # dynamic_resolution: false
63
+ # batch_size: 1
64
+ # - target: "sky" # what word for erasing the positive concept from
65
+ # positive: "peaceful sky" # concept to erase
66
+ # unconditional: "sky" # word to take the difference from the positive concept
67
+ # neutral: "sky" # starting point for conditioning the target
68
+ # action: "enhance" # erase or enhance
69
+ # guidance_scale: 4
70
+ # resolution: 512
71
+ # dynamic_resolution: false
72
+ # batch_size: 1
73
+ # - target: "sky" # what word for erasing the positive concept from
74
+ # positive: "chaotic dark sky" # concept to erase
75
+ # unconditional: "sky" # word to take the difference from the positive concept
76
+ # neutral: "sky" # starting point for conditioning the target
77
+ # action: "erase" # erase or enhance
78
+ # guidance_scale: 4
79
+ # resolution: 512
80
+ # dynamic_resolution: false
81
+ # batch_size: 1
82
+ # - target: "person" # what word for erasing the positive concept from
83
+ # positive: "person, very young" # concept to erase
84
+ # unconditional: "person" # word to take the difference from the positive concept
85
+ # neutral: "person" # starting point for conditioning the target
86
+ # action: "erase" # erase or enhance
87
+ # guidance_scale: 4
88
+ # resolution: 512
89
+ # dynamic_resolution: false
90
+ # batch_size: 1
91
+ # overweight
92
+ # - target: "art" # what word for erasing the positive concept from
93
+ # positive: "realistic art" # concept to erase
94
+ # unconditional: "art" # word to take the difference from the positive concept
95
+ # neutral: "art" # starting point for conditioning the target
96
+ # action: "enhance" # erase or enhance
97
+ # guidance_scale: 4
98
+ # resolution: 512
99
+ # dynamic_resolution: false
100
+ # batch_size: 1
101
+ # - target: "art" # what word for erasing the positive concept from
102
+ # positive: "abstract art" # concept to erase
103
+ # unconditional: "art" # word to take the difference from the positive concept
104
+ # neutral: "art" # starting point for conditioning the target
105
+ # action: "erase" # erase or enhance
106
+ # guidance_scale: 4
107
+ # resolution: 512
108
+ # dynamic_resolution: false
109
+ # batch_size: 1
110
+ # sky
111
+ # - target: "weather" # what word for erasing the positive concept from
112
+ # positive: "bright pleasant weather" # concept to erase
113
+ # unconditional: "weather" # word to take the difference from the positive concept
114
+ # neutral: "weather" # starting point for conditioning the target
115
+ # action: "enhance" # erase or enhance
116
+ # guidance_scale: 4
117
+ # resolution: 512
118
+ # dynamic_resolution: false
119
+ # batch_size: 1
120
+ # - target: "weather" # what word for erasing the positive concept from
121
+ # positive: "dark gloomy weather" # concept to erase
122
+ # unconditional: "weather" # word to take the difference from the positive concept
123
+ # neutral: "weather" # starting point for conditioning the target
124
+ # action: "erase" # erase or enhance
125
+ # guidance_scale: 4
126
+ # resolution: 512
127
+ # dynamic_resolution: false
128
+ # batch_size: 1
129
+ # hair
130
+ # - target: "person" # what word for erasing the positive concept from
131
+ # positive: "person with long hair" # concept to erase
132
+ # unconditional: "person" # word to take the difference from the positive concept
133
+ # neutral: "person" # starting point for conditioning the target
134
+ # action: "enhance" # erase or enhance
135
+ # guidance_scale: 4
136
+ # resolution: 512
137
+ # dynamic_resolution: false
138
+ # batch_size: 1
139
+ # - target: "person" # what word for erasing the positive concept from
140
+ # positive: "person with short hair" # concept to erase
141
+ # unconditional: "person" # word to take the difference from the positive concept
142
+ # neutral: "person" # starting point for conditioning the target
143
+ # action: "erase" # erase or enhance
144
+ # guidance_scale: 4
145
+ # resolution: 512
146
+ # dynamic_resolution: false
147
+ # batch_size: 1
148
+ # - target: "girl" # what word for erasing the positive concept from
149
+ # positive: "baby girl" # concept to erase
150
+ # unconditional: "girl" # word to take the difference from the positive concept
151
+ # neutral: "girl" # starting point for conditioning the target
152
+ # action: "enhance" # erase or enhance
153
+ # guidance_scale: -4
154
+ # resolution: 512
155
+ # dynamic_resolution: false
156
+ # batch_size: 1
157
+ # - target: "boy" # what word for erasing the positive concept from
158
+ # positive: "old man" # concept to erase
159
+ # unconditional: "boy" # word to take the difference from the positive concept
160
+ # neutral: "boy" # starting point for conditioning the target
161
+ # action: "enhance" # erase or enhance
162
+ # guidance_scale: 4
163
+ # resolution: 512
164
+ # dynamic_resolution: false
165
+ # batch_size: 1
166
+ # - target: "boy" # what word for erasing the positive concept from
167
+ # positive: "baby boy" # concept to erase
168
+ # unconditional: "boy" # word to take the difference from the positive concept
169
+ # neutral: "boy" # starting point for conditioning the target
170
+ # action: "enhance" # erase or enhance
171
+ # guidance_scale: -4
172
+ # resolution: 512
173
+ # dynamic_resolution: false
174
+ # batch_size: 1
trainscripts/imagesliders/debug_util.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # デバッグ用...
2
+
3
+ import torch
4
+
5
+
6
+ def check_requires_grad(model: torch.nn.Module):
7
+ for name, module in list(model.named_modules())[:5]:
8
+ if len(list(module.parameters())) > 0:
9
+ print(f"Module: {name}")
10
+ for name, param in list(module.named_parameters())[:2]:
11
+ print(f" Parameter: {name}, Requires Grad: {param.requires_grad}")
12
+
13
+
14
+ def check_training_mode(model: torch.nn.Module):
15
+ for name, module in list(model.named_modules())[:5]:
16
+ print(f"Module: {name}, Training Mode: {module.training}")
trainscripts/imagesliders/lora.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:
2
+ # - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
3
+ # - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
4
+
5
+ import os
6
+ import math
7
+ from typing import Optional, List, Type, Set, Literal
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from diffusers import UNet2DConditionModel
12
+ from safetensors.torch import save_file
13
+
14
+
15
+ UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
16
+ # "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
17
+ "Attention"
18
+ ]
19
+ UNET_TARGET_REPLACE_MODULE_CONV = [
20
+ "ResnetBlock2D",
21
+ "Downsample2D",
22
+ "Upsample2D",
23
+ # "DownBlock2D",
24
+ # "UpBlock2D"
25
+ ] # locon, 3clier
26
+
27
+ LORA_PREFIX_UNET = "lora_unet"
28
+
29
+ DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
30
+
31
+ TRAINING_METHODS = Literal[
32
+ "noxattn", # train all layers except x-attns and time_embed layers
33
+ "innoxattn", # train all layers except self attention layers
34
+ "selfattn", # ESD-u, train only self attention layers
35
+ "xattn", # ESD-x, train only x attention layers
36
+ "full", # train all layers
37
+ "xattn-strict", # q and k values
38
+ "noxattn-hspace",
39
+ "noxattn-hspace-last",
40
+ # "xlayer",
41
+ # "outxattn",
42
+ # "outsattn",
43
+ # "inxattn",
44
+ # "inmidsattn",
45
+ # "selflayer",
46
+ ]
47
+
48
+
49
+ class LoRAModule(nn.Module):
50
+ """
51
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ lora_name,
57
+ org_module: nn.Module,
58
+ multiplier=1.0,
59
+ lora_dim=4,
60
+ alpha=1,
61
+ ):
62
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
63
+ super().__init__()
64
+ self.lora_name = lora_name
65
+ self.lora_dim = lora_dim
66
+
67
+ if "Linear" in org_module.__class__.__name__:
68
+ in_dim = org_module.in_features
69
+ out_dim = org_module.out_features
70
+ self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
71
+ self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
72
+
73
+ elif "Conv" in org_module.__class__.__name__: # 一応
74
+ in_dim = org_module.in_channels
75
+ out_dim = org_module.out_channels
76
+
77
+ self.lora_dim = min(self.lora_dim, in_dim, out_dim)
78
+ if self.lora_dim != lora_dim:
79
+ print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
80
+
81
+ kernel_size = org_module.kernel_size
82
+ stride = org_module.stride
83
+ padding = org_module.padding
84
+ self.lora_down = nn.Conv2d(
85
+ in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
86
+ )
87
+ self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
88
+
89
+ if type(alpha) == torch.Tensor:
90
+ alpha = alpha.detach().numpy()
91
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
92
+ self.scale = alpha / self.lora_dim
93
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
94
+
95
+ # same as microsoft's
96
+ nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
97
+ nn.init.zeros_(self.lora_up.weight)
98
+
99
+ self.multiplier = multiplier
100
+ self.org_module = org_module # remove in applying
101
+
102
+ def apply_to(self):
103
+ self.org_forward = self.org_module.forward
104
+ self.org_module.forward = self.forward
105
+ del self.org_module
106
+
107
+ def forward(self, x):
108
+ return (
109
+ self.org_forward(x)
110
+ + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
111
+ )
112
+
113
+
114
+ class LoRANetwork(nn.Module):
115
+ def __init__(
116
+ self,
117
+ unet: UNet2DConditionModel,
118
+ rank: int = 4,
119
+ multiplier: float = 1.0,
120
+ alpha: float = 1.0,
121
+ train_method: TRAINING_METHODS = "full",
122
+ ) -> None:
123
+ super().__init__()
124
+ self.lora_scale = 1
125
+ self.multiplier = multiplier
126
+ self.lora_dim = rank
127
+ self.alpha = alpha
128
+
129
+ # LoRAのみ
130
+ self.module = LoRAModule
131
+
132
+ # unetのloraを作る
133
+ self.unet_loras = self.create_modules(
134
+ LORA_PREFIX_UNET,
135
+ unet,
136
+ DEFAULT_TARGET_REPLACE,
137
+ self.lora_dim,
138
+ self.multiplier,
139
+ train_method=train_method,
140
+ )
141
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
142
+
143
+ # assertion 名前の被りがないか確認しているようだ
144
+ lora_names = set()
145
+ for lora in self.unet_loras:
146
+ assert (
147
+ lora.lora_name not in lora_names
148
+ ), f"duplicated lora name: {lora.lora_name}. {lora_names}"
149
+ lora_names.add(lora.lora_name)
150
+
151
+ # 適用する
152
+ for lora in self.unet_loras:
153
+ lora.apply_to()
154
+ self.add_module(
155
+ lora.lora_name,
156
+ lora,
157
+ )
158
+
159
+ del unet
160
+
161
+ torch.cuda.empty_cache()
162
+
163
+ def create_modules(
164
+ self,
165
+ prefix: str,
166
+ root_module: nn.Module,
167
+ target_replace_modules: List[str],
168
+ rank: int,
169
+ multiplier: float,
170
+ train_method: TRAINING_METHODS,
171
+ ) -> list:
172
+ loras = []
173
+ names = []
174
+ for name, module in root_module.named_modules():
175
+ if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習
176
+ if "attn2" in name or "time_embed" in name:
177
+ continue
178
+ elif train_method == "innoxattn": # Cross Attention 以外学習
179
+ if "attn2" in name:
180
+ continue
181
+ elif train_method == "selfattn": # Self Attention のみ学習
182
+ if "attn1" not in name:
183
+ continue
184
+ elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention のみ学習
185
+ if "attn2" not in name:
186
+ continue
187
+ elif train_method == "full": # 全部学習
188
+ pass
189
+ else:
190
+ raise NotImplementedError(
191
+ f"train_method: {train_method} is not implemented."
192
+ )
193
+ if module.__class__.__name__ in target_replace_modules:
194
+ for child_name, child_module in module.named_modules():
195
+ if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:
196
+ if train_method == 'xattn-strict':
197
+ if 'out' in child_name:
198
+ continue
199
+ if train_method == 'noxattn-hspace':
200
+ if 'mid_block' not in name:
201
+ continue
202
+ if train_method == 'noxattn-hspace-last':
203
+ if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
204
+ continue
205
+ lora_name = prefix + "." + name + "." + child_name
206
+ lora_name = lora_name.replace(".", "_")
207
+ # print(f"{lora_name}")
208
+ lora = self.module(
209
+ lora_name, child_module, multiplier, rank, self.alpha
210
+ )
211
+ # print(name, child_name)
212
+ # print(child_module.weight.shape)
213
+ loras.append(lora)
214
+ names.append(lora_name)
215
+ # print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}')
216
+ return loras
217
+
218
+ def prepare_optimizer_params(self):
219
+ all_params = []
220
+
221
+ if self.unet_loras: # 実質これしかない
222
+ params = []
223
+ [params.extend(lora.parameters()) for lora in self.unet_loras]
224
+ param_data = {"params": params}
225
+ all_params.append(param_data)
226
+
227
+ return all_params
228
+
229
+ def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
230
+ state_dict = self.state_dict()
231
+
232
+ if dtype is not None:
233
+ for key in list(state_dict.keys()):
234
+ v = state_dict[key]
235
+ v = v.detach().clone().to("cpu").to(dtype)
236
+ state_dict[key] = v
237
+
238
+ # for key in list(state_dict.keys()):
239
+ # if not key.startswith("lora"):
240
+ # # lora以外除外
241
+ # del state_dict[key]
242
+
243
+ if os.path.splitext(file)[1] == ".safetensors":
244
+ save_file(state_dict, file, metadata)
245
+ else:
246
+ torch.save(state_dict, file)
247
+ def set_lora_slider(self, scale):
248
+ self.lora_scale = scale
249
+
250
+ def __enter__(self):
251
+ for lora in self.unet_loras:
252
+ lora.multiplier = 1.0 * self.lora_scale
253
+
254
+ def __exit__(self, exc_type, exc_value, tb):
255
+ for lora in self.unet_loras:
256
+ lora.multiplier = 0
trainscripts/imagesliders/model_util.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Union, Optional
2
+
3
+ import torch
4
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
5
+ from diffusers import (
6
+ UNet2DConditionModel,
7
+ SchedulerMixin,
8
+ StableDiffusionPipeline,
9
+ StableDiffusionXLPipeline,
10
+ AutoencoderKL,
11
+ )
12
+ from diffusers.schedulers import (
13
+ DDIMScheduler,
14
+ DDPMScheduler,
15
+ LMSDiscreteScheduler,
16
+ EulerAncestralDiscreteScheduler,
17
+ )
18
+
19
+
20
+ TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
21
+ TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
22
+
23
+ AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"]
24
+
25
+ SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
26
+
27
+ DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
28
+
29
+
30
+ def load_diffusers_model(
31
+ pretrained_model_name_or_path: str,
32
+ v2: bool = False,
33
+ clip_skip: Optional[int] = None,
34
+ weight_dtype: torch.dtype = torch.float32,
35
+ ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
36
+ # VAE はいらない
37
+
38
+ if v2:
39
+ tokenizer = CLIPTokenizer.from_pretrained(
40
+ TOKENIZER_V2_MODEL_NAME,
41
+ subfolder="tokenizer",
42
+ torch_dtype=weight_dtype,
43
+ cache_dir=DIFFUSERS_CACHE_DIR,
44
+ )
45
+ text_encoder = CLIPTextModel.from_pretrained(
46
+ pretrained_model_name_or_path,
47
+ subfolder="text_encoder",
48
+ # default is clip skip 2
49
+ num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
50
+ torch_dtype=weight_dtype,
51
+ cache_dir=DIFFUSERS_CACHE_DIR,
52
+ )
53
+ else:
54
+ tokenizer = CLIPTokenizer.from_pretrained(
55
+ TOKENIZER_V1_MODEL_NAME,
56
+ subfolder="tokenizer",
57
+ torch_dtype=weight_dtype,
58
+ cache_dir=DIFFUSERS_CACHE_DIR,
59
+ )
60
+ text_encoder = CLIPTextModel.from_pretrained(
61
+ pretrained_model_name_or_path,
62
+ subfolder="text_encoder",
63
+ num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
64
+ torch_dtype=weight_dtype,
65
+ cache_dir=DIFFUSERS_CACHE_DIR,
66
+ )
67
+
68
+ unet = UNet2DConditionModel.from_pretrained(
69
+ pretrained_model_name_or_path,
70
+ subfolder="unet",
71
+ torch_dtype=weight_dtype,
72
+ cache_dir=DIFFUSERS_CACHE_DIR,
73
+ )
74
+
75
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
76
+
77
+ return tokenizer, text_encoder, unet, vae
78
+
79
+
80
+ def load_checkpoint_model(
81
+ checkpoint_path: str,
82
+ v2: bool = False,
83
+ clip_skip: Optional[int] = None,
84
+ weight_dtype: torch.dtype = torch.float32,
85
+ ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
86
+ pipe = StableDiffusionPipeline.from_ckpt(
87
+ checkpoint_path,
88
+ upcast_attention=True if v2 else False,
89
+ torch_dtype=weight_dtype,
90
+ cache_dir=DIFFUSERS_CACHE_DIR,
91
+ )
92
+
93
+ unet = pipe.unet
94
+ tokenizer = pipe.tokenizer
95
+ text_encoder = pipe.text_encoder
96
+ vae = pipe.vae
97
+ if clip_skip is not None:
98
+ if v2:
99
+ text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
100
+ else:
101
+ text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
102
+
103
+ del pipe
104
+
105
+ return tokenizer, text_encoder, unet, vae
106
+
107
+
108
+ def load_models(
109
+ pretrained_model_name_or_path: str,
110
+ scheduler_name: AVAILABLE_SCHEDULERS,
111
+ v2: bool = False,
112
+ v_pred: bool = False,
113
+ weight_dtype: torch.dtype = torch.float32,
114
+ ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
115
+ if pretrained_model_name_or_path.endswith(
116
+ ".ckpt"
117
+ ) or pretrained_model_name_or_path.endswith(".safetensors"):
118
+ tokenizer, text_encoder, unet, vae = load_checkpoint_model(
119
+ pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
120
+ )
121
+ else: # diffusers
122
+ tokenizer, text_encoder, unet, vae = load_diffusers_model(
123
+ pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
124
+ )
125
+
126
+ # VAE はいらない
127
+
128
+ scheduler = create_noise_scheduler(
129
+ scheduler_name,
130
+ prediction_type="v_prediction" if v_pred else "epsilon",
131
+ )
132
+
133
+ return tokenizer, text_encoder, unet, scheduler, vae
134
+
135
+
136
+ def load_diffusers_model_xl(
137
+ pretrained_model_name_or_path: str,
138
+ weight_dtype: torch.dtype = torch.float32,
139
+ ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
140
+ # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
141
+
142
+ tokenizers = [
143
+ CLIPTokenizer.from_pretrained(
144
+ pretrained_model_name_or_path,
145
+ subfolder="tokenizer",
146
+ torch_dtype=weight_dtype,
147
+ cache_dir=DIFFUSERS_CACHE_DIR,
148
+ ),
149
+ CLIPTokenizer.from_pretrained(
150
+ pretrained_model_name_or_path,
151
+ subfolder="tokenizer_2",
152
+ torch_dtype=weight_dtype,
153
+ cache_dir=DIFFUSERS_CACHE_DIR,
154
+ pad_token_id=0, # same as open clip
155
+ ),
156
+ ]
157
+
158
+ text_encoders = [
159
+ CLIPTextModel.from_pretrained(
160
+ pretrained_model_name_or_path,
161
+ subfolder="text_encoder",
162
+ torch_dtype=weight_dtype,
163
+ cache_dir=DIFFUSERS_CACHE_DIR,
164
+ ),
165
+ CLIPTextModelWithProjection.from_pretrained(
166
+ pretrained_model_name_or_path,
167
+ subfolder="text_encoder_2",
168
+ torch_dtype=weight_dtype,
169
+ cache_dir=DIFFUSERS_CACHE_DIR,
170
+ ),
171
+ ]
172
+
173
+ unet = UNet2DConditionModel.from_pretrained(
174
+ pretrained_model_name_or_path,
175
+ subfolder="unet",
176
+ torch_dtype=weight_dtype,
177
+ cache_dir=DIFFUSERS_CACHE_DIR,
178
+ )
179
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
180
+ return tokenizers, text_encoders, unet, vae
181
+
182
+
183
+ def load_checkpoint_model_xl(
184
+ checkpoint_path: str,
185
+ weight_dtype: torch.dtype = torch.float32,
186
+ ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
187
+ pipe = StableDiffusionXLPipeline.from_single_file(
188
+ checkpoint_path,
189
+ torch_dtype=weight_dtype,
190
+ cache_dir=DIFFUSERS_CACHE_DIR,
191
+ )
192
+
193
+ unet = pipe.unet
194
+ tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
195
+ text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
196
+ if len(text_encoders) == 2:
197
+ text_encoders[1].pad_token_id = 0
198
+
199
+ del pipe
200
+
201
+ return tokenizers, text_encoders, unet
202
+
203
+
204
+ def load_models_xl(
205
+ pretrained_model_name_or_path: str,
206
+ scheduler_name: AVAILABLE_SCHEDULERS,
207
+ weight_dtype: torch.dtype = torch.float32,
208
+ ) -> tuple[
209
+ list[CLIPTokenizer],
210
+ list[SDXL_TEXT_ENCODER_TYPE],
211
+ UNet2DConditionModel,
212
+ SchedulerMixin,
213
+ ]:
214
+ if pretrained_model_name_or_path.endswith(
215
+ ".ckpt"
216
+ ) or pretrained_model_name_or_path.endswith(".safetensors"):
217
+ (
218
+ tokenizers,
219
+ text_encoders,
220
+ unet,
221
+ ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype)
222
+ else: # diffusers
223
+ (
224
+ tokenizers,
225
+ text_encoders,
226
+ unet,
227
+ vae
228
+ ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype)
229
+
230
+ scheduler = create_noise_scheduler(scheduler_name)
231
+
232
+ return tokenizers, text_encoders, unet, scheduler, vae
233
+
234
+
235
+ def create_noise_scheduler(
236
+ scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
237
+ prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
238
+ ) -> SchedulerMixin:
239
+ # 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。
240
+
241
+ name = scheduler_name.lower().replace(" ", "_")
242
+ if name == "ddim":
243
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
244
+ scheduler = DDIMScheduler(
245
+ beta_start=0.00085,
246
+ beta_end=0.012,
247
+ beta_schedule="scaled_linear",
248
+ num_train_timesteps=1000,
249
+ clip_sample=False,
250
+ prediction_type=prediction_type, # これでいいの?
251
+ )
252
+ elif name == "ddpm":
253
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
254
+ scheduler = DDPMScheduler(
255
+ beta_start=0.00085,
256
+ beta_end=0.012,
257
+ beta_schedule="scaled_linear",
258
+ num_train_timesteps=1000,
259
+ clip_sample=False,
260
+ prediction_type=prediction_type,
261
+ )
262
+ elif name == "lms":
263
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
264
+ scheduler = LMSDiscreteScheduler(
265
+ beta_start=0.00085,
266
+ beta_end=0.012,
267
+ beta_schedule="scaled_linear",
268
+ num_train_timesteps=1000,
269
+ prediction_type=prediction_type,
270
+ )
271
+ elif name == "euler_a":
272
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
273
+ scheduler = EulerAncestralDiscreteScheduler(
274
+ beta_start=0.00085,
275
+ beta_end=0.012,
276
+ beta_schedule="scaled_linear",
277
+ num_train_timesteps=1000,
278
+ prediction_type=prediction_type,
279
+ )
280
+ else:
281
+ raise ValueError(f"Unknown scheduler name: {name}")
282
+
283
+ return scheduler
trainscripts/imagesliders/prompt_util.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional, Union, List
2
+
3
+ import yaml
4
+ from pathlib import Path
5
+
6
+
7
+ from pydantic import BaseModel, root_validator
8
+ import torch
9
+ import copy
10
+
11
+ ACTION_TYPES = Literal[
12
+ "erase",
13
+ "enhance",
14
+ ]
15
+
16
+
17
+ # XL は二種類必要なので
18
+ class PromptEmbedsXL:
19
+ text_embeds: torch.FloatTensor
20
+ pooled_embeds: torch.FloatTensor
21
+
22
+ def __init__(self, *args) -> None:
23
+ self.text_embeds = args[0]
24
+ self.pooled_embeds = args[1]
25
+
26
+
27
+ # SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL
28
+ PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL]
29
+
30
+
31
+ class PromptEmbedsCache: # 使いまわしたいので
32
+ prompts: dict[str, PROMPT_EMBEDDING] = {}
33
+
34
+ def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None:
35
+ self.prompts[__name] = __value
36
+
37
+ def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]:
38
+ if __name in self.prompts:
39
+ return self.prompts[__name]
40
+ else:
41
+ return None
42
+
43
+
44
+ class PromptSettings(BaseModel): # yaml のやつ
45
+ target: str
46
+ positive: str = None # if None, target will be used
47
+ unconditional: str = "" # default is ""
48
+ neutral: str = None # if None, unconditional will be used
49
+ action: ACTION_TYPES = "erase" # default is "erase"
50
+ guidance_scale: float = 1.0 # default is 1.0
51
+ resolution: int = 512 # default is 512
52
+ dynamic_resolution: bool = False # default is False
53
+ batch_size: int = 1 # default is 1
54
+ dynamic_crops: bool = False # default is False. only used when model is XL
55
+
56
+ @root_validator(pre=True)
57
+ def fill_prompts(cls, values):
58
+ keys = values.keys()
59
+ if "target" not in keys:
60
+ raise ValueError("target must be specified")
61
+ if "positive" not in keys:
62
+ values["positive"] = values["target"]
63
+ if "unconditional" not in keys:
64
+ values["unconditional"] = ""
65
+ if "neutral" not in keys:
66
+ values["neutral"] = values["unconditional"]
67
+
68
+ return values
69
+
70
+
71
+ class PromptEmbedsPair:
72
+ target: PROMPT_EMBEDDING # not want to generate the concept
73
+ positive: PROMPT_EMBEDDING # generate the concept
74
+ unconditional: PROMPT_EMBEDDING # uncondition (default should be empty)
75
+ neutral: PROMPT_EMBEDDING # base condition (default should be empty)
76
+
77
+ guidance_scale: float
78
+ resolution: int
79
+ dynamic_resolution: bool
80
+ batch_size: int
81
+ dynamic_crops: bool
82
+
83
+ loss_fn: torch.nn.Module
84
+ action: ACTION_TYPES
85
+
86
+ def __init__(
87
+ self,
88
+ loss_fn: torch.nn.Module,
89
+ target: PROMPT_EMBEDDING,
90
+ positive: PROMPT_EMBEDDING,
91
+ unconditional: PROMPT_EMBEDDING,
92
+ neutral: PROMPT_EMBEDDING,
93
+ settings: PromptSettings,
94
+ ) -> None:
95
+ self.loss_fn = loss_fn
96
+ self.target = target
97
+ self.positive = positive
98
+ self.unconditional = unconditional
99
+ self.neutral = neutral
100
+
101
+ self.guidance_scale = settings.guidance_scale
102
+ self.resolution = settings.resolution
103
+ self.dynamic_resolution = settings.dynamic_resolution
104
+ self.batch_size = settings.batch_size
105
+ self.dynamic_crops = settings.dynamic_crops
106
+ self.action = settings.action
107
+
108
+ def _erase(
109
+ self,
110
+ target_latents: torch.FloatTensor, # "van gogh"
111
+ positive_latents: torch.FloatTensor, # "van gogh"
112
+ unconditional_latents: torch.FloatTensor, # ""
113
+ neutral_latents: torch.FloatTensor, # ""
114
+ ) -> torch.FloatTensor:
115
+ """Target latents are going not to have the positive concept."""
116
+ return self.loss_fn(
117
+ target_latents,
118
+ neutral_latents
119
+ - self.guidance_scale * (positive_latents - unconditional_latents)
120
+ )
121
+
122
+
123
+ def _enhance(
124
+ self,
125
+ target_latents: torch.FloatTensor, # "van gogh"
126
+ positive_latents: torch.FloatTensor, # "van gogh"
127
+ unconditional_latents: torch.FloatTensor, # ""
128
+ neutral_latents: torch.FloatTensor, # ""
129
+ ):
130
+ """Target latents are going to have the positive concept."""
131
+ return self.loss_fn(
132
+ target_latents,
133
+ neutral_latents
134
+ + self.guidance_scale * (positive_latents - unconditional_latents)
135
+ )
136
+
137
+ def loss(
138
+ self,
139
+ **kwargs,
140
+ ):
141
+ if self.action == "erase":
142
+ return self._erase(**kwargs)
143
+
144
+ elif self.action == "enhance":
145
+ return self._enhance(**kwargs)
146
+
147
+ else:
148
+ raise ValueError("action must be erase or enhance")
149
+
150
+
151
+ def load_prompts_from_yaml(path, attributes = []):
152
+ with open(path, "r") as f:
153
+ prompts = yaml.safe_load(f)
154
+ print(prompts)
155
+ if len(prompts) == 0:
156
+ raise ValueError("prompts file is empty")
157
+ if len(attributes)!=0:
158
+ newprompts = []
159
+ for i in range(len(prompts)):
160
+ for att in attributes:
161
+ copy_ = copy.deepcopy(prompts[i])
162
+ copy_['target'] = att + ' ' + copy_['target']
163
+ copy_['positive'] = att + ' ' + copy_['positive']
164
+ copy_['neutral'] = att + ' ' + copy_['neutral']
165
+ copy_['unconditional'] = att + ' ' + copy_['unconditional']
166
+ newprompts.append(copy_)
167
+ else:
168
+ newprompts = copy.deepcopy(prompts)
169
+
170
+ print(newprompts)
171
+ print(len(prompts), len(newprompts))
172
+ prompt_settings = [PromptSettings(**prompt) for prompt in newprompts]
173
+
174
+ return prompt_settings
trainscripts/imagesliders/train_lora-scale-xl.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:
2
+ # - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566
3
+ # - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py
4
+
5
+ from typing import List, Optional
6
+ import argparse
7
+ import ast
8
+ from pathlib import Path
9
+ import gc, os
10
+ import numpy as np
11
+
12
+ import torch
13
+ from tqdm import tqdm
14
+ from PIL import Image
15
+
16
+
17
+
18
+ import train_util
19
+ import random
20
+ import model_util
21
+ import prompt_util
22
+ from prompt_util import (
23
+ PromptEmbedsCache,
24
+ PromptEmbedsPair,
25
+ PromptSettings,
26
+ PromptEmbedsXL,
27
+ )
28
+ import debug_util
29
+ import config_util
30
+ from config_util import RootConfig
31
+
32
+ import wandb
33
+
34
+ NUM_IMAGES_PER_PROMPT = 1
35
+ from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
36
+
37
+ def flush():
38
+ torch.cuda.empty_cache()
39
+ gc.collect()
40
+
41
+
42
+ def train(
43
+ config: RootConfig,
44
+ prompts: list[PromptSettings],
45
+ device,
46
+ folder_main: str,
47
+ folders,
48
+ scales,
49
+
50
+ ):
51
+ scales = np.array(scales)
52
+ folders = np.array(folders)
53
+ scales_unique = list(scales)
54
+
55
+ metadata = {
56
+ "prompts": ",".join([prompt.json() for prompt in prompts]),
57
+ "config": config.json(),
58
+ }
59
+ save_path = Path(config.save.path)
60
+
61
+ modules = DEFAULT_TARGET_REPLACE
62
+ if config.network.type == "c3lier":
63
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
64
+
65
+ if config.logging.verbose:
66
+ print(metadata)
67
+
68
+ if config.logging.use_wandb:
69
+ wandb.init(project=f"LECO_{config.save.name}", config=metadata)
70
+
71
+ weight_dtype = config_util.parse_precision(config.train.precision)
72
+ save_weight_dtype = config_util.parse_precision(config.train.precision)
73
+
74
+ (
75
+ tokenizers,
76
+ text_encoders,
77
+ unet,
78
+ noise_scheduler,
79
+ vae
80
+ ) = model_util.load_models_xl(
81
+ config.pretrained_model.name_or_path,
82
+ scheduler_name=config.train.noise_scheduler,
83
+ )
84
+
85
+ for text_encoder in text_encoders:
86
+ text_encoder.to(device, dtype=weight_dtype)
87
+ text_encoder.requires_grad_(False)
88
+ text_encoder.eval()
89
+
90
+ unet.to(device, dtype=weight_dtype)
91
+ if config.other.use_xformers:
92
+ unet.enable_xformers_memory_efficient_attention()
93
+ unet.requires_grad_(False)
94
+ unet.eval()
95
+
96
+ vae.to(device)
97
+ vae.requires_grad_(False)
98
+ vae.eval()
99
+
100
+ network = LoRANetwork(
101
+ unet,
102
+ rank=config.network.rank,
103
+ multiplier=1.0,
104
+ alpha=config.network.alpha,
105
+ train_method=config.network.training_method,
106
+ ).to(device, dtype=weight_dtype)
107
+
108
+ optimizer_module = train_util.get_optimizer(config.train.optimizer)
109
+ #optimizer_args
110
+ optimizer_kwargs = {}
111
+ if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0:
112
+ for arg in config.train.optimizer_args.split(" "):
113
+ key, value = arg.split("=")
114
+ value = ast.literal_eval(value)
115
+ optimizer_kwargs[key] = value
116
+
117
+ optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs)
118
+ lr_scheduler = train_util.get_lr_scheduler(
119
+ config.train.lr_scheduler,
120
+ optimizer,
121
+ max_iterations=config.train.iterations,
122
+ lr_min=config.train.lr / 100,
123
+ )
124
+ criteria = torch.nn.MSELoss()
125
+
126
+ print("Prompts")
127
+ for settings in prompts:
128
+ print(settings)
129
+
130
+ # debug
131
+ debug_util.check_requires_grad(network)
132
+ debug_util.check_training_mode(network)
133
+
134
+ cache = PromptEmbedsCache()
135
+ prompt_pairs: list[PromptEmbedsPair] = []
136
+
137
+ with torch.no_grad():
138
+ for settings in prompts:
139
+ print(settings)
140
+ for prompt in [
141
+ settings.target,
142
+ settings.positive,
143
+ settings.neutral,
144
+ settings.unconditional,
145
+ ]:
146
+ if cache[prompt] == None:
147
+ tex_embs, pool_embs = train_util.encode_prompts_xl(
148
+ tokenizers,
149
+ text_encoders,
150
+ [prompt],
151
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
152
+ )
153
+ cache[prompt] = PromptEmbedsXL(
154
+ tex_embs,
155
+ pool_embs
156
+ )
157
+
158
+ prompt_pairs.append(
159
+ PromptEmbedsPair(
160
+ criteria,
161
+ cache[settings.target],
162
+ cache[settings.positive],
163
+ cache[settings.unconditional],
164
+ cache[settings.neutral],
165
+ settings,
166
+ )
167
+ )
168
+
169
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
170
+ del tokenizer, text_encoder
171
+
172
+ flush()
173
+
174
+ pbar = tqdm(range(config.train.iterations))
175
+
176
+ loss = None
177
+
178
+ for i in pbar:
179
+ with torch.no_grad():
180
+ noise_scheduler.set_timesteps(
181
+ config.train.max_denoising_steps, device=device
182
+ )
183
+
184
+ optimizer.zero_grad()
185
+
186
+ prompt_pair: PromptEmbedsPair = prompt_pairs[
187
+ torch.randint(0, len(prompt_pairs), (1,)).item()
188
+ ]
189
+
190
+ # 1 ~ 49 からランダム
191
+ timesteps_to = torch.randint(
192
+ 1, config.train.max_denoising_steps, (1,)
193
+ ).item()
194
+
195
+ height, width = prompt_pair.resolution, prompt_pair.resolution
196
+ if prompt_pair.dynamic_resolution:
197
+ height, width = train_util.get_random_resolution_in_bucket(
198
+ prompt_pair.resolution
199
+ )
200
+
201
+ if config.logging.verbose:
202
+ print("guidance_scale:", prompt_pair.guidance_scale)
203
+ print("resolution:", prompt_pair.resolution)
204
+ print("dynamic_resolution:", prompt_pair.dynamic_resolution)
205
+ if prompt_pair.dynamic_resolution:
206
+ print("bucketed resolution:", (height, width))
207
+ print("batch_size:", prompt_pair.batch_size)
208
+ print("dynamic_crops:", prompt_pair.dynamic_crops)
209
+
210
+
211
+
212
+ scale_to_look = abs(random.choice(list(scales_unique)))
213
+ folder1 = folders[scales==-scale_to_look][0]
214
+ folder2 = folders[scales==scale_to_look][0]
215
+
216
+ ims = os.listdir(f'{folder_main}/{folder1}/')
217
+ ims = [im_ for im_ in ims if '.png' in im_ or '.jpg' in im_ or '.jpeg' in im_ or '.webp' in im_]
218
+ random_sampler = random.randint(0, len(ims)-1)
219
+
220
+ img1 = Image.open(f'{folder_main}/{folder1}/{ims[random_sampler]}').resize((512,512))
221
+ img2 = Image.open(f'{folder_main}/{folder2}/{ims[random_sampler]}').resize((512,512))
222
+
223
+ seed = random.randint(0,2*15)
224
+
225
+ generator = torch.manual_seed(seed)
226
+ denoised_latents_low, low_noise = train_util.get_noisy_image(
227
+ img1,
228
+ vae,
229
+ generator,
230
+ unet,
231
+ noise_scheduler,
232
+ start_timesteps=0,
233
+ total_timesteps=timesteps_to)
234
+ denoised_latents_low = denoised_latents_low.to(device, dtype=weight_dtype)
235
+ low_noise = low_noise.to(device, dtype=weight_dtype)
236
+
237
+ generator = torch.manual_seed(seed)
238
+ denoised_latents_high, high_noise = train_util.get_noisy_image(
239
+ img2,
240
+ vae,
241
+ generator,
242
+ unet,
243
+ noise_scheduler,
244
+ start_timesteps=0,
245
+ total_timesteps=timesteps_to)
246
+ denoised_latents_high = denoised_latents_high.to(device, dtype=weight_dtype)
247
+ high_noise = high_noise.to(device, dtype=weight_dtype)
248
+ noise_scheduler.set_timesteps(1000)
249
+
250
+ add_time_ids = train_util.get_add_time_ids(
251
+ height,
252
+ width,
253
+ dynamic_crops=prompt_pair.dynamic_crops,
254
+ dtype=weight_dtype,
255
+ ).to(device, dtype=weight_dtype)
256
+
257
+
258
+ current_timestep = noise_scheduler.timesteps[
259
+ int(timesteps_to * 1000 / config.train.max_denoising_steps)
260
+ ]
261
+ try:
262
+ # with network: の外では空のLoRAのみが有効になる
263
+ high_latents = train_util.predict_noise_xl(
264
+ unet,
265
+ noise_scheduler,
266
+ current_timestep,
267
+ denoised_latents_high,
268
+ text_embeddings=train_util.concat_embeddings(
269
+ prompt_pair.unconditional.text_embeds,
270
+ prompt_pair.positive.text_embeds,
271
+ prompt_pair.batch_size,
272
+ ),
273
+ add_text_embeddings=train_util.concat_embeddings(
274
+ prompt_pair.unconditional.pooled_embeds,
275
+ prompt_pair.positive.pooled_embeds,
276
+ prompt_pair.batch_size,
277
+ ),
278
+ add_time_ids=train_util.concat_embeddings(
279
+ add_time_ids, add_time_ids, prompt_pair.batch_size
280
+ ),
281
+ guidance_scale=1,
282
+ ).to(device, dtype=torch.float32)
283
+ except:
284
+ flush()
285
+ print(f'Error Occured!: {np.array(img1).shape} {np.array(img2).shape}')
286
+ continue
287
+ # with network: の外では空のLoRAのみが有効になる
288
+
289
+ low_latents = train_util.predict_noise_xl(
290
+ unet,
291
+ noise_scheduler,
292
+ current_timestep,
293
+ denoised_latents_low,
294
+ text_embeddings=train_util.concat_embeddings(
295
+ prompt_pair.unconditional.text_embeds,
296
+ prompt_pair.neutral.text_embeds,
297
+ prompt_pair.batch_size,
298
+ ),
299
+ add_text_embeddings=train_util.concat_embeddings(
300
+ prompt_pair.unconditional.pooled_embeds,
301
+ prompt_pair.neutral.pooled_embeds,
302
+ prompt_pair.batch_size,
303
+ ),
304
+ add_time_ids=train_util.concat_embeddings(
305
+ add_time_ids, add_time_ids, prompt_pair.batch_size
306
+ ),
307
+ guidance_scale=1,
308
+ ).to(device, dtype=torch.float32)
309
+
310
+
311
+
312
+ if config.logging.verbose:
313
+ print("positive_latents:", positive_latents[0, 0, :5, :5])
314
+ print("neutral_latents:", neutral_latents[0, 0, :5, :5])
315
+ print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
316
+
317
+ network.set_lora_slider(scale=scale_to_look)
318
+ with network:
319
+ target_latents_high = train_util.predict_noise_xl(
320
+ unet,
321
+ noise_scheduler,
322
+ current_timestep,
323
+ denoised_latents_high,
324
+ text_embeddings=train_util.concat_embeddings(
325
+ prompt_pair.unconditional.text_embeds,
326
+ prompt_pair.positive.text_embeds,
327
+ prompt_pair.batch_size,
328
+ ),
329
+ add_text_embeddings=train_util.concat_embeddings(
330
+ prompt_pair.unconditional.pooled_embeds,
331
+ prompt_pair.positive.pooled_embeds,
332
+ prompt_pair.batch_size,
333
+ ),
334
+ add_time_ids=train_util.concat_embeddings(
335
+ add_time_ids, add_time_ids, prompt_pair.batch_size
336
+ ),
337
+ guidance_scale=1,
338
+ ).to(device, dtype=torch.float32)
339
+
340
+ high_latents.requires_grad = False
341
+ low_latents.requires_grad = False
342
+
343
+ loss_high = criteria(target_latents_high, high_noise.to(torch.float32))
344
+ pbar.set_description(f"Loss*1k: {loss_high.item()*1000:.4f}")
345
+ loss_high.backward()
346
+
347
+ # opposite
348
+ network.set_lora_slider(scale=-scale_to_look)
349
+ with network:
350
+ target_latents_low = train_util.predict_noise_xl(
351
+ unet,
352
+ noise_scheduler,
353
+ current_timestep,
354
+ denoised_latents_low,
355
+ text_embeddings=train_util.concat_embeddings(
356
+ prompt_pair.unconditional.text_embeds,
357
+ prompt_pair.neutral.text_embeds,
358
+ prompt_pair.batch_size,
359
+ ),
360
+ add_text_embeddings=train_util.concat_embeddings(
361
+ prompt_pair.unconditional.pooled_embeds,
362
+ prompt_pair.neutral.pooled_embeds,
363
+ prompt_pair.batch_size,
364
+ ),
365
+ add_time_ids=train_util.concat_embeddings(
366
+ add_time_ids, add_time_ids, prompt_pair.batch_size
367
+ ),
368
+ guidance_scale=1,
369
+ ).to(device, dtype=torch.float32)
370
+
371
+
372
+ high_latents.requires_grad = False
373
+ low_latents.requires_grad = False
374
+
375
+ loss_low = criteria(target_latents_low, low_noise.to(torch.float32))
376
+ pbar.set_description(f"Loss*1k: {loss_low.item()*1000:.4f}")
377
+ loss_low.backward()
378
+
379
+
380
+ optimizer.step()
381
+ lr_scheduler.step()
382
+
383
+ del (
384
+ high_latents,
385
+ low_latents,
386
+ target_latents_low,
387
+ target_latents_high,
388
+ )
389
+ flush()
390
+
391
+ if (
392
+ i % config.save.per_steps == 0
393
+ and i != 0
394
+ and i != config.train.iterations - 1
395
+ ):
396
+ print("Saving...")
397
+ save_path.mkdir(parents=True, exist_ok=True)
398
+ network.save_weights(
399
+ save_path / f"{config.save.name}_{i}steps.pt",
400
+ dtype=save_weight_dtype,
401
+ )
402
+
403
+ print("Saving...")
404
+ save_path.mkdir(parents=True, exist_ok=True)
405
+ network.save_weights(
406
+ save_path / f"{config.save.name}_last.pt",
407
+ dtype=save_weight_dtype,
408
+ )
409
+
410
+ del (
411
+ unet,
412
+ noise_scheduler,
413
+ loss,
414
+ optimizer,
415
+ network,
416
+ )
417
+
418
+ flush()
419
+
420
+ print("Done.")
421
+
422
+
423
+ def main(args):
424
+ config_file = args.config_file
425
+
426
+ config = config_util.load_config_from_yaml(config_file)
427
+ if args.name is not None:
428
+ config.save.name = args.name
429
+ attributes = []
430
+ if args.attributes is not None:
431
+ attributes = args.attributes.split(',')
432
+ attributes = [a.strip() for a in attributes]
433
+
434
+ config.network.alpha = args.alpha
435
+ config.network.rank = args.rank
436
+ config.save.name += f'_alpha{args.alpha}'
437
+ config.save.name += f'_rank{config.network.rank }'
438
+ config.save.name += f'_{config.network.training_method}'
439
+ config.save.path += f'/{config.save.name}'
440
+
441
+ prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes)
442
+
443
+ device = torch.device(f"cuda:{args.device}")
444
+
445
+ folders = args.folders.split(',')
446
+ folders = [f.strip() for f in folders]
447
+ scales = args.scales.split(',')
448
+ scales = [f.strip() for f in scales]
449
+ scales = [int(s) for s in scales]
450
+
451
+ print(folders, scales)
452
+ if len(scales) != len(folders):
453
+ raise Exception('the number of folders need to match the number of scales')
454
+
455
+ if args.stylecheck is not None:
456
+ check = args.stylecheck.split('-')
457
+
458
+ for i in range(int(check[0]), int(check[1])):
459
+ folder_main = args.folder_main+ f'{i}'
460
+ config.save.name = f'{os.path.basename(folder_main)}'
461
+ config.save.name += f'_alpha{args.alpha}'
462
+ config.save.name += f'_rank{config.network.rank }'
463
+ config.save.path = f'models/{config.save.name}'
464
+ train(config=config, prompts=prompts, device=device, folder_main = folder_main, folders = folders, scales = scales)
465
+ else:
466
+ train(config=config, prompts=prompts, device=device, folder_main = args.folder_main, folders = folders, scales = scales)
467
+
468
+
469
+ if __name__ == "__main__":
470
+ parser = argparse.ArgumentParser()
471
+ parser.add_argument(
472
+ "--config_file",
473
+ required=True,
474
+ help="Config file for training.",
475
+ )
476
+ # config_file 'data/config.yaml'
477
+ parser.add_argument(
478
+ "--alpha",
479
+ type=float,
480
+ required=True,
481
+ help="LoRA weight.",
482
+ )
483
+ # --alpha 1.0
484
+ parser.add_argument(
485
+ "--rank",
486
+ type=int,
487
+ required=False,
488
+ help="Rank of LoRA.",
489
+ default=4,
490
+ )
491
+ # --rank 4
492
+ parser.add_argument(
493
+ "--device",
494
+ type=int,
495
+ required=False,
496
+ default=0,
497
+ help="Device to train on.",
498
+ )
499
+ # --device 0
500
+ parser.add_argument(
501
+ "--name",
502
+ type=str,
503
+ required=False,
504
+ default=None,
505
+ help="Device to train on.",
506
+ )
507
+ # --name 'eyesize_slider'
508
+ parser.add_argument(
509
+ "--attributes",
510
+ type=str,
511
+ required=False,
512
+ default=None,
513
+ help="attritbutes to disentangle (comma seperated string)",
514
+ )
515
+ parser.add_argument(
516
+ "--folder_main",
517
+ type=str,
518
+ required=True,
519
+ help="The folder to check",
520
+ )
521
+
522
+ parser.add_argument(
523
+ "--stylecheck",
524
+ type=str,
525
+ required=False,
526
+ default = None,
527
+ help="The folder to check",
528
+ )
529
+
530
+ parser.add_argument(
531
+ "--folders",
532
+ type=str,
533
+ required=False,
534
+ default = 'verylow, low, high, veryhigh',
535
+ help="folders with different attribute-scaled images",
536
+ )
537
+ parser.add_argument(
538
+ "--scales",
539
+ type=str,
540
+ required=False,
541
+ default = '-2, -1, 1, 2',
542
+ help="scales for different attribute-scaled images",
543
+ )
544
+
545
+
546
+ args = parser.parse_args()
547
+
548
+ main(args)
trainscripts/imagesliders/train_lora-scale.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:
2
+ # - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566
3
+ # - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py
4
+
5
+ from typing import List, Optional
6
+ import argparse
7
+ import ast
8
+ from pathlib import Path
9
+ import gc
10
+
11
+ import torch
12
+ from tqdm import tqdm
13
+ import os, glob
14
+
15
+ from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
16
+ import train_util
17
+ import model_util
18
+ import prompt_util
19
+ from prompt_util import PromptEmbedsCache, PromptEmbedsPair, PromptSettings
20
+ import debug_util
21
+ import config_util
22
+ from config_util import RootConfig
23
+ import random
24
+ import numpy as np
25
+ import wandb
26
+ from PIL import Image
27
+
28
+ def flush():
29
+ torch.cuda.empty_cache()
30
+ gc.collect()
31
+ def prev_step(model_output, timestep, scheduler, sample):
32
+ prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
33
+ alpha_prod_t =scheduler.alphas_cumprod[timestep]
34
+ alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
35
+ beta_prod_t = 1 - alpha_prod_t
36
+ pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
37
+ pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
38
+ prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
39
+ return prev_sample
40
+
41
+ def train(
42
+ config: RootConfig,
43
+ prompts: list[PromptSettings],
44
+ device: int,
45
+ folder_main: str,
46
+ folders,
47
+ scales,
48
+ ):
49
+ scales = np.array(scales)
50
+ folders = np.array(folders)
51
+ scales_unique = list(scales)
52
+
53
+ metadata = {
54
+ "prompts": ",".join([prompt.json() for prompt in prompts]),
55
+ "config": config.json(),
56
+ }
57
+ save_path = Path(config.save.path)
58
+
59
+ modules = DEFAULT_TARGET_REPLACE
60
+ if config.network.type == "c3lier":
61
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
62
+
63
+ if config.logging.verbose:
64
+ print(metadata)
65
+
66
+ if config.logging.use_wandb:
67
+ wandb.init(project=f"LECO_{config.save.name}", config=metadata)
68
+
69
+ weight_dtype = config_util.parse_precision(config.train.precision)
70
+ save_weight_dtype = config_util.parse_precision(config.train.precision)
71
+
72
+ tokenizer, text_encoder, unet, noise_scheduler, vae = model_util.load_models(
73
+ config.pretrained_model.name_or_path,
74
+ scheduler_name=config.train.noise_scheduler,
75
+ v2=config.pretrained_model.v2,
76
+ v_pred=config.pretrained_model.v_pred,
77
+ )
78
+
79
+ text_encoder.to(device, dtype=weight_dtype)
80
+ text_encoder.eval()
81
+
82
+ unet.to(device, dtype=weight_dtype)
83
+ unet.enable_xformers_memory_efficient_attention()
84
+ unet.requires_grad_(False)
85
+ unet.eval()
86
+
87
+ vae.to(device)
88
+ vae.requires_grad_(False)
89
+ vae.eval()
90
+
91
+ network = LoRANetwork(
92
+ unet,
93
+ rank=config.network.rank,
94
+ multiplier=1.0,
95
+ alpha=config.network.alpha,
96
+ train_method=config.network.training_method,
97
+ ).to(device, dtype=weight_dtype)
98
+
99
+ optimizer_module = train_util.get_optimizer(config.train.optimizer)
100
+ #optimizer_args
101
+ optimizer_kwargs = {}
102
+ if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0:
103
+ for arg in config.train.optimizer_args.split(" "):
104
+ key, value = arg.split("=")
105
+ value = ast.literal_eval(value)
106
+ optimizer_kwargs[key] = value
107
+
108
+ optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs)
109
+ lr_scheduler = train_util.get_lr_scheduler(
110
+ config.train.lr_scheduler,
111
+ optimizer,
112
+ max_iterations=config.train.iterations,
113
+ lr_min=config.train.lr / 100,
114
+ )
115
+ criteria = torch.nn.MSELoss()
116
+
117
+ print("Prompts")
118
+ for settings in prompts:
119
+ print(settings)
120
+
121
+ # debug
122
+ debug_util.check_requires_grad(network)
123
+ debug_util.check_training_mode(network)
124
+
125
+ cache = PromptEmbedsCache()
126
+ prompt_pairs: list[PromptEmbedsPair] = []
127
+
128
+ with torch.no_grad():
129
+ for settings in prompts:
130
+ print(settings)
131
+ for prompt in [
132
+ settings.target,
133
+ settings.positive,
134
+ settings.neutral,
135
+ settings.unconditional,
136
+ ]:
137
+ print(prompt)
138
+ if isinstance(prompt, list):
139
+ if prompt == settings.positive:
140
+ key_setting = 'positive'
141
+ else:
142
+ key_setting = 'attributes'
143
+ if len(prompt) == 0:
144
+ cache[key_setting] = []
145
+ else:
146
+ if cache[key_setting] is None:
147
+ cache[key_setting] = train_util.encode_prompts(
148
+ tokenizer, text_encoder, prompt
149
+ )
150
+ else:
151
+ if cache[prompt] == None:
152
+ cache[prompt] = train_util.encode_prompts(
153
+ tokenizer, text_encoder, [prompt]
154
+ )
155
+
156
+ prompt_pairs.append(
157
+ PromptEmbedsPair(
158
+ criteria,
159
+ cache[settings.target],
160
+ cache[settings.positive],
161
+ cache[settings.unconditional],
162
+ cache[settings.neutral],
163
+ settings,
164
+ )
165
+ )
166
+
167
+ del tokenizer
168
+ del text_encoder
169
+
170
+ flush()
171
+
172
+ pbar = tqdm(range(config.train.iterations))
173
+ for i in pbar:
174
+ with torch.no_grad():
175
+ noise_scheduler.set_timesteps(
176
+ config.train.max_denoising_steps, device=device
177
+ )
178
+
179
+ optimizer.zero_grad()
180
+
181
+ prompt_pair: PromptEmbedsPair = prompt_pairs[
182
+ torch.randint(0, len(prompt_pairs), (1,)).item()
183
+ ]
184
+
185
+ # 1 ~ 49 からランダム
186
+ timesteps_to = torch.randint(
187
+ 1, config.train.max_denoising_steps-1, (1,)
188
+ # 1, 25, (1,)
189
+ ).item()
190
+
191
+ height, width = (
192
+ prompt_pair.resolution,
193
+ prompt_pair.resolution,
194
+ )
195
+ if prompt_pair.dynamic_resolution:
196
+ height, width = train_util.get_random_resolution_in_bucket(
197
+ prompt_pair.resolution
198
+ )
199
+
200
+ if config.logging.verbose:
201
+ print("guidance_scale:", prompt_pair.guidance_scale)
202
+ print("resolution:", prompt_pair.resolution)
203
+ print("dynamic_resolution:", prompt_pair.dynamic_resolution)
204
+ if prompt_pair.dynamic_resolution:
205
+ print("bucketed resolution:", (height, width))
206
+ print("batch_size:", prompt_pair.batch_size)
207
+
208
+
209
+
210
+
211
+ scale_to_look = abs(random.choice(list(scales_unique)))
212
+ folder1 = folders[scales==-scale_to_look][0]
213
+ folder2 = folders[scales==scale_to_look][0]
214
+
215
+ ims = os.listdir(f'{folder_main}/{folder1}/')
216
+ ims = [im_ for im_ in ims if '.png' in im_ or '.jpg' in im_ or '.jpeg' in im_ or '.webp' in im_]
217
+ random_sampler = random.randint(0, len(ims)-1)
218
+
219
+ img1 = Image.open(f'{folder_main}/{folder1}/{ims[random_sampler]}').resize((256,256))
220
+ img2 = Image.open(f'{folder_main}/{folder2}/{ims[random_sampler]}').resize((256,256))
221
+
222
+ seed = random.randint(0,2*15)
223
+
224
+ generator = torch.manual_seed(seed)
225
+ denoised_latents_low, low_noise = train_util.get_noisy_image(
226
+ img1,
227
+ vae,
228
+ generator,
229
+ unet,
230
+ noise_scheduler,
231
+ start_timesteps=0,
232
+ total_timesteps=timesteps_to)
233
+ denoised_latents_low = denoised_latents_low.to(device, dtype=weight_dtype)
234
+ low_noise = low_noise.to(device, dtype=weight_dtype)
235
+
236
+ generator = torch.manual_seed(seed)
237
+ denoised_latents_high, high_noise = train_util.get_noisy_image(
238
+ img2,
239
+ vae,
240
+ generator,
241
+ unet,
242
+ noise_scheduler,
243
+ start_timesteps=0,
244
+ total_timesteps=timesteps_to)
245
+ denoised_latents_high = denoised_latents_high.to(device, dtype=weight_dtype)
246
+ high_noise = high_noise.to(device, dtype=weight_dtype)
247
+ noise_scheduler.set_timesteps(1000)
248
+
249
+ current_timestep = noise_scheduler.timesteps[
250
+ int(timesteps_to * 1000 / config.train.max_denoising_steps)
251
+ ]
252
+
253
+ # with network: の外では空のLoRAのみが有効になる
254
+ high_latents = train_util.predict_noise(
255
+ unet,
256
+ noise_scheduler,
257
+ current_timestep,
258
+ denoised_latents_high,
259
+ train_util.concat_embeddings(
260
+ prompt_pair.unconditional,
261
+ prompt_pair.positive,
262
+ prompt_pair.batch_size,
263
+ ),
264
+ guidance_scale=1,
265
+ ).to("cpu", dtype=torch.float32)
266
+ # with network: の外では空のLoRAのみが有効になる
267
+ low_latents = train_util.predict_noise(
268
+ unet,
269
+ noise_scheduler,
270
+ current_timestep,
271
+ denoised_latents_low,
272
+ train_util.concat_embeddings(
273
+ prompt_pair.unconditional,
274
+ prompt_pair.unconditional,
275
+ prompt_pair.batch_size,
276
+ ),
277
+ guidance_scale=1,
278
+ ).to("cpu", dtype=torch.float32)
279
+ if config.logging.verbose:
280
+ print("positive_latents:", positive_latents[0, 0, :5, :5])
281
+ print("neutral_latents:", neutral_latents[0, 0, :5, :5])
282
+ print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
283
+
284
+ network.set_lora_slider(scale=scale_to_look)
285
+ with network:
286
+ target_latents_high = train_util.predict_noise(
287
+ unet,
288
+ noise_scheduler,
289
+ current_timestep,
290
+ denoised_latents_high,
291
+ train_util.concat_embeddings(
292
+ prompt_pair.unconditional,
293
+ prompt_pair.positive,
294
+ prompt_pair.batch_size,
295
+ ),
296
+ guidance_scale=1,
297
+ ).to("cpu", dtype=torch.float32)
298
+
299
+
300
+ high_latents.requires_grad = False
301
+ low_latents.requires_grad = False
302
+
303
+ loss_high = criteria(target_latents_high, high_noise.cpu().to(torch.float32))
304
+ pbar.set_description(f"Loss*1k: {loss_high.item()*1000:.4f}")
305
+ loss_high.backward()
306
+
307
+
308
+ network.set_lora_slider(scale=-scale_to_look)
309
+ with network:
310
+ target_latents_low = train_util.predict_noise(
311
+ unet,
312
+ noise_scheduler,
313
+ current_timestep,
314
+ denoised_latents_low,
315
+ train_util.concat_embeddings(
316
+ prompt_pair.unconditional,
317
+ prompt_pair.neutral,
318
+ prompt_pair.batch_size,
319
+ ),
320
+ guidance_scale=1,
321
+ ).to("cpu", dtype=torch.float32)
322
+
323
+
324
+ high_latents.requires_grad = False
325
+ low_latents.requires_grad = False
326
+
327
+ loss_low = criteria(target_latents_low, low_noise.cpu().to(torch.float32))
328
+ pbar.set_description(f"Loss*1k: {loss_low.item()*1000:.4f}")
329
+ loss_low.backward()
330
+
331
+ ## NOTICE NO zero_grad between these steps (accumulating gradients)
332
+ #following guidelines from Ostris (https://github.com/ostris/ai-toolkit)
333
+
334
+ optimizer.step()
335
+ lr_scheduler.step()
336
+
337
+ del (
338
+ high_latents,
339
+ low_latents,
340
+ target_latents_low,
341
+ target_latents_high,
342
+ )
343
+ flush()
344
+
345
+ if (
346
+ i % config.save.per_steps == 0
347
+ and i != 0
348
+ and i != config.train.iterations - 1
349
+ ):
350
+ print("Saving...")
351
+ save_path.mkdir(parents=True, exist_ok=True)
352
+ network.save_weights(
353
+ save_path / f"{config.save.name}_{i}steps.pt",
354
+ dtype=save_weight_dtype,
355
+ )
356
+
357
+ print("Saving...")
358
+ save_path.mkdir(parents=True, exist_ok=True)
359
+ network.save_weights(
360
+ save_path / f"{config.save.name}_last.pt",
361
+ dtype=save_weight_dtype,
362
+ )
363
+
364
+ del (
365
+ unet,
366
+ noise_scheduler,
367
+ optimizer,
368
+ network,
369
+ )
370
+
371
+ flush()
372
+
373
+ print("Done.")
374
+
375
+
376
+ def main(args):
377
+ config_file = args.config_file
378
+
379
+ config = config_util.load_config_from_yaml(config_file)
380
+ if args.name is not None:
381
+ config.save.name = args.name
382
+ attributes = []
383
+ if args.attributes is not None:
384
+ attributes = args.attributes.split(',')
385
+ attributes = [a.strip() for a in attributes]
386
+
387
+ config.network.alpha = args.alpha
388
+ config.network.rank = args.rank
389
+ config.save.name += f'_alpha{args.alpha}'
390
+ config.save.name += f'_rank{config.network.rank }'
391
+ config.save.name += f'_{config.network.training_method}'
392
+ config.save.path += f'/{config.save.name}'
393
+
394
+ prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes)
395
+ device = torch.device(f"cuda:{args.device}")
396
+
397
+
398
+ folders = args.folders.split(',')
399
+ folders = [f.strip() for f in folders]
400
+ scales = args.scales.split(',')
401
+ scales = [f.strip() for f in scales]
402
+ scales = [int(s) for s in scales]
403
+
404
+ print(folders, scales)
405
+ if len(scales) != len(folders):
406
+ raise Exception('the number of folders need to match the number of scales')
407
+
408
+ if args.stylecheck is not None:
409
+ check = args.stylecheck.split('-')
410
+
411
+ for i in range(int(check[0]), int(check[1])):
412
+ folder_main = args.folder_main+ f'{i}'
413
+ config.save.name = f'{os.path.basename(folder_main)}'
414
+ config.save.name += f'_alpha{args.alpha}'
415
+ config.save.name += f'_rank{config.network.rank }'
416
+ config.save.path = f'models/{config.save.name}'
417
+ train(config=config, prompts=prompts, device=device, folder_main = folder_main)
418
+ else:
419
+ train(config=config, prompts=prompts, device=device, folder_main = args.folder_main, folders = folders, scales = scales)
420
+
421
+ if __name__ == "__main__":
422
+ parser = argparse.ArgumentParser()
423
+ parser.add_argument(
424
+ "--config_file",
425
+ required=False,
426
+ default = 'data/config.yaml',
427
+ help="Config file for training.",
428
+ )
429
+ parser.add_argument(
430
+ "--alpha",
431
+ type=float,
432
+ required=True,
433
+ help="LoRA weight.",
434
+ )
435
+
436
+ parser.add_argument(
437
+ "--rank",
438
+ type=int,
439
+ required=False,
440
+ help="Rank of LoRA.",
441
+ default=4,
442
+ )
443
+
444
+ parser.add_argument(
445
+ "--device",
446
+ type=int,
447
+ required=False,
448
+ default=0,
449
+ help="Device to train on.",
450
+ )
451
+
452
+ parser.add_argument(
453
+ "--name",
454
+ type=str,
455
+ required=False,
456
+ default=None,
457
+ help="Device to train on.",
458
+ )
459
+
460
+ parser.add_argument(
461
+ "--attributes",
462
+ type=str,
463
+ required=False,
464
+ default=None,
465
+ help="attritbutes to disentangle",
466
+ )
467
+
468
+ parser.add_argument(
469
+ "--folder_main",
470
+ type=str,
471
+ required=True,
472
+ help="The folder to check",
473
+ )
474
+
475
+ parser.add_argument(
476
+ "--stylecheck",
477
+ type=str,
478
+ required=False,
479
+ default = None,
480
+ help="The folder to check",
481
+ )
482
+
483
+ parser.add_argument(
484
+ "--folders",
485
+ type=str,
486
+ required=False,
487
+ default = 'verylow, low, high, veryhigh',
488
+ help="folders with different attribute-scaled images",
489
+ )
490
+ parser.add_argument(
491
+ "--scales",
492
+ type=str,
493
+ required=False,
494
+ default = '-2, -1,1, 2',
495
+ help="scales for different attribute-scaled images",
496
+ )
497
+
498
+
499
+ args = parser.parse_args()
500
+
501
+ main(args)
trainscripts/imagesliders/train_util.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+
5
+ from transformers import CLIPTextModel, CLIPTokenizer
6
+ from diffusers import UNet2DConditionModel, SchedulerMixin
7
+ from diffusers.image_processor import VaeImageProcessor
8
+ from model_util import SDXL_TEXT_ENCODER_TYPE
9
+ from diffusers.utils import randn_tensor
10
+
11
+ from tqdm import tqdm
12
+
13
+ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
14
+ VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
15
+
16
+ UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL
17
+ TEXT_ENCODER_2_PROJECTION_DIM = 1280
18
+ UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816
19
+
20
+
21
+ def get_random_noise(
22
+ batch_size: int, height: int, width: int, generator: torch.Generator = None
23
+ ) -> torch.Tensor:
24
+ return torch.randn(
25
+ (
26
+ batch_size,
27
+ UNET_IN_CHANNELS,
28
+ height // VAE_SCALE_FACTOR, # 縦と横これであってるのかわからないけど、どっちにしろ大きな問題は発生しないのでこれでいいや
29
+ width // VAE_SCALE_FACTOR,
30
+ ),
31
+ generator=generator,
32
+ device="cpu",
33
+ )
34
+
35
+
36
+ # https://www.crosslabs.org/blog/diffusion-with-offset-noise
37
+ def apply_noise_offset(latents: torch.FloatTensor, noise_offset: float):
38
+ latents = latents + noise_offset * torch.randn(
39
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
40
+ )
41
+ return latents
42
+
43
+
44
+ def get_initial_latents(
45
+ scheduler: SchedulerMixin,
46
+ n_imgs: int,
47
+ height: int,
48
+ width: int,
49
+ n_prompts: int,
50
+ generator=None,
51
+ ) -> torch.Tensor:
52
+ noise = get_random_noise(n_imgs, height, width, generator=generator).repeat(
53
+ n_prompts, 1, 1, 1
54
+ )
55
+
56
+ latents = noise * scheduler.init_noise_sigma
57
+
58
+ return latents
59
+
60
+
61
+ def text_tokenize(
62
+ tokenizer: CLIPTokenizer, # 普通ならひとつ、XLならふたつ!
63
+ prompts: list[str],
64
+ ):
65
+ return tokenizer(
66
+ prompts,
67
+ padding="max_length",
68
+ max_length=tokenizer.model_max_length,
69
+ truncation=True,
70
+ return_tensors="pt",
71
+ ).input_ids
72
+
73
+
74
+ def text_encode(text_encoder: CLIPTextModel, tokens):
75
+ return text_encoder(tokens.to(text_encoder.device))[0]
76
+
77
+
78
+ def encode_prompts(
79
+ tokenizer: CLIPTokenizer,
80
+ text_encoder: CLIPTokenizer,
81
+ prompts: list[str],
82
+ ):
83
+
84
+ text_tokens = text_tokenize(tokenizer, prompts)
85
+ text_embeddings = text_encode(text_encoder, text_tokens)
86
+
87
+
88
+
89
+ return text_embeddings
90
+
91
+
92
+ # https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
93
+ def text_encode_xl(
94
+ text_encoder: SDXL_TEXT_ENCODER_TYPE,
95
+ tokens: torch.FloatTensor,
96
+ num_images_per_prompt: int = 1,
97
+ ):
98
+ prompt_embeds = text_encoder(
99
+ tokens.to(text_encoder.device), output_hidden_states=True
100
+ )
101
+ pooled_prompt_embeds = prompt_embeds[0]
102
+ prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
103
+
104
+ bs_embed, seq_len, _ = prompt_embeds.shape
105
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
106
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
107
+
108
+ return prompt_embeds, pooled_prompt_embeds
109
+
110
+
111
+ def encode_prompts_xl(
112
+ tokenizers: list[CLIPTokenizer],
113
+ text_encoders: list[SDXL_TEXT_ENCODER_TYPE],
114
+ prompts: list[str],
115
+ num_images_per_prompt: int = 1,
116
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
117
+ # text_encoder and text_encoder_2's penuultimate layer's output
118
+ text_embeds_list = []
119
+ pooled_text_embeds = None # always text_encoder_2's pool
120
+
121
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
122
+ text_tokens_input_ids = text_tokenize(tokenizer, prompts)
123
+ text_embeds, pooled_text_embeds = text_encode_xl(
124
+ text_encoder, text_tokens_input_ids, num_images_per_prompt
125
+ )
126
+
127
+ text_embeds_list.append(text_embeds)
128
+
129
+ bs_embed = pooled_text_embeds.shape[0]
130
+ pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view(
131
+ bs_embed * num_images_per_prompt, -1
132
+ )
133
+
134
+ return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
135
+
136
+
137
+ def concat_embeddings(
138
+ unconditional: torch.FloatTensor,
139
+ conditional: torch.FloatTensor,
140
+ n_imgs: int,
141
+ ):
142
+ return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0)
143
+
144
+
145
+ # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L721
146
+ def predict_noise(
147
+ unet: UNet2DConditionModel,
148
+ scheduler: SchedulerMixin,
149
+ timestep: int, # 現在のタイムステップ
150
+ latents: torch.FloatTensor,
151
+ text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
152
+ guidance_scale=7.5,
153
+ ) -> torch.FloatTensor:
154
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
155
+ latent_model_input = torch.cat([latents] * 2)
156
+
157
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
158
+
159
+ # predict the noise residual
160
+ noise_pred = unet(
161
+ latent_model_input,
162
+ timestep,
163
+ encoder_hidden_states=text_embeddings,
164
+ ).sample
165
+
166
+ # perform guidance
167
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
168
+ guided_target = noise_pred_uncond + guidance_scale * (
169
+ noise_pred_text - noise_pred_uncond
170
+ )
171
+
172
+ return guided_target
173
+
174
+
175
+
176
+ # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
177
+ @torch.no_grad()
178
+ def diffusion(
179
+ unet: UNet2DConditionModel,
180
+ scheduler: SchedulerMixin,
181
+ latents: torch.FloatTensor, # ただのノイズだけのlatents
182
+ text_embeddings: torch.FloatTensor,
183
+ total_timesteps: int = 1000,
184
+ start_timesteps=0,
185
+ **kwargs,
186
+ ):
187
+ # latents_steps = []
188
+
189
+ for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]):
190
+ noise_pred = predict_noise(
191
+ unet, scheduler, timestep, latents, text_embeddings, **kwargs
192
+ )
193
+
194
+ # compute the previous noisy sample x_t -> x_t-1
195
+ latents = scheduler.step(noise_pred, timestep, latents).prev_sample
196
+
197
+ # return latents_steps
198
+ return latents
199
+
200
+ @torch.no_grad()
201
+ def get_noisy_image(
202
+ img,
203
+ vae,
204
+ generator,
205
+ unet: UNet2DConditionModel,
206
+ scheduler: SchedulerMixin,
207
+ total_timesteps: int = 1000,
208
+ start_timesteps=0,
209
+
210
+ **kwargs,
211
+ ):
212
+ # latents_steps = []
213
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
214
+ image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
215
+
216
+ image = img
217
+ im_orig = image
218
+ device = vae.device
219
+ image = image_processor.preprocess(image).to(device)
220
+
221
+ init_latents = vae.encode(image).latent_dist.sample(None)
222
+ init_latents = vae.config.scaling_factor * init_latents
223
+
224
+ init_latents = torch.cat([init_latents], dim=0)
225
+
226
+ shape = init_latents.shape
227
+
228
+ noise = randn_tensor(shape, generator=generator, device=device)
229
+
230
+ time_ = total_timesteps
231
+ timestep = scheduler.timesteps[time_:time_+1]
232
+ # get latents
233
+ init_latents = scheduler.add_noise(init_latents, noise, timestep)
234
+
235
+ return init_latents, noise
236
+
237
+
238
+ def rescale_noise_cfg(
239
+ noise_cfg: torch.FloatTensor, noise_pred_text, guidance_rescale=0.0
240
+ ):
241
+ """
242
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
243
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
244
+ """
245
+ std_text = noise_pred_text.std(
246
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
247
+ )
248
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
249
+ # rescale the results from guidance (fixes overexposure)
250
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
251
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
252
+ noise_cfg = (
253
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
254
+ )
255
+
256
+ return noise_cfg
257
+
258
+
259
+ def predict_noise_xl(
260
+ unet: UNet2DConditionModel,
261
+ scheduler: SchedulerMixin,
262
+ timestep: int, # 現在のタイムステップ
263
+ latents: torch.FloatTensor,
264
+ text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
265
+ add_text_embeddings: torch.FloatTensor, # pooled なやつ
266
+ add_time_ids: torch.FloatTensor,
267
+ guidance_scale=7.5,
268
+ guidance_rescale=0.7,
269
+ ) -> torch.FloatTensor:
270
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
271
+ latent_model_input = torch.cat([latents] * 2)
272
+
273
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
274
+
275
+ added_cond_kwargs = {
276
+ "text_embeds": add_text_embeddings,
277
+ "time_ids": add_time_ids,
278
+ }
279
+
280
+ # predict the noise residual
281
+ noise_pred = unet(
282
+ latent_model_input,
283
+ timestep,
284
+ encoder_hidden_states=text_embeddings,
285
+ added_cond_kwargs=added_cond_kwargs,
286
+ ).sample
287
+
288
+ # perform guidance
289
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
290
+ guided_target = noise_pred_uncond + guidance_scale * (
291
+ noise_pred_text - noise_pred_uncond
292
+ )
293
+
294
+ # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
295
+ noise_pred = rescale_noise_cfg(
296
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
297
+ )
298
+
299
+ return guided_target
300
+
301
+
302
+ @torch.no_grad()
303
+ def diffusion_xl(
304
+ unet: UNet2DConditionModel,
305
+ scheduler: SchedulerMixin,
306
+ latents: torch.FloatTensor, # ただのノイズだけのlatents
307
+ text_embeddings: tuple[torch.FloatTensor, torch.FloatTensor],
308
+ add_text_embeddings: torch.FloatTensor, # pooled なやつ
309
+ add_time_ids: torch.FloatTensor,
310
+ guidance_scale: float = 1.0,
311
+ total_timesteps: int = 1000,
312
+ start_timesteps=0,
313
+ ):
314
+ # latents_steps = []
315
+
316
+ for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]):
317
+ noise_pred = predict_noise_xl(
318
+ unet,
319
+ scheduler,
320
+ timestep,
321
+ latents,
322
+ text_embeddings,
323
+ add_text_embeddings,
324
+ add_time_ids,
325
+ guidance_scale=guidance_scale,
326
+ guidance_rescale=0.7,
327
+ )
328
+
329
+ # compute the previous noisy sample x_t -> x_t-1
330
+ latents = scheduler.step(noise_pred, timestep, latents).prev_sample
331
+
332
+ # return latents_steps
333
+ return latents
334
+
335
+
336
+ # for XL
337
+ def get_add_time_ids(
338
+ height: int,
339
+ width: int,
340
+ dynamic_crops: bool = False,
341
+ dtype: torch.dtype = torch.float32,
342
+ ):
343
+ if dynamic_crops:
344
+ # random float scale between 1 and 3
345
+ random_scale = torch.rand(1).item() * 2 + 1
346
+ original_size = (int(height * random_scale), int(width * random_scale))
347
+ # random position
348
+ crops_coords_top_left = (
349
+ torch.randint(0, original_size[0] - height, (1,)).item(),
350
+ torch.randint(0, original_size[1] - width, (1,)).item(),
351
+ )
352
+ target_size = (height, width)
353
+ else:
354
+ original_size = (height, width)
355
+ crops_coords_top_left = (0, 0)
356
+ target_size = (height, width)
357
+
358
+ # this is expected as 6
359
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
360
+
361
+ # this is expected as 2816
362
+ passed_add_embed_dim = (
363
+ UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6
364
+ + TEXT_ENCODER_2_PROJECTION_DIM # + 1280
365
+ )
366
+ if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM:
367
+ raise ValueError(
368
+ f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
369
+ )
370
+
371
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
372
+ return add_time_ids
373
+
374
+
375
+ def get_optimizer(name: str):
376
+ name = name.lower()
377
+
378
+ if name.startswith("dadapt"):
379
+ import dadaptation
380
+
381
+ if name == "dadaptadam":
382
+ return dadaptation.DAdaptAdam
383
+ elif name == "dadaptlion":
384
+ return dadaptation.DAdaptLion
385
+ else:
386
+ raise ValueError("DAdapt optimizer must be dadaptadam or dadaptlion")
387
+
388
+ elif name.endswith("8bit"): # 検証してない
389
+ import bitsandbytes as bnb
390
+
391
+ if name == "adam8bit":
392
+ return bnb.optim.Adam8bit
393
+ elif name == "lion8bit":
394
+ return bnb.optim.Lion8bit
395
+ else:
396
+ raise ValueError("8bit optimizer must be adam8bit or lion8bit")
397
+
398
+ else:
399
+ if name == "adam":
400
+ return torch.optim.Adam
401
+ elif name == "adamw":
402
+ return torch.optim.AdamW
403
+ elif name == "lion":
404
+ from lion_pytorch import Lion
405
+
406
+ return Lion
407
+ elif name == "prodigy":
408
+ import prodigyopt
409
+
410
+ return prodigyopt.Prodigy
411
+ else:
412
+ raise ValueError("Optimizer must be adam, adamw, lion or Prodigy")
413
+
414
+
415
+ def get_lr_scheduler(
416
+ name: Optional[str],
417
+ optimizer: torch.optim.Optimizer,
418
+ max_iterations: Optional[int],
419
+ lr_min: Optional[float],
420
+ **kwargs,
421
+ ):
422
+ if name == "cosine":
423
+ return torch.optim.lr_scheduler.CosineAnnealingLR(
424
+ optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
425
+ )
426
+ elif name == "cosine_with_restarts":
427
+ return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
428
+ optimizer, T_0=max_iterations // 10, T_mult=2, eta_min=lr_min, **kwargs
429
+ )
430
+ elif name == "step":
431
+ return torch.optim.lr_scheduler.StepLR(
432
+ optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
433
+ )
434
+ elif name == "constant":
435
+ return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
436
+ elif name == "linear":
437
+ return torch.optim.lr_scheduler.LinearLR(
438
+ optimizer, factor=0.5, total_iters=max_iterations // 100, **kwargs
439
+ )
440
+ else:
441
+ raise ValueError(
442
+ "Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
443
+ )
444
+
445
+
446
+ def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> tuple[int, int]:
447
+ max_resolution = bucket_resolution
448
+ min_resolution = bucket_resolution // 2
449
+
450
+ step = 64
451
+
452
+ min_step = min_resolution // step
453
+ max_step = max_resolution // step
454
+
455
+ height = torch.randint(min_step, max_step, (1,)).item() * step
456
+ width = torch.randint(min_step, max_step, (1,)).item() * step
457
+
458
+ return height, width
trainscripts/textsliders/__init__.py ADDED
File without changes
trainscripts/textsliders/config_util.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional
2
+
3
+ import yaml
4
+
5
+ from pydantic import BaseModel
6
+ import torch
7
+
8
+ from trainscripts.textsliders.lora import TRAINING_METHODS
9
+
10
+ PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"]
11
+ NETWORK_TYPES = Literal["lierla", "c3lier"]
12
+
13
+
14
+ class PretrainedModelConfig(BaseModel):
15
+ name_or_path: str
16
+ v2: bool = False
17
+ v_pred: bool = False
18
+
19
+ clip_skip: Optional[int] = None
20
+
21
+
22
+ class NetworkConfig(BaseModel):
23
+ type: NETWORK_TYPES = "lierla"
24
+ rank: int = 4
25
+ alpha: float = 1.0
26
+
27
+ training_method: TRAINING_METHODS = "full"
28
+
29
+
30
+ class TrainConfig(BaseModel):
31
+ precision: PRECISION_TYPES = "bfloat16"
32
+ noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim"
33
+
34
+ iterations: int = 500
35
+ lr: float = 1e-4
36
+ optimizer: str = "adamw"
37
+ optimizer_args: str = ""
38
+ lr_scheduler: str = "constant"
39
+
40
+ max_denoising_steps: int = 50
41
+
42
+
43
+ class SaveConfig(BaseModel):
44
+ name: str = "untitled"
45
+ path: str = "./output"
46
+ per_steps: int = 200
47
+ precision: PRECISION_TYPES = "float32"
48
+
49
+
50
+ class LoggingConfig(BaseModel):
51
+ use_wandb: bool = False
52
+
53
+ verbose: bool = False
54
+
55
+
56
+ class OtherConfig(BaseModel):
57
+ use_xformers: bool = False
58
+
59
+
60
+ class RootConfig(BaseModel):
61
+ prompts_file: str
62
+ pretrained_model: PretrainedModelConfig
63
+
64
+ network: NetworkConfig
65
+
66
+ train: Optional[TrainConfig]
67
+
68
+ save: Optional[SaveConfig]
69
+
70
+ logging: Optional[LoggingConfig]
71
+
72
+ other: Optional[OtherConfig]
73
+
74
+
75
+ def parse_precision(precision: str) -> torch.dtype:
76
+ if precision == "fp32" or precision == "float32":
77
+ return torch.float32
78
+ elif precision == "fp16" or precision == "float16":
79
+ return torch.float16
80
+ elif precision == "bf16" or precision == "bfloat16":
81
+ return torch.bfloat16
82
+
83
+ raise ValueError(f"Invalid precision type: {precision}")
84
+
85
+
86
+ def load_config_from_yaml(config_path: str) -> RootConfig:
87
+ with open(config_path, "r") as f:
88
+ config = yaml.load(f, Loader=yaml.FullLoader)
89
+
90
+ root = RootConfig(**config)
91
+
92
+ if root.train is None:
93
+ root.train = TrainConfig()
94
+
95
+ if root.save is None:
96
+ root.save = SaveConfig()
97
+
98
+ if root.logging is None:
99
+ root.logging = LoggingConfig()
100
+
101
+ if root.other is None:
102
+ root.other = OtherConfig()
103
+
104
+ return root
trainscripts/textsliders/data/config-xl.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompts_file: "trainscripts/textsliders/data/prompts-xl.yaml"
2
+ pretrained_model:
3
+ name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" # you can also use .ckpt or .safetensors models
4
+ v2: false # true if model is v2.x
5
+ v_pred: false # true if model uses v-prediction
6
+ network:
7
+ type: "c3lier" # or "c3lier" or "lierla"
8
+ rank: 4
9
+ alpha: 1.0
10
+ training_method: "noxattn"
11
+ train:
12
+ precision: "bfloat16"
13
+ noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
14
+ iterations: 1000
15
+ lr: 0.0002
16
+ optimizer: "AdamW"
17
+ lr_scheduler: "constant"
18
+ max_denoising_steps: 50
19
+ save:
20
+ name: "temp"
21
+ path: "./models"
22
+ per_steps: 5000000
23
+ precision: "bfloat16"
24
+ logging:
25
+ use_wandb: false
26
+ verbose: false
27
+ other:
28
+ use_xformers: true
trainscripts/textsliders/data/config.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompts_file: "trainscripts/textsliders/data/prompts.yaml"
2
+ pretrained_model:
3
+ name_or_path: "CompVis/stable-diffusion-v1-4" # you can also use .ckpt or .safetensors models
4
+ v2: false # true if model is v2.x
5
+ v_pred: false # true if model uses v-prediction
6
+ network:
7
+ type: "c3lier" # or "c3lier" or "lierla"
8
+ rank: 4
9
+ alpha: 1.0
10
+ training_method: "noxattn"
11
+ train:
12
+ precision: "bfloat16"
13
+ noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
14
+ iterations: 1000
15
+ lr: 0.0002
16
+ optimizer: "AdamW"
17
+ lr_scheduler: "constant"
18
+ max_denoising_steps: 50
19
+ save:
20
+ name: "temp"
21
+ path: "./models"
22
+ per_steps: 500
23
+ precision: "bfloat16"
24
+ logging:
25
+ use_wandb: false
26
+ verbose: false
27
+ other:
28
+ use_xformers: true
trainscripts/textsliders/data/prompts-xl.yaml ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - target: "" # what word for erasing the positive concept from
2
+ positive: "" # concept to erase
3
+ unconditional: "" # word to take the difference from the positive concept
4
+ neutral: "" # starting point for conditioning the target
5
+ action: "enhance" # erase or enhance
6
+ guidance_scale: 4
7
+ resolution: 512
8
+ dynamic_resolution: false
9
+ batch_size: 1
10
+ ####################################################################################################### AGE SLIDER
11
+ # - target: "male person" # what word for erasing the positive concept from
12
+ # positive: "male person, very old" # concept to erase
13
+ # unconditional: "male person, very young" # word to take the difference from the positive concept
14
+ # neutral: "male person" # starting point for conditioning the target
15
+ # action: "enhance" # erase or enhance
16
+ # guidance_scale: 4
17
+ # resolution: 512
18
+ # dynamic_resolution: false
19
+ # batch_size: 1
20
+ # - target: "female person" # what word for erasing the positive concept from
21
+ # positive: "female person, very old" # concept to erase
22
+ # unconditional: "female person, very young" # word to take the difference from the positive concept
23
+ # neutral: "female person" # starting point for conditioning the target
24
+ # action: "enhance" # erase or enhance
25
+ # guidance_scale: 4
26
+ # resolution: 512
27
+ # dynamic_resolution: false
28
+ # batch_size: 1
29
+ ####################################################################################################### MUSCULAR SLIDER
30
+ # - target: "male person" # what word for erasing the positive concept from
31
+ # positive: "male person, muscular, strong, biceps, greek god physique, body builder" # concept to erase
32
+ # unconditional: "male person, lean, thin, weak, slender, skinny, scrawny" # word to take the difference from the positive concept
33
+ # neutral: "male person" # starting point for conditioning the target
34
+ # action: "enhance" # erase or enhance
35
+ # guidance_scale: 4
36
+ # resolution: 512
37
+ # dynamic_resolution: false
38
+ # batch_size: 1
39
+ # - target: "female person" # what word for erasing the positive concept from
40
+ # positive: "female person, muscular, strong, biceps, greek god physique, body builder" # concept to erase
41
+ # unconditional: "female person, lean, thin, weak, slender, skinny, scrawny" # word to take the difference from the positive concept
42
+ # neutral: "female person" # starting point for conditioning the target
43
+ # action: "enhance" # erase or enhance
44
+ # guidance_scale: 4
45
+ # resolution: 512
46
+ # dynamic_resolution: false
47
+ # batch_size: 1
48
+ ####################################################################################################### CURLY HAIR SLIDER
49
+ # - target: "male person" # what word for erasing the positive concept from
50
+ # positive: "male person, curly hair, wavy hair" # concept to erase
51
+ # unconditional: "male person, straight hair" # word to take the difference from the positive concept
52
+ # neutral: "male person" # starting point for conditioning the target
53
+ # action: "enhance" # erase or enhance
54
+ # guidance_scale: 4
55
+ # resolution: 512
56
+ # dynamic_resolution: false
57
+ # batch_size: 1
58
+ # - target: "female person" # what word for erasing the positive concept from
59
+ # positive: "female person, curly hair, wavy hair" # concept to erase
60
+ # unconditional: "female person, straight hair" # word to take the difference from the positive concept
61
+ # neutral: "female person" # starting point for conditioning the target
62
+ # action: "enhance" # erase or enhance
63
+ # guidance_scale: 4
64
+ # resolution: 512
65
+ # dynamic_resolution: false
66
+ # batch_size: 1
67
+ ####################################################################################################### BEARD SLIDER
68
+ # - target: "male person" # what word for erasing the positive concept from
69
+ # positive: "male person, with beard" # concept to erase
70
+ # unconditional: "male person, clean shaven" # word to take the difference from the positive concept
71
+ # neutral: "male person" # starting point for conditioning the target
72
+ # action: "enhance" # erase or enhance
73
+ # guidance_scale: 4
74
+ # resolution: 512
75
+ # dynamic_resolution: false
76
+ # batch_size: 1
77
+ # - target: "female person" # what word for erasing the positive concept from
78
+ # positive: "female person, with beard, lipstick and feminine" # concept to erase
79
+ # unconditional: "female person, clean shaven" # word to take the difference from the positive concept
80
+ # neutral: "female person" # starting point for conditioning the target
81
+ # action: "enhance" # erase or enhance
82
+ # guidance_scale: 4
83
+ # resolution: 512
84
+ # dynamic_resolution: false
85
+ # batch_size: 1
86
+ ####################################################################################################### MAKEUP SLIDER
87
+ # - target: "male person" # what word for erasing the positive concept from
88
+ # positive: "male person, with makeup, cosmetic, concealer, mascara" # concept to erase
89
+ # unconditional: "male person, barefaced, ugly" # word to take the difference from the positive concept
90
+ # neutral: "male person" # starting point for conditioning the target
91
+ # action: "enhance" # erase or enhance
92
+ # guidance_scale: 4
93
+ # resolution: 512
94
+ # dynamic_resolution: false
95
+ # batch_size: 1
96
+ # - target: "female person" # what word for erasing the positive concept from
97
+ # positive: "female person, with makeup, cosmetic, concealer, mascara, lipstick" # concept to erase
98
+ # unconditional: "female person, barefaced, ugly" # word to take the difference from the positive concept
99
+ # neutral: "female person" # starting point for conditioning the target
100
+ # action: "enhance" # erase or enhance
101
+ # guidance_scale: 4
102
+ # resolution: 512
103
+ # dynamic_resolution: false
104
+ # batch_size: 1
105
+ ####################################################################################################### SURPRISED SLIDER
106
+ # - target: "male person" # what word for erasing the positive concept from
107
+ # positive: "male person, with shocked look, surprised, stunned, amazed" # concept to erase
108
+ # unconditional: "male person, dull, uninterested, bored, incurious" # word to take the difference from the positive concept
109
+ # neutral: "male person" # starting point for conditioning the target
110
+ # action: "enhance" # erase or enhance
111
+ # guidance_scale: 4
112
+ # resolution: 512
113
+ # dynamic_resolution: false
114
+ # batch_size: 1
115
+ # - target: "female person" # what word for erasing the positive concept from
116
+ # positive: "female person, with shocked look, surprised, stunned, amazed" # concept to erase
117
+ # unconditional: "female person, dull, uninterested, bored, incurious" # word to take the difference from the positive concept
118
+ # neutral: "female person" # starting point for conditioning the target
119
+ # action: "enhance" # erase or enhance
120
+ # guidance_scale: 4
121
+ # resolution: 512
122
+ # dynamic_resolution: false
123
+ # batch_size: 1
124
+ ####################################################################################################### OBESE SLIDER
125
+ # - target: "male person" # what word for erasing the positive concept from
126
+ # positive: "male person, fat, chubby, overweight, obese" # concept to erase
127
+ # unconditional: "male person, lean, fit, slim, slender" # word to take the difference from the positive concept
128
+ # neutral: "male person" # starting point for conditioning the target
129
+ # action: "enhance" # erase or enhance
130
+ # guidance_scale: 4
131
+ # resolution: 512
132
+ # dynamic_resolution: false
133
+ # batch_size: 1
134
+ # - target: "female person" # what word for erasing the positive concept from
135
+ # positive: "female person, fat, chubby, overweight, obese" # concept to erase
136
+ # unconditional: "female person, lean, fit, slim, slender" # word to take the difference from the positive concept
137
+ # neutral: "female person" # starting point for conditioning the target
138
+ # action: "enhance" # erase or enhance
139
+ # guidance_scale: 4
140
+ # resolution: 512
141
+ # dynamic_resolution: false
142
+ # batch_size: 1
143
+ ####################################################################################################### PROFESSIONAL SLIDER
144
+ # - target: "male person" # what word for erasing the positive concept from
145
+ # positive: "male person, professionally dressed, stylised hair, clean face" # concept to erase
146
+ # unconditional: "male person, casually dressed, messy hair, unkempt face" # word to take the difference from the positive concept
147
+ # neutral: "male person" # starting point for conditioning the target
148
+ # action: "enhance" # erase or enhance
149
+ # guidance_scale: 4
150
+ # resolution: 512
151
+ # dynamic_resolution: false
152
+ # batch_size: 1
153
+ # - target: "female person" # what word for erasing the positive concept from
154
+ # positive: "female person, professionally dressed, stylised hair, clean face" # concept to erase
155
+ # unconditional: "female person, casually dressed, messy hair, unkempt face" # word to take the difference from the positive concept
156
+ # neutral: "female person" # starting point for conditioning the target
157
+ # action: "enhance" # erase or enhance
158
+ # guidance_scale: 4
159
+ # resolution: 512
160
+ # dynamic_resolution: false
161
+ # batch_size: 1
162
+ ####################################################################################################### GLASSES SLIDER
163
+ # - target: "male person" # what word for erasing the positive concept from
164
+ # positive: "male person, wearing glasses" # concept to erase
165
+ # unconditional: "male person" # word to take the difference from the positive concept
166
+ # neutral: "male person" # starting point for conditioning the target
167
+ # action: "enhance" # erase or enhance
168
+ # guidance_scale: 4
169
+ # resolution: 512
170
+ # dynamic_resolution: false
171
+ # batch_size: 1
172
+ # - target: "female person" # what word for erasing the positive concept from
173
+ # positive: "female person, wearing glasses" # concept to erase
174
+ # unconditional: "female person" # word to take the difference from the positive concept
175
+ # neutral: "female person" # starting point for conditioning the target
176
+ # action: "enhance" # erase or enhance
177
+ # guidance_scale: 4
178
+ # resolution: 512
179
+ # dynamic_resolution: false
180
+ # batch_size: 1
181
+ ####################################################################################################### ASTRONAUGHT SLIDER
182
+ # - target: "astronaught" # what word for erasing the positive concept from
183
+ # positive: "astronaught, with orange colored spacesuit" # concept to erase
184
+ # unconditional: "astronaught" # word to take the difference from the positive concept
185
+ # neutral: "astronaught" # starting point for conditioning the target
186
+ # action: "enhance" # erase or enhance
187
+ # guidance_scale: 4
188
+ # resolution: 512
189
+ # dynamic_resolution: false
190
+ # batch_size: 1
191
+ ####################################################################################################### SMILING SLIDER
192
+ # - target: "male person" # what word for erasing the positive concept from
193
+ # positive: "male person, smiling" # concept to erase
194
+ # unconditional: "male person, frowning" # word to take the difference from the positive concept
195
+ # neutral: "male person" # starting point for conditioning the target
196
+ # action: "enhance" # erase or enhance
197
+ # guidance_scale: 4
198
+ # resolution: 512
199
+ # dynamic_resolution: false
200
+ # batch_size: 1
201
+ # - target: "female person" # what word for erasing the positive concept from
202
+ # positive: "female person, smiling" # concept to erase
203
+ # unconditional: "female person, frowning" # word to take the difference from the positive concept
204
+ # neutral: "female person" # starting point for conditioning the target
205
+ # action: "enhance" # erase or enhance
206
+ # guidance_scale: 4
207
+ # resolution: 512
208
+ # dynamic_resolution: false
209
+ # batch_size: 1
210
+ ####################################################################################################### CAR COLOR SLIDER
211
+ # - target: "car" # what word for erasing the positive concept from
212
+ # positive: "car, white color" # concept to erase
213
+ # unconditional: "car, black color" # word to take the difference from the positive concept
214
+ # neutral: "car" # starting point for conditioning the target
215
+ # action: "enhance" # erase or enhance
216
+ # guidance_scale: 4
217
+ # resolution: 512
218
+ # dynamic_resolution: false
219
+ # batch_size: 1
220
+ ####################################################################################################### DETAILS SLIDER
221
+ # - target: "" # what word for erasing the positive concept from
222
+ # positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality, hyper realistic" # concept to erase
223
+ # unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept
224
+ # neutral: "" # starting point for conditioning the target
225
+ # action: "enhance" # erase or enhance
226
+ # guidance_scale: 4
227
+ # resolution: 512
228
+ # dynamic_resolution: false
229
+ # batch_size: 1
230
+ ####################################################################################################### CARTOON SLIDER
231
+ # - target: "male person" # what word for erasing the positive concept from
232
+ # positive: "male person, cartoon style, pixar style, animated style" # concept to erase
233
+ # unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept
234
+ # neutral: "male person" # starting point for conditioning the target
235
+ # action: "enhance" # erase or enhance
236
+ # guidance_scale: 4
237
+ # resolution: 512
238
+ # dynamic_resolution: false
239
+ # batch_size: 1
240
+ # - target: "female person" # what word for erasing the positive concept from
241
+ # positive: "female person, cartoon style, pixar style, animated style" # concept to erase
242
+ # unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept
243
+ # neutral: "female person" # starting point for conditioning the target
244
+ # action: "enhance" # erase or enhance
245
+ # guidance_scale: 4
246
+ # resolution: 512
247
+ # dynamic_resolution: false
248
+ # batch_size: 1
249
+ ####################################################################################################### CLAY SLIDER
250
+ # - target: "male person" # what word for erasing the positive concept from
251
+ # positive: "male person, clay style, made out of clay, clay sculpture" # concept to erase
252
+ # unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept
253
+ # neutral: "male person" # starting point for conditioning the target
254
+ # action: "enhance" # erase or enhance
255
+ # guidance_scale: 4
256
+ # resolution: 512
257
+ # dynamic_resolution: false
258
+ # batch_size: 1
259
+ # - target: "female person" # what word for erasing the positive concept from
260
+ # positive: "female person, clay style, made out of clay, clay sculpture" # concept to erase
261
+ # unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept
262
+ # neutral: "female person" # starting point for conditioning the target
263
+ # action: "enhance" # erase or enhance
264
+ # guidance_scale: 4
265
+ # resolution: 512
266
+ # dynamic_resolution: false
267
+ # batch_size: 1
268
+ ####################################################################################################### SCULPTURE SLIDER
269
+ # - target: "male person" # what word for erasing the positive concept from
270
+ # positive: "male person, cement sculpture, cement greek statue style" # concept to erase
271
+ # unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept
272
+ # neutral: "male person" # starting point for conditioning the target
273
+ # action: "enhance" # erase or enhance
274
+ # guidance_scale: 4
275
+ # resolution: 512
276
+ # dynamic_resolution: false
277
+ # batch_size: 1
278
+ # - target: "female person" # what word for erasing the positive concept from
279
+ # positive: "female person, cement sculpture, cement greek statue style" # concept to erase
280
+ # unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept
281
+ # neutral: "female person" # starting point for conditioning the target
282
+ # action: "enhance" # erase or enhance
283
+ # guidance_scale: 4
284
+ # resolution: 512
285
+ # dynamic_resolution: false
286
+ # batch_size: 1
287
+ ####################################################################################################### METAL SLIDER
288
+ # - target: "" # what word for erasing the positive concept from
289
+ # positive: "made out of metal, metallic style, iron, copper, platinum metal," # concept to erase
290
+ # unconditional: "wooden style, made out of wood" # word to take the difference from the positive concept
291
+ # neutral: "" # starting point for conditioning the target
292
+ # action: "enhance" # erase or enhance
293
+ # guidance_scale: 4
294
+ # resolution: 512
295
+ # dynamic_resolution: false
296
+ # batch_size: 1
297
+ ####################################################################################################### FESTIVE SLIDER
298
+ # - target: "" # what word for erasing the positive concept from
299
+ # positive: "festive, colorful banners, confetti, indian festival decorations, chinese festival decorations, fireworks, parade, cherry, gala, happy, celebrations" # concept to erase
300
+ # unconditional: "dull, dark, sad, desserted, empty, alone" # word to take the difference from the positive concept
301
+ # neutral: "" # starting point for conditioning the target
302
+ # action: "enhance" # erase or enhance
303
+ # guidance_scale: 4
304
+ # resolution: 512
305
+ # dynamic_resolution: false
306
+ # batch_size: 1
307
+ ####################################################################################################### TROPICAL SLIDER
308
+ # - target: "" # what word for erasing the positive concept from
309
+ # positive: "tropical, beach, sunny, hot" # concept to erase
310
+ # unconditional: "arctic, winter, snow, ice, iceburg, snowfall" # word to take the difference from the positive concept
311
+ # neutral: "" # starting point for conditioning the target
312
+ # action: "enhance" # erase or enhance
313
+ # guidance_scale: 4
314
+ # resolution: 512
315
+ # dynamic_resolution: false
316
+ # batch_size: 1
317
+ ####################################################################################################### MODERN SLIDER
318
+ # - target: "" # what word for erasing the positive concept from
319
+ # positive: "modern, futuristic style, trendy, stylish, swank" # concept to erase
320
+ # unconditional: "ancient, classic style, regal, vintage" # word to take the difference from the positive concept
321
+ # neutral: "" # starting point for conditioning the target
322
+ # action: "enhance" # erase or enhance
323
+ # guidance_scale: 4
324
+ # resolution: 512
325
+ # dynamic_resolution: false
326
+ # batch_size: 1
327
+ ####################################################################################################### BOKEH SLIDER
328
+ # - target: "" # what word for erasing the positive concept from
329
+ # positive: "blurred background, narrow DOF, bokeh effect" # concept to erase
330
+ # # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept
331
+ # unconditional: ""
332
+ # neutral: "" # starting point for conditioning the target
333
+ # action: "enhance" # erase or enhance
334
+ # guidance_scale: 4
335
+ # resolution: 512
336
+ # dynamic_resolution: false
337
+ # batch_size: 1
338
+ ####################################################################################################### LONG HAIR SLIDER
339
+ # - target: "male person" # what word for erasing the positive concept from
340
+ # positive: "male person, with long hair" # concept to erase
341
+ # unconditional: "male person, with short hair" # word to take the difference from the positive concept
342
+ # neutral: "male person" # starting point for conditioning the target
343
+ # action: "enhance" # erase or enhance
344
+ # guidance_scale: 4
345
+ # resolution: 512
346
+ # dynamic_resolution: false
347
+ # batch_size: 1
348
+ # - target: "female person" # what word for erasing the positive concept from
349
+ # positive: "female person, with long hair" # concept to erase
350
+ # unconditional: "female person, with short hair" # word to take the difference from the positive concept
351
+ # neutral: "female person" # starting point for conditioning the target
352
+ # action: "enhance" # erase or enhance
353
+ # guidance_scale: 4
354
+ # resolution: 512
355
+ # dynamic_resolution: false
356
+ # batch_size: 1
357
+ ####################################################################################################### NEGPROMPT SLIDER
358
+ # - target: "" # what word for erasing the positive concept from
359
+ # positive: "cartoon, cgi, render, illustration, painting, drawing, bad quality, grainy, low resolution" # concept to erase
360
+ # unconditional: ""
361
+ # neutral: "" # starting point for conditioning the target
362
+ # action: "erase" # erase or enhance
363
+ # guidance_scale: 4
364
+ # resolution: 512
365
+ # dynamic_resolution: false
366
+ # batch_size: 1
367
+ ####################################################################################################### EXPENSIVE FOOD SLIDER
368
+ # - target: "food" # what word for erasing the positive concept from
369
+ # positive: "food, expensive and fine dining" # concept to erase
370
+ # unconditional: "food, cheap and low quality" # word to take the difference from the positive concept
371
+ # neutral: "food" # starting point for conditioning the target
372
+ # action: "enhance" # erase or enhance
373
+ # guidance_scale: 4
374
+ # resolution: 512
375
+ # dynamic_resolution: false
376
+ # batch_size: 1
377
+ ####################################################################################################### COOKED FOOD SLIDER
378
+ # - target: "food" # what word for erasing the positive concept from
379
+ # positive: "food, cooked, baked, roasted, fried" # concept to erase
380
+ # unconditional: "food, raw, uncooked, fresh, undone" # word to take the difference from the positive concept
381
+ # neutral: "food" # starting point for conditioning the target
382
+ # action: "enhance" # erase or enhance
383
+ # guidance_scale: 4
384
+ # resolution: 512
385
+ # dynamic_resolution: false
386
+ # batch_size: 1
387
+ ####################################################################################################### MEAT FOOD SLIDER
388
+ # - target: "food" # what word for erasing the positive concept from
389
+ # positive: "food, meat, steak, fish, non-vegetrian, beef, lamb, pork, chicken, salmon" # concept to erase
390
+ # unconditional: "food, vegetables, fruits, leafy-vegetables, greens, vegetarian, vegan, tomatoes, onions, carrots" # word to take the difference from the positive concept
391
+ # neutral: "food" # starting point for conditioning the target
392
+ # action: "enhance" # erase or enhance
393
+ # guidance_scale: 4
394
+ # resolution: 512
395
+ # dynamic_resolution: false
396
+ # batch_size: 1
397
+ ####################################################################################################### WEATHER SLIDER
398
+ # - target: "" # what word for erasing the positive concept from
399
+ # positive: "snowy, winter, cold, ice, snowfall, white" # concept to erase
400
+ # unconditional: "hot, summer, bright, sunny" # word to take the difference from the positive concept
401
+ # neutral: "" # starting point for conditioning the target
402
+ # action: "enhance" # erase or enhance
403
+ # guidance_scale: 4
404
+ # resolution: 512
405
+ # dynamic_resolution: false
406
+ # batch_size: 1
407
+ ####################################################################################################### NIGHT/DAY SLIDER
408
+ # - target: "" # what word for erasing the positive concept from
409
+ # positive: "night time, dark, darkness, pitch black, nighttime" # concept to erase
410
+ # unconditional: "day time, bright, sunny, daytime, sunlight" # word to take the difference from the positive concept
411
+ # neutral: "" # starting point for conditioning the target
412
+ # action: "enhance" # erase or enhance
413
+ # guidance_scale: 4
414
+ # resolution: 512
415
+ # dynamic_resolution: false
416
+ # batch_size: 1
417
+ ####################################################################################################### INDOOR/OUTDOOR SLIDER
418
+ # - target: "" # what word for erasing the positive concept from
419
+ # positive: "indoor, inside a room, inside, interior" # concept to erase
420
+ # unconditional: "outdoor, outside, open air, exterior" # word to take the difference from the positive concept
421
+ # neutral: "" # starting point for conditioning the target
422
+ # action: "enhance" # erase or enhance
423
+ # guidance_scale: 4
424
+ # resolution: 512
425
+ # dynamic_resolution: false
426
+ # batch_size: 1
427
+ ####################################################################################################### GOODHANDS SLIDER
428
+ # - target: "" # what word for erasing the positive concept from
429
+ # positive: "realistic hands, realistic limbs, perfect limbs, perfect hands, 5 fingers, five fingers, hyper realisitc hands" # concept to erase
430
+ # unconditional: "poorly drawn limbs, distorted limbs, poorly rendered hands,bad anatomy, disfigured, mutated body parts, bad composition" # word to take the difference from the positive concept
431
+ # neutral: "" # starting point for conditioning the target
432
+ # action: "enhance" # erase or enhance
433
+ # guidance_scale: 4
434
+ # resolution: 512
435
+ # dynamic_resolution: false
436
+ # batch_size: 1
437
+ ####################################################################################################### RUSTY CAR SLIDER
438
+ # - target: "car" # what word for erasing the positive concept from
439
+ # positive: "car, rusty conditioned" # concept to erase
440
+ # unconditional: "car, mint condition, brand new, shiny" # word to take the difference from the positive concept
441
+ # neutral: "car" # starting point for conditioning the target
442
+ # action: "enhance" # erase or enhance
443
+ # guidance_scale: 4
444
+ # resolution: 512
445
+ # dynamic_resolution: false
446
+ # batch_size: 1
447
+ ####################################################################################################### RUSTY CAR SLIDER
448
+ # - target: "car" # what word for erasing the positive concept from
449
+ # positive: "car, damaged, broken headlights, dented car, with scrapped paintwork" # concept to erase
450
+ # unconditional: "car, mint condition, brand new, shiny" # word to take the difference from the positive concept
451
+ # neutral: "car" # starting point for conditioning the target
452
+ # action: "enhance" # erase or enhance
453
+ # guidance_scale: 4
454
+ # resolution: 512
455
+ # dynamic_resolution: false
456
+ # batch_size: 1
457
+ ####################################################################################################### CLUTTERED ROOM SLIDER
458
+ # - target: "room" # what word for erasing the positive concept from
459
+ # positive: "room, cluttered, disorganized, dirty, jumbled, scattered" # concept to erase
460
+ # unconditional: "room, super organized, clean, ordered, neat, tidy" # word to take the difference from the positive concept
461
+ # neutral: "room" # starting point for conditioning the target
462
+ # action: "enhance" # erase or enhance
463
+ # guidance_scale: 4
464
+ # resolution: 512
465
+ # dynamic_resolution: false
466
+ # batch_size: 1
467
+ ####################################################################################################### HANDS SLIDER
468
+ # - target: "hands" # what word for erasing the positive concept from
469
+ # positive: "realistic hands, five fingers, 8k hyper realistic hands" # concept to erase
470
+ # unconditional: "poorly drawn hands, distorted hands, amputed fingers" # word to take the difference from the positive concept
471
+ # neutral: "hands" # starting point for conditioning the target
472
+ # action: "enhance" # erase or enhance
473
+ # guidance_scale: 4
474
+ # resolution: 512
475
+ # dynamic_resolution: false
476
+ # batch_size: 1
477
+ ####################################################################################################### HANDS SLIDER
478
+ # - target: "female person" # what word for erasing the positive concept from
479
+ # positive: "female person, with a surprised look" # concept to erase
480
+ # unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept
481
+ # neutral: "female person" # starting point for conditioning the target
482
+ # action: "enhance" # erase or enhance
483
+ # guidance_scale: 4
484
+ # resolution: 512
485
+ # dynamic_resolution: false
486
+ # batch_size: 1
trainscripts/textsliders/data/prompts.yaml ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - target: "male person" # what word for erasing the positive concept from
2
+ positive: "male person, very old" # concept to erase
3
+ unconditional: "male person, very young" # word to take the difference from the positive concept
4
+ neutral: "male person" # starting point for conditioning the target
5
+ action: "enhance" # erase or enhance
6
+ guidance_scale: 4
7
+ resolution: 512
8
+ dynamic_resolution: false
9
+ batch_size: 1
10
+ - target: "female person" # what word for erasing the positive concept from
11
+ positive: "female person, very old" # concept to erase
12
+ unconditional: "female person, very young" # word to take the difference from the positive concept
13
+ neutral: "female person" # starting point for conditioning the target
14
+ action: "enhance" # erase or enhance
15
+ guidance_scale: 4
16
+ resolution: 512
17
+ dynamic_resolution: false
18
+ batch_size: 1
19
+ # - target: "" # what word for erasing the positive concept from
20
+ # positive: "a group of people" # concept to erase
21
+ # unconditional: "a person" # word to take the difference from the positive concept
22
+ # neutral: "" # starting point for conditioning the target
23
+ # action: "enhance" # erase or enhance
24
+ # guidance_scale: 4
25
+ # resolution: 512
26
+ # dynamic_resolution: false
27
+ # batch_size: 1
28
+ # - target: "" # what word for erasing the positive concept from
29
+ # positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" # concept to erase
30
+ # unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept
31
+ # neutral: "" # starting point for conditioning the target
32
+ # action: "enhance" # erase or enhance
33
+ # guidance_scale: 4
34
+ # resolution: 512
35
+ # dynamic_resolution: false
36
+ # batch_size: 1
37
+ # - target: "" # what word for erasing the positive concept from
38
+ # positive: "blurred background, narrow DOF, bokeh effect" # concept to erase
39
+ # # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept
40
+ # unconditional: ""
41
+ # neutral: "" # starting point for conditioning the target
42
+ # action: "enhance" # erase or enhance
43
+ # guidance_scale: 4
44
+ # resolution: 512
45
+ # dynamic_resolution: false
46
+ # batch_size: 1
47
+ # - target: "food" # what word for erasing the positive concept from
48
+ # positive: "food, expensive and fine dining" # concept to erase
49
+ # unconditional: "food, cheap and low quality" # word to take the difference from the positive concept
50
+ # neutral: "food" # starting point for conditioning the target
51
+ # action: "enhance" # erase or enhance
52
+ # guidance_scale: 4
53
+ # resolution: 512
54
+ # dynamic_resolution: false
55
+ # batch_size: 1
56
+ # - target: "room" # what word for erasing the positive concept from
57
+ # positive: "room, dirty disorganised and cluttered" # concept to erase
58
+ # unconditional: "room, neat organised and clean" # word to take the difference from the positive concept
59
+ # neutral: "room" # starting point for conditioning the target
60
+ # action: "enhance" # erase or enhance
61
+ # guidance_scale: 4
62
+ # resolution: 512
63
+ # dynamic_resolution: false
64
+ # batch_size: 1
65
+ # - target: "male person" # what word for erasing the positive concept from
66
+ # positive: "male person, with a surprised look" # concept to erase
67
+ # unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept
68
+ # neutral: "male person" # starting point for conditioning the target
69
+ # action: "enhance" # erase or enhance
70
+ # guidance_scale: 4
71
+ # resolution: 512
72
+ # dynamic_resolution: false
73
+ # batch_size: 1
74
+ # - target: "female person" # what word for erasing the positive concept from
75
+ # positive: "female person, with a surprised look" # concept to erase
76
+ # unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept
77
+ # neutral: "female person" # starting point for conditioning the target
78
+ # action: "enhance" # erase or enhance
79
+ # guidance_scale: 4
80
+ # resolution: 512
81
+ # dynamic_resolution: false
82
+ # batch_size: 1
83
+ # - target: "sky" # what word for erasing the positive concept from
84
+ # positive: "peaceful sky" # concept to erase
85
+ # unconditional: "sky" # word to take the difference from the positive concept
86
+ # neutral: "sky" # starting point for conditioning the target
87
+ # action: "enhance" # erase or enhance
88
+ # guidance_scale: 4
89
+ # resolution: 512
90
+ # dynamic_resolution: false
91
+ # batch_size: 1
92
+ # - target: "sky" # what word for erasing the positive concept from
93
+ # positive: "chaotic dark sky" # concept to erase
94
+ # unconditional: "sky" # word to take the difference from the positive concept
95
+ # neutral: "sky" # starting point for conditioning the target
96
+ # action: "erase" # erase or enhance
97
+ # guidance_scale: 4
98
+ # resolution: 512
99
+ # dynamic_resolution: false
100
+ # batch_size: 1
101
+ # - target: "person" # what word for erasing the positive concept from
102
+ # positive: "person, very young" # concept to erase
103
+ # unconditional: "person" # word to take the difference from the positive concept
104
+ # neutral: "person" # starting point for conditioning the target
105
+ # action: "erase" # erase or enhance
106
+ # guidance_scale: 4
107
+ # resolution: 512
108
+ # dynamic_resolution: false
109
+ # batch_size: 1
110
+ # overweight
111
+ # - target: "art" # what word for erasing the positive concept from
112
+ # positive: "realistic art" # concept to erase
113
+ # unconditional: "art" # word to take the difference from the positive concept
114
+ # neutral: "art" # starting point for conditioning the target
115
+ # action: "enhance" # erase or enhance
116
+ # guidance_scale: 4
117
+ # resolution: 512
118
+ # dynamic_resolution: false
119
+ # batch_size: 1
120
+ # - target: "art" # what word for erasing the positive concept from
121
+ # positive: "abstract art" # concept to erase
122
+ # unconditional: "art" # word to take the difference from the positive concept
123
+ # neutral: "art" # starting point for conditioning the target
124
+ # action: "erase" # erase or enhance
125
+ # guidance_scale: 4
126
+ # resolution: 512
127
+ # dynamic_resolution: false
128
+ # batch_size: 1
129
+ # sky
130
+ # - target: "weather" # what word for erasing the positive concept from
131
+ # positive: "bright pleasant weather" # concept to erase
132
+ # unconditional: "weather" # word to take the difference from the positive concept
133
+ # neutral: "weather" # starting point for conditioning the target
134
+ # action: "enhance" # erase or enhance
135
+ # guidance_scale: 4
136
+ # resolution: 512
137
+ # dynamic_resolution: false
138
+ # batch_size: 1
139
+ # - target: "weather" # what word for erasing the positive concept from
140
+ # positive: "dark gloomy weather" # concept to erase
141
+ # unconditional: "weather" # word to take the difference from the positive concept
142
+ # neutral: "weather" # starting point for conditioning the target
143
+ # action: "erase" # erase or enhance
144
+ # guidance_scale: 4
145
+ # resolution: 512
146
+ # dynamic_resolution: false
147
+ # batch_size: 1
148
+ # hair
149
+ # - target: "person" # what word for erasing the positive concept from
150
+ # positive: "person with long hair" # concept to erase
151
+ # unconditional: "person" # word to take the difference from the positive concept
152
+ # neutral: "person" # starting point for conditioning the target
153
+ # action: "enhance" # erase or enhance
154
+ # guidance_scale: 4
155
+ # resolution: 512
156
+ # dynamic_resolution: false
157
+ # batch_size: 1
158
+ # - target: "person" # what word for erasing the positive concept from
159
+ # positive: "person with short hair" # concept to erase
160
+ # unconditional: "person" # word to take the difference from the positive concept
161
+ # neutral: "person" # starting point for conditioning the target
162
+ # action: "erase" # erase or enhance
163
+ # guidance_scale: 4
164
+ # resolution: 512
165
+ # dynamic_resolution: false
166
+ # batch_size: 1
167
+ # - target: "girl" # what word for erasing the positive concept from
168
+ # positive: "baby girl" # concept to erase
169
+ # unconditional: "girl" # word to take the difference from the positive concept
170
+ # neutral: "girl" # starting point for conditioning the target
171
+ # action: "enhance" # erase or enhance
172
+ # guidance_scale: -4
173
+ # resolution: 512
174
+ # dynamic_resolution: false
175
+ # batch_size: 1
176
+ # - target: "boy" # what word for erasing the positive concept from
177
+ # positive: "old man" # concept to erase
178
+ # unconditional: "boy" # word to take the difference from the positive concept
179
+ # neutral: "boy" # starting point for conditioning the target
180
+ # action: "enhance" # erase or enhance
181
+ # guidance_scale: 4
182
+ # resolution: 512
183
+ # dynamic_resolution: false
184
+ # batch_size: 1
185
+ # - target: "boy" # what word for erasing the positive concept from
186
+ # positive: "baby boy" # concept to erase
187
+ # unconditional: "boy" # word to take the difference from the positive concept
188
+ # neutral: "boy" # starting point for conditioning the target
189
+ # action: "enhance" # erase or enhance
190
+ # guidance_scale: -4
191
+ # resolution: 512
192
+ # dynamic_resolution: false
193
+ # batch_size: 1
trainscripts/textsliders/debug_util.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # デバッグ用...
2
+
3
+ import torch
4
+
5
+
6
+ def check_requires_grad(model: torch.nn.Module):
7
+ for name, module in list(model.named_modules())[:5]:
8
+ if len(list(module.parameters())) > 0:
9
+ print(f"Module: {name}")
10
+ for name, param in list(module.named_parameters())[:2]:
11
+ print(f" Parameter: {name}, Requires Grad: {param.requires_grad}")
12
+
13
+
14
+ def check_training_mode(model: torch.nn.Module):
15
+ for name, module in list(model.named_modules())[:5]:
16
+ print(f"Module: {name}, Training Mode: {module.training}")
trainscripts/textsliders/demotrain.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:
2
+ # - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566
3
+ # - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py
4
+
5
+ from typing import List, Optional
6
+ import argparse
7
+ import ast
8
+ from pathlib import Path
9
+ import gc
10
+
11
+ import torch
12
+ from tqdm import tqdm
13
+
14
+
15
+ from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
16
+ import trainscripts.textsliders.train_util as train_util
17
+ import trainscripts.textsliders.model_util as model_util
18
+ import trainscripts.textsliders.prompt_util as prompt_util
19
+ from trainscripts.textsliders.prompt_util import (
20
+ PromptEmbedsCache,
21
+ PromptEmbedsPair,
22
+ PromptSettings,
23
+ PromptEmbedsXL,
24
+ )
25
+ import trainscripts.textsliders.debug_util as debug_util
26
+ import trainscripts.textsliders.config_util as config_util
27
+ from trainscripts.textsliders.config_util import RootConfig
28
+
29
+ import wandb
30
+
31
+ NUM_IMAGES_PER_PROMPT = 1
32
+
33
+
34
+ def flush():
35
+ torch.cuda.empty_cache()
36
+ gc.collect()
37
+
38
+
39
+ def train(
40
+ config: RootConfig,
41
+ prompts: list[PromptSettings],
42
+ device,
43
+ ):
44
+ metadata = {
45
+ "prompts": ",".join([prompt.json() for prompt in prompts]),
46
+ "config": config.json(),
47
+ }
48
+ save_path = Path(config.save.path)
49
+
50
+ modules = DEFAULT_TARGET_REPLACE
51
+ if config.network.type == "c3lier":
52
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
53
+
54
+ if config.logging.verbose:
55
+ print(metadata)
56
+
57
+ if config.logging.use_wandb:
58
+ wandb.init(project=f"LECO_{config.save.name}", config=metadata)
59
+
60
+ weight_dtype = config_util.parse_precision(config.train.precision)
61
+ save_weight_dtype = config_util.parse_precision(config.train.precision)
62
+
63
+ (
64
+ tokenizers,
65
+ text_encoders,
66
+ unet,
67
+ noise_scheduler,
68
+ ) = model_util.load_models_xl(
69
+ config.pretrained_model.name_or_path,
70
+ scheduler_name=config.train.noise_scheduler,
71
+ )
72
+
73
+ for text_encoder in text_encoders:
74
+ text_encoder.to(device, dtype=weight_dtype)
75
+ text_encoder.requires_grad_(False)
76
+ text_encoder.eval()
77
+
78
+ unet.to(device, dtype=weight_dtype)
79
+ if config.other.use_xformers:
80
+ unet.enable_xformers_memory_efficient_attention()
81
+ unet.requires_grad_(False)
82
+ unet.eval()
83
+
84
+ network = LoRANetwork(
85
+ unet,
86
+ rank=config.network.rank,
87
+ multiplier=1.0,
88
+ alpha=config.network.alpha,
89
+ train_method=config.network.training_method,
90
+ ).to(device, dtype=weight_dtype)
91
+
92
+ optimizer_module = train_util.get_optimizer(config.train.optimizer)
93
+ #optimizer_args
94
+ optimizer_kwargs = {}
95
+ if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0:
96
+ for arg in config.train.optimizer_args.split(" "):
97
+ key, value = arg.split("=")
98
+ value = ast.literal_eval(value)
99
+ optimizer_kwargs[key] = value
100
+
101
+ optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs)
102
+ lr_scheduler = train_util.get_lr_scheduler(
103
+ config.train.lr_scheduler,
104
+ optimizer,
105
+ max_iterations=config.train.iterations,
106
+ lr_min=config.train.lr / 100,
107
+ )
108
+ criteria = torch.nn.MSELoss()
109
+
110
+ print("Prompts")
111
+ for settings in prompts:
112
+ print(settings)
113
+
114
+ # debug
115
+ debug_util.check_requires_grad(network)
116
+ debug_util.check_training_mode(network)
117
+
118
+ cache = PromptEmbedsCache()
119
+ prompt_pairs: list[PromptEmbedsPair] = []
120
+
121
+ with torch.no_grad():
122
+ for settings in prompts:
123
+ print(settings)
124
+ for prompt in [
125
+ settings.target,
126
+ settings.positive,
127
+ settings.neutral,
128
+ settings.unconditional,
129
+ ]:
130
+ if cache[prompt] == None:
131
+ tex_embs, pool_embs = train_util.encode_prompts_xl(
132
+ tokenizers,
133
+ text_encoders,
134
+ [prompt],
135
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
136
+ )
137
+ cache[prompt] = PromptEmbedsXL(
138
+ tex_embs,
139
+ pool_embs
140
+ )
141
+
142
+ prompt_pairs.append(
143
+ PromptEmbedsPair(
144
+ criteria,
145
+ cache[settings.target],
146
+ cache[settings.positive],
147
+ cache[settings.unconditional],
148
+ cache[settings.neutral],
149
+ settings,
150
+ )
151
+ )
152
+
153
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
154
+ del tokenizer, text_encoder
155
+
156
+ flush()
157
+
158
+ pbar = tqdm(range(config.train.iterations))
159
+
160
+ loss = None
161
+
162
+ for i in pbar:
163
+ with torch.no_grad():
164
+ noise_scheduler.set_timesteps(
165
+ config.train.max_denoising_steps, device=device
166
+ )
167
+
168
+ optimizer.zero_grad()
169
+
170
+ prompt_pair: PromptEmbedsPair = prompt_pairs[
171
+ torch.randint(0, len(prompt_pairs), (1,)).item()
172
+ ]
173
+
174
+ # 1 ~ 49 からランダム
175
+ timesteps_to = torch.randint(
176
+ 1, config.train.max_denoising_steps, (1,)
177
+ ).item()
178
+
179
+ height, width = prompt_pair.resolution, prompt_pair.resolution
180
+ if prompt_pair.dynamic_resolution:
181
+ height, width = train_util.get_random_resolution_in_bucket(
182
+ prompt_pair.resolution
183
+ )
184
+
185
+ if config.logging.verbose:
186
+ print("gudance_scale:", prompt_pair.guidance_scale)
187
+ print("resolution:", prompt_pair.resolution)
188
+ print("dynamic_resolution:", prompt_pair.dynamic_resolution)
189
+ if prompt_pair.dynamic_resolution:
190
+ print("bucketed resolution:", (height, width))
191
+ print("batch_size:", prompt_pair.batch_size)
192
+ print("dynamic_crops:", prompt_pair.dynamic_crops)
193
+
194
+ latents = train_util.get_initial_latents(
195
+ noise_scheduler, prompt_pair.batch_size, height, width, 1
196
+ ).to(device, dtype=weight_dtype)
197
+
198
+ add_time_ids = train_util.get_add_time_ids(
199
+ height,
200
+ width,
201
+ dynamic_crops=prompt_pair.dynamic_crops,
202
+ dtype=weight_dtype,
203
+ ).to(device, dtype=weight_dtype)
204
+
205
+ with network:
206
+ # ちょっとデノイズされれたものが返る
207
+ denoised_latents = train_util.diffusion_xl(
208
+ unet,
209
+ noise_scheduler,
210
+ latents, # 単純なノイズのlatentsを渡す
211
+ text_embeddings=train_util.concat_embeddings(
212
+ prompt_pair.unconditional.text_embeds,
213
+ prompt_pair.target.text_embeds,
214
+ prompt_pair.batch_size,
215
+ ),
216
+ add_text_embeddings=train_util.concat_embeddings(
217
+ prompt_pair.unconditional.pooled_embeds,
218
+ prompt_pair.target.pooled_embeds,
219
+ prompt_pair.batch_size,
220
+ ),
221
+ add_time_ids=train_util.concat_embeddings(
222
+ add_time_ids, add_time_ids, prompt_pair.batch_size
223
+ ),
224
+ start_timesteps=0,
225
+ total_timesteps=timesteps_to,
226
+ guidance_scale=3,
227
+ )
228
+
229
+ noise_scheduler.set_timesteps(1000)
230
+
231
+ current_timestep = noise_scheduler.timesteps[
232
+ int(timesteps_to * 1000 / config.train.max_denoising_steps)
233
+ ]
234
+
235
+ # with network: の外では空のLoRAのみが有効になる
236
+ positive_latents = train_util.predict_noise_xl(
237
+ unet,
238
+ noise_scheduler,
239
+ current_timestep,
240
+ denoised_latents,
241
+ text_embeddings=train_util.concat_embeddings(
242
+ prompt_pair.unconditional.text_embeds,
243
+ prompt_pair.positive.text_embeds,
244
+ prompt_pair.batch_size,
245
+ ),
246
+ add_text_embeddings=train_util.concat_embeddings(
247
+ prompt_pair.unconditional.pooled_embeds,
248
+ prompt_pair.positive.pooled_embeds,
249
+ prompt_pair.batch_size,
250
+ ),
251
+ add_time_ids=train_util.concat_embeddings(
252
+ add_time_ids, add_time_ids, prompt_pair.batch_size
253
+ ),
254
+ guidance_scale=1,
255
+ ).to(device, dtype=weight_dtype)
256
+ neutral_latents = train_util.predict_noise_xl(
257
+ unet,
258
+ noise_scheduler,
259
+ current_timestep,
260
+ denoised_latents,
261
+ text_embeddings=train_util.concat_embeddings(
262
+ prompt_pair.unconditional.text_embeds,
263
+ prompt_pair.neutral.text_embeds,
264
+ prompt_pair.batch_size,
265
+ ),
266
+ add_text_embeddings=train_util.concat_embeddings(
267
+ prompt_pair.unconditional.pooled_embeds,
268
+ prompt_pair.neutral.pooled_embeds,
269
+ prompt_pair.batch_size,
270
+ ),
271
+ add_time_ids=train_util.concat_embeddings(
272
+ add_time_ids, add_time_ids, prompt_pair.batch_size
273
+ ),
274
+ guidance_scale=1,
275
+ ).to(device, dtype=weight_dtype)
276
+ unconditional_latents = train_util.predict_noise_xl(
277
+ unet,
278
+ noise_scheduler,
279
+ current_timestep,
280
+ denoised_latents,
281
+ text_embeddings=train_util.concat_embeddings(
282
+ prompt_pair.unconditional.text_embeds,
283
+ prompt_pair.unconditional.text_embeds,
284
+ prompt_pair.batch_size,
285
+ ),
286
+ add_text_embeddings=train_util.concat_embeddings(
287
+ prompt_pair.unconditional.pooled_embeds,
288
+ prompt_pair.unconditional.pooled_embeds,
289
+ prompt_pair.batch_size,
290
+ ),
291
+ add_time_ids=train_util.concat_embeddings(
292
+ add_time_ids, add_time_ids, prompt_pair.batch_size
293
+ ),
294
+ guidance_scale=1,
295
+ ).to(device, dtype=weight_dtype)
296
+
297
+ if config.logging.verbose:
298
+ print("positive_latents:", positive_latents[0, 0, :5, :5])
299
+ print("neutral_latents:", neutral_latents[0, 0, :5, :5])
300
+ print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
301
+
302
+ with network:
303
+ target_latents = train_util.predict_noise_xl(
304
+ unet,
305
+ noise_scheduler,
306
+ current_timestep,
307
+ denoised_latents,
308
+ text_embeddings=train_util.concat_embeddings(
309
+ prompt_pair.unconditional.text_embeds,
310
+ prompt_pair.target.text_embeds,
311
+ prompt_pair.batch_size,
312
+ ),
313
+ add_text_embeddings=train_util.concat_embeddings(
314
+ prompt_pair.unconditional.pooled_embeds,
315
+ prompt_pair.target.pooled_embeds,
316
+ prompt_pair.batch_size,
317
+ ),
318
+ add_time_ids=train_util.concat_embeddings(
319
+ add_time_ids, add_time_ids, prompt_pair.batch_size
320
+ ),
321
+ guidance_scale=1,
322
+ ).to(device, dtype=weight_dtype)
323
+
324
+ if config.logging.verbose:
325
+ print("target_latents:", target_latents[0, 0, :5, :5])
326
+
327
+ positive_latents.requires_grad = False
328
+ neutral_latents.requires_grad = False
329
+ unconditional_latents.requires_grad = False
330
+
331
+ loss = prompt_pair.loss(
332
+ target_latents=target_latents,
333
+ positive_latents=positive_latents,
334
+ neutral_latents=neutral_latents,
335
+ unconditional_latents=unconditional_latents,
336
+ )
337
+
338
+ # 1000倍しないとずっと0.000...になってしまって見た目的に面白くない
339
+ pbar.set_description(f"Loss*1k: {loss.item()*1000:.4f}")
340
+ if config.logging.use_wandb:
341
+ wandb.log(
342
+ {"loss": loss, "iteration": i, "lr": lr_scheduler.get_last_lr()[0]}
343
+ )
344
+
345
+ loss.backward()
346
+ optimizer.step()
347
+ lr_scheduler.step()
348
+
349
+ del (
350
+ positive_latents,
351
+ neutral_latents,
352
+ unconditional_latents,
353
+ target_latents,
354
+ latents,
355
+ )
356
+ flush()
357
+
358
+ # if (
359
+ # i % config.save.per_steps == 0
360
+ # and i != 0
361
+ # and i != config.train.iterations - 1
362
+ # ):
363
+ # print("Saving...")
364
+ # save_path.mkdir(parents=True, exist_ok=True)
365
+ # network.save_weights(
366
+ # save_path / f"{config.save.name}_{i}steps.pt",
367
+ # dtype=save_weight_dtype,
368
+ # )
369
+
370
+ print("Saving...")
371
+ save_path.mkdir(parents=True, exist_ok=True)
372
+ network.save_weights(
373
+ save_path / f"{config.save.name}",
374
+ dtype=save_weight_dtype,
375
+ )
376
+
377
+ del (
378
+ unet,
379
+ noise_scheduler,
380
+ loss,
381
+ optimizer,
382
+ network,
383
+ )
384
+
385
+ flush()
386
+
387
+ print("Done.")
388
+
389
+
390
+ # def main(args):
391
+ # config_file = args.config_file
392
+
393
+ # config = config_util.load_config_from_yaml(config_file)
394
+ # if args.name is not None:
395
+ # config.save.name = args.name
396
+ # attributes = []
397
+ # if args.attributes is not None:
398
+ # attributes = args.attributes.split(',')
399
+ # attributes = [a.strip() for a in attributes]
400
+
401
+ # config.network.alpha = args.alpha
402
+ # config.network.rank = args.rank
403
+ # config.save.name += f'_alpha{args.alpha}'
404
+ # config.save.name += f'_rank{config.network.rank }'
405
+ # config.save.name += f'_{config.network.training_method}'
406
+ # config.save.path += f'/{config.save.name}'
407
+
408
+ # prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes)
409
+
410
+ # device = torch.device(f"cuda:{args.device}")
411
+ # train(config, prompts, device)
412
+
413
+
414
+ def train_xl(target, positive, negative, lr, iterations, config_file, rank, train_method, device, attributes,save_name):
415
+
416
+ config = config_util.load_config_from_yaml(config_file)
417
+ randn = torch.randint(1, 10000000, (1,)).item()
418
+ config.save.name = save_name
419
+
420
+ config.train.lr = float(lr)
421
+ config.train.iterations=int(iterations)
422
+
423
+ if attributes is not None:
424
+ attributes = attributes.split(',')
425
+ attributes = [a.strip() for a in attributes]
426
+ else:
427
+ attributes = []
428
+ config.network.alpha = 1.0
429
+ config.network.rank = int(rank)
430
+ config.network.training_method = train_method
431
+
432
+ # config.save.path += f'/{config.save.name}'
433
+
434
+ prompts = prompt_util.load_prompts_from_yaml(path=config.prompts_file, target=target, positive=positive, negative=negative, attributes=attributes)
435
+
436
+ device = torch.device(device)
437
+ train(config, prompts, device)
trainscripts/textsliders/flush.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+
4
+ torch.cuda.empty_cache()
5
+ gc.collect()
trainscripts/textsliders/generate_images_xl.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import argparse
4
+ import os, json, random
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
+ import glob, re
8
+
9
+ from safetensors.torch import load_file
10
+ import matplotlib.image as mpimg
11
+ import copy
12
+ import gc
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+
15
+ import diffusers
16
+ from diffusers import DiffusionPipeline
17
+ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
18
+ from diffusers.loaders import AttnProcsLayers
19
+ from diffusers.models.attention_processor import LoRAAttnProcessor, AttentionProcessor
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+ from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
22
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
23
+ import inspect
24
+ import os
25
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
26
+ from diffusers.pipelines import StableDiffusionXLPipeline
27
+ import random
28
+
29
+ import torch
30
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
31
+ import re
32
+ import argparse
33
+
34
+ def flush():
35
+ torch.cuda.empty_cache()
36
+ gc.collect()
37
+
38
+ @torch.no_grad()
39
+ def call(
40
+ self,
41
+ prompt: Union[str, List[str]] = None,
42
+ prompt_2: Optional[Union[str, List[str]]] = None,
43
+ height: Optional[int] = None,
44
+ width: Optional[int] = None,
45
+ num_inference_steps: int = 50,
46
+ denoising_end: Optional[float] = None,
47
+ guidance_scale: float = 5.0,
48
+ negative_prompt: Optional[Union[str, List[str]]] = None,
49
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
50
+ num_images_per_prompt: Optional[int] = 1,
51
+ eta: float = 0.0,
52
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
53
+ latents: Optional[torch.FloatTensor] = None,
54
+ prompt_embeds: Optional[torch.FloatTensor] = None,
55
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
56
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
57
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
58
+ output_type: Optional[str] = "pil",
59
+ return_dict: bool = True,
60
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
61
+ callback_steps: int = 1,
62
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
63
+ guidance_rescale: float = 0.0,
64
+ original_size: Optional[Tuple[int, int]] = None,
65
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
66
+ target_size: Optional[Tuple[int, int]] = None,
67
+ negative_original_size: Optional[Tuple[int, int]] = None,
68
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
69
+ negative_target_size: Optional[Tuple[int, int]] = None,
70
+
71
+ network=None,
72
+ start_noise=None,
73
+ scale=None,
74
+ unet=None,
75
+ ):
76
+ r"""
77
+ Function invoked when calling the pipeline for generation.
78
+
79
+ Args:
80
+ prompt (`str` or `List[str]`, *optional*):
81
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
82
+ instead.
83
+ prompt_2 (`str` or `List[str]`, *optional*):
84
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
85
+ used in both text-encoders
86
+ height (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor):
87
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
88
+ Anything below 512 pixels won't work well for
89
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
90
+ and checkpoints that are not specifically fine-tuned on low resolutions.
91
+ width (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor):
92
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
93
+ Anything below 512 pixels won't work well for
94
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
95
+ and checkpoints that are not specifically fine-tuned on low resolutions.
96
+ num_inference_steps (`int`, *optional*, defaults to 50):
97
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
98
+ expense of slower inference.
99
+ denoising_end (`float`, *optional*):
100
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
101
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
102
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
103
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
104
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
105
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
106
+ guidance_scale (`float`, *optional*, defaults to 5.0):
107
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
108
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
109
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
110
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
111
+ usually at the expense of lower image quality.
112
+ negative_prompt (`str` or `List[str]`, *optional*):
113
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
114
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
115
+ less than `1`).
116
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
117
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
118
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
119
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
120
+ The number of images to generate per prompt.
121
+ eta (`float`, *optional*, defaults to 0.0):
122
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
123
+ [`schedulers.DDIMScheduler`], will be ignored for others.
124
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
125
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
126
+ to make generation deterministic.
127
+ latents (`torch.FloatTensor`, *optional*):
128
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
129
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
130
+ tensor will ge generated by sampling using the supplied random `generator`.
131
+ prompt_embeds (`torch.FloatTensor`, *optional*):
132
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
133
+ provided, text embeddings will be generated from `prompt` input argument.
134
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
135
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
136
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
137
+ argument.
138
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
139
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
140
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
141
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
142
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
143
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
144
+ input argument.
145
+ output_type (`str`, *optional*, defaults to `"pil"`):
146
+ The output format of the generate image. Choose between
147
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
148
+ return_dict (`bool`, *optional*, defaults to `True`):
149
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
150
+ of a plain tuple.
151
+ callback (`Callable`, *optional*):
152
+ A function that will be called every `callback_steps` steps during inference. The function will be
153
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
154
+ callback_steps (`int`, *optional*, defaults to 1):
155
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
156
+ called at every step.
157
+ cross_attention_kwargs (`dict`, *optional*):
158
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
159
+ `self.processor` in
160
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
161
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
162
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
163
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
164
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
165
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
166
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
167
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
168
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
169
+ explained in section 2.2 of
170
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
171
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
172
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
173
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
174
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
175
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
176
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
177
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
178
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
179
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
180
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
181
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
182
+ micro-conditioning as explained in section 2.2 of
183
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
184
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
185
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
186
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
187
+ micro-conditioning as explained in section 2.2 of
188
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
189
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
190
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
191
+ To negatively condition the generation process based on a target image resolution. It should be as same
192
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
193
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
194
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
195
+
196
+ Examples:
197
+
198
+ Returns:
199
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
200
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
201
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
202
+ """
203
+ # 0. Default height and width to unet
204
+ height = height or self.default_sample_size * self.vae_scale_factor
205
+ width = width or self.default_sample_size * self.vae_scale_factor
206
+
207
+ original_size = original_size or (height, width)
208
+ target_size = target_size or (height, width)
209
+
210
+ # 1. Check inputs. Raise error if not correct
211
+ self.check_inputs(
212
+ prompt,
213
+ prompt_2,
214
+ height,
215
+ width,
216
+ callback_steps,
217
+ negative_prompt,
218
+ negative_prompt_2,
219
+ prompt_embeds,
220
+ negative_prompt_embeds,
221
+ pooled_prompt_embeds,
222
+ negative_pooled_prompt_embeds,
223
+ )
224
+
225
+ # 2. Define call parameters
226
+ if prompt is not None and isinstance(prompt, str):
227
+ batch_size = 1
228
+ elif prompt is not None and isinstance(prompt, list):
229
+ batch_size = len(prompt)
230
+ else:
231
+ batch_size = prompt_embeds.shape[0]
232
+
233
+ device = self._execution_device
234
+
235
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
236
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
237
+ # corresponds to doing no classifier free guidance.
238
+ do_classifier_free_guidance = guidance_scale > 1.0
239
+
240
+ # 3. Encode input prompt
241
+ text_encoder_lora_scale = (
242
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
243
+ )
244
+ (
245
+ prompt_embeds,
246
+ negative_prompt_embeds,
247
+ pooled_prompt_embeds,
248
+ negative_pooled_prompt_embeds,
249
+ ) = self.encode_prompt(
250
+ prompt=prompt,
251
+ prompt_2=prompt_2,
252
+ device=device,
253
+ num_images_per_prompt=num_images_per_prompt,
254
+ do_classifier_free_guidance=do_classifier_free_guidance,
255
+ negative_prompt=negative_prompt,
256
+ negative_prompt_2=negative_prompt_2,
257
+ prompt_embeds=prompt_embeds,
258
+ negative_prompt_embeds=negative_prompt_embeds,
259
+ pooled_prompt_embeds=pooled_prompt_embeds,
260
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
261
+ lora_scale=text_encoder_lora_scale,
262
+ )
263
+
264
+ # 4. Prepare timesteps
265
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
266
+
267
+ timesteps = self.scheduler.timesteps
268
+
269
+ # 5. Prepare latent variables
270
+ num_channels_latents = unet.config.in_channels
271
+ latents = self.prepare_latents(
272
+ batch_size * num_images_per_prompt,
273
+ num_channels_latents,
274
+ height,
275
+ width,
276
+ prompt_embeds.dtype,
277
+ device,
278
+ generator,
279
+ latents,
280
+ )
281
+
282
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
283
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
284
+
285
+ # 7. Prepare added time ids & embeddings
286
+ add_text_embeds = pooled_prompt_embeds
287
+ add_time_ids = self._get_add_time_ids(
288
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
289
+ )
290
+ if negative_original_size is not None and negative_target_size is not None:
291
+ negative_add_time_ids = self._get_add_time_ids(
292
+ negative_original_size,
293
+ negative_crops_coords_top_left,
294
+ negative_target_size,
295
+ dtype=prompt_embeds.dtype,
296
+ )
297
+ else:
298
+ negative_add_time_ids = add_time_ids
299
+
300
+ if do_classifier_free_guidance:
301
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
302
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
303
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
304
+
305
+ prompt_embeds = prompt_embeds.to(device)
306
+ add_text_embeds = add_text_embeds.to(device)
307
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
308
+
309
+ # 8. Denoising loop
310
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
311
+
312
+ # 7.1 Apply denoising_end
313
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
314
+ discrete_timestep_cutoff = int(
315
+ round(
316
+ self.scheduler.config.num_train_timesteps
317
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
318
+ )
319
+ )
320
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
321
+ timesteps = timesteps[:num_inference_steps]
322
+
323
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
324
+ for i, t in enumerate(timesteps):
325
+ if t>start_noise:
326
+ network.set_lora_slider(scale=0)
327
+ else:
328
+ network.set_lora_slider(scale=scale)
329
+ # expand the latents if we are doing classifier free guidance
330
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
331
+
332
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
333
+
334
+ # predict the noise residual
335
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
336
+ with network:
337
+ noise_pred = unet(
338
+ latent_model_input,
339
+ t,
340
+ encoder_hidden_states=prompt_embeds,
341
+ cross_attention_kwargs=cross_attention_kwargs,
342
+ added_cond_kwargs=added_cond_kwargs,
343
+ return_dict=False,
344
+ )[0]
345
+
346
+ # perform guidance
347
+ if do_classifier_free_guidance:
348
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
349
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
350
+
351
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
352
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
353
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
354
+
355
+ # compute the previous noisy sample x_t -> x_t-1
356
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
357
+
358
+ # call the callback, if provided
359
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
360
+ progress_bar.update()
361
+ if callback is not None and i % callback_steps == 0:
362
+ callback(i, t, latents)
363
+
364
+ if not output_type == "latent":
365
+ # make sure the VAE is in float32 mode, as it overflows in float16
366
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
367
+
368
+ if needs_upcasting:
369
+ self.upcast_vae()
370
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
371
+
372
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
373
+
374
+ # cast back to fp16 if needed
375
+ if needs_upcasting:
376
+ self.vae.to(dtype=torch.float16)
377
+ else:
378
+ image = latents
379
+
380
+ if not output_type == "latent":
381
+ # apply watermark if available
382
+ if self.watermark is not None:
383
+ image = self.watermark.apply_watermark(image)
384
+
385
+ image = self.image_processor.postprocess(image, output_type=output_type)
386
+
387
+ # Offload all models
388
+ # self.maybe_free_model_hooks()
389
+
390
+ if not return_dict:
391
+ return (image,)
392
+
393
+ return StableDiffusionXLPipelineOutput(images=image)
394
+
395
+
396
+ def sorted_nicely( l ):
397
+ convert = lambda text: float(text) if text.replace('-','').replace('.','').isdigit() else text
398
+ alphanum_key = lambda key: [convert(c) for c in re.split('(-?[0-9]+.?[0-9]+?)', key) ]
399
+ return sorted(l, key = alphanum_key)
400
+
401
+ def flush():
402
+ torch.cuda.empty_cache()
403
+ gc.collect()
404
+
405
+
406
+ if __name__=='__main__':
407
+
408
+ device = 'cuda:0'
409
+ StableDiffusionXLPipeline.__call__ = call
410
+ pipe = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0')
411
+
412
+ # pipe.__call__ = call
413
+ pipe = pipe.to(device)
414
+
415
+
416
+ parser = argparse.ArgumentParser(
417
+ prog = 'generateImages',
418
+ description = 'Generate Images using Diffusers Code')
419
+ parser.add_argument('--model_name', help='name of model', type=str, required=True)
420
+ parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=True)
421
+ parser.add_argument('--negative_prompts', help='negative prompt', type=str, required=False, default=None)
422
+ parser.add_argument('--save_path', help='folder where to save images', type=str, required=True)
423
+ parser.add_argument('--base', help='version of stable diffusion to use', type=str, required=False, default='1.4')
424
+ parser.add_argument('--guidance_scale', help='guidance to run eval', type=float, required=False, default=7.5)
425
+ parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512)
426
+ parser.add_argument('--till_case', help='continue generating from case_number', type=int, required=False, default=1000000)
427
+ parser.add_argument('--from_case', help='continue generating from case_number', type=int, required=False, default=0)
428
+ parser.add_argument('--num_samples', help='number of samples per prompt', type=int, required=False, default=5)
429
+ parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=50)
430
+ parser.add_argument('--rank', help='rank of the LoRA', type=int, required=False, default=4)
431
+ parser.add_argument('--start_noise', help='what time stamp to flip to edited model', type=int, required=False, default=750)
432
+
433
+ args = parser.parse_args()
434
+ lora_weight = args.model_name
435
+ csv_path = args.prompts_path
436
+ save_path = args.save_path
437
+ start_noise = args.start_noise
438
+ from_case = args.from_case
439
+ till_case = args.till_case
440
+
441
+ weight_dtype = torch.float16
442
+ num_images_per_prompt = 1
443
+ scales = [-2, -1, 0, 1, 2]
444
+ scales = [-1, -.5, 0, .5, 1]
445
+ scales = [-2]
446
+ df = pd.read_csv(csv_path)
447
+
448
+ for scale in scales:
449
+ os.makedirs(f'{save_path}/{os.path.basename(lora_weight)}/{scale}', exist_ok=True)
450
+
451
+ prompts = list(df['prompt'])
452
+ seeds = list(df['evaluation_seed'])
453
+ case_numbers = list(df['case_number'])
454
+ pipe = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',torch_dtype=torch.float16,)
455
+
456
+ # pipe.__call__ = call
457
+ pipe = pipe.to(device)
458
+ unet = pipe.unet
459
+ if 'full' in lora_weight:
460
+ train_method = 'full'
461
+ elif 'noxattn' in lora_weight:
462
+ train_method = 'noxattn'
463
+ else:
464
+ train_method = 'noxattn'
465
+
466
+ network_type = "c3lier"
467
+ if train_method == 'xattn':
468
+ network_type = 'lierla'
469
+
470
+ modules = DEFAULT_TARGET_REPLACE
471
+ if network_type == "c3lier":
472
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
473
+ import os
474
+ model_name = lora_weight
475
+
476
+ name = os.path.basename(model_name)
477
+ rank = 1
478
+ alpha = 4
479
+ if 'rank4' in lora_weight:
480
+ rank = 4
481
+ if 'rank8' in lora_weight:
482
+ rank = 8
483
+ if 'alpha1' in lora_weight:
484
+ alpha = 1.0
485
+ network = LoRANetwork(
486
+ unet,
487
+ rank=rank,
488
+ multiplier=1.0,
489
+ alpha=alpha,
490
+ train_method=train_method,
491
+ ).to(device, dtype=weight_dtype)
492
+ network.load_state_dict(torch.load(lora_weight))
493
+
494
+ for idx, prompt in enumerate(prompts):
495
+ seed = seeds[idx]
496
+ case_number = case_numbers[idx]
497
+
498
+ if not (case_number>=from_case and case_number<=till_case):
499
+ continue
500
+ if os.path.exists(f'{save_path}/{os.path.basename(lora_weight)}/{scale}/{case_number}_{idx}.png'):
501
+ continue
502
+ print(prompt, seed)
503
+ for scale in scales:
504
+ generator = torch.manual_seed(seed)
505
+ images = pipe(prompt, num_images_per_prompt=args.num_samples, num_inference_steps=50, generator=generator, network=network, start_noise=start_noise, scale=scale, unet=unet).images
506
+ for idx, im in enumerate(images):
507
+ im.save(f'{save_path}/{os.path.basename(lora_weight)}/{scale}/{case_number}_{idx}.png')
508
+ del unet, network, pipe
509
+ unet = None
510
+ network = None
511
+ pipe = None
512
+ torch.cuda.empty_cache()
513
+ flush()
trainscripts/textsliders/lora.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:
2
+ # - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
3
+ # - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
4
+
5
+ import os
6
+ import math
7
+ from typing import Optional, List, Type, Set, Literal
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from diffusers import UNet2DConditionModel
12
+ from safetensors.torch import save_file
13
+
14
+
15
+ UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
16
+ # "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
17
+ "Attention"
18
+ ]
19
+ UNET_TARGET_REPLACE_MODULE_CONV = [
20
+ "ResnetBlock2D",
21
+ "Downsample2D",
22
+ "Upsample2D",
23
+ "DownBlock2D",
24
+ "UpBlock2D",
25
+
26
+ ] # locon, 3clier
27
+
28
+ LORA_PREFIX_UNET = "lora_unet"
29
+
30
+ DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
31
+
32
+ TRAINING_METHODS = Literal[
33
+ "noxattn", # train all layers except x-attns and time_embed layers
34
+ "innoxattn", # train all layers except self attention layers
35
+ "selfattn", # ESD-u, train only self attention layers
36
+ "xattn", # ESD-x, train only x attention layers
37
+ "full", # train all layers
38
+ "xattn-strict", # q and k values
39
+ "noxattn-hspace",
40
+ "noxattn-hspace-last",
41
+ # "xlayer",
42
+ # "outxattn",
43
+ # "outsattn",
44
+ # "inxattn",
45
+ # "inmidsattn",
46
+ # "selflayer",
47
+ ]
48
+
49
+
50
+ class LoRAModule(nn.Module):
51
+ """
52
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ lora_name,
58
+ org_module: nn.Module,
59
+ multiplier=1.0,
60
+ lora_dim=4,
61
+ alpha=1,
62
+ ):
63
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
64
+ super().__init__()
65
+ self.lora_name = lora_name
66
+ self.lora_dim = lora_dim
67
+
68
+ if "Linear" in org_module.__class__.__name__:
69
+ in_dim = org_module.in_features
70
+ out_dim = org_module.out_features
71
+ self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
72
+ self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
73
+
74
+ elif "Conv" in org_module.__class__.__name__: # 一応
75
+ in_dim = org_module.in_channels
76
+ out_dim = org_module.out_channels
77
+
78
+ self.lora_dim = min(self.lora_dim, in_dim, out_dim)
79
+ if self.lora_dim != lora_dim:
80
+ print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
81
+
82
+ kernel_size = org_module.kernel_size
83
+ stride = org_module.stride
84
+ padding = org_module.padding
85
+ self.lora_down = nn.Conv2d(
86
+ in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
87
+ )
88
+ self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
89
+
90
+ if type(alpha) == torch.Tensor:
91
+ alpha = alpha.detach().numpy()
92
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
93
+ self.scale = alpha / self.lora_dim
94
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
95
+
96
+ # same as microsoft's
97
+ nn.init.kaiming_uniform_(self.lora_down.weight, a=1)
98
+ nn.init.zeros_(self.lora_up.weight)
99
+
100
+ self.multiplier = multiplier
101
+ self.org_module = org_module # remove in applying
102
+
103
+ def apply_to(self):
104
+ self.org_forward = self.org_module.forward
105
+ self.org_module.forward = self.forward
106
+ del self.org_module
107
+
108
+ def forward(self, x):
109
+ return (
110
+ self.org_forward(x)
111
+ + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
112
+ )
113
+
114
+
115
+ class LoRANetwork(nn.Module):
116
+ def __init__(
117
+ self,
118
+ unet: UNet2DConditionModel,
119
+ rank: int = 4,
120
+ multiplier: float = 1.0,
121
+ alpha: float = 1.0,
122
+ train_method: TRAINING_METHODS = "full",
123
+ ) -> None:
124
+ super().__init__()
125
+ self.lora_scale = 1
126
+ self.multiplier = multiplier
127
+ self.lora_dim = rank
128
+ self.alpha = alpha
129
+
130
+ # LoRAのみ
131
+ self.module = LoRAModule
132
+
133
+ # unetのloraを作る
134
+ self.unet_loras = self.create_modules(
135
+ LORA_PREFIX_UNET,
136
+ unet,
137
+ DEFAULT_TARGET_REPLACE,
138
+ self.lora_dim,
139
+ self.multiplier,
140
+ train_method=train_method,
141
+ )
142
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
143
+
144
+ # assertion 名前の被りがないか確認しているようだ
145
+ lora_names = set()
146
+ for lora in self.unet_loras:
147
+ assert (
148
+ lora.lora_name not in lora_names
149
+ ), f"duplicated lora name: {lora.lora_name}. {lora_names}"
150
+ lora_names.add(lora.lora_name)
151
+
152
+ # 適用する
153
+ for lora in self.unet_loras:
154
+ lora.apply_to()
155
+ self.add_module(
156
+ lora.lora_name,
157
+ lora,
158
+ )
159
+
160
+ del unet
161
+
162
+ torch.cuda.empty_cache()
163
+
164
+ def create_modules(
165
+ self,
166
+ prefix: str,
167
+ root_module: nn.Module,
168
+ target_replace_modules: List[str],
169
+ rank: int,
170
+ multiplier: float,
171
+ train_method: TRAINING_METHODS,
172
+ ) -> list:
173
+ loras = []
174
+ names = []
175
+ for name, module in root_module.named_modules():
176
+ if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習
177
+ if "attn2" in name or "time_embed" in name:
178
+ continue
179
+ elif train_method == "innoxattn": # Cross Attention 以外学習
180
+ if "attn2" in name:
181
+ continue
182
+ elif train_method == "selfattn": # Self Attention のみ学習
183
+ if "attn1" not in name:
184
+ continue
185
+ elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention のみ学習
186
+ if "attn2" not in name:
187
+ continue
188
+ elif train_method == "full": # 全部学習
189
+ pass
190
+ else:
191
+ raise NotImplementedError(
192
+ f"train_method: {train_method} is not implemented."
193
+ )
194
+ if module.__class__.__name__ in target_replace_modules:
195
+ for child_name, child_module in module.named_modules():
196
+ if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:
197
+ if train_method == 'xattn-strict':
198
+ if 'out' in child_name:
199
+ continue
200
+ if train_method == 'noxattn-hspace':
201
+ if 'mid_block' not in name:
202
+ continue
203
+ if train_method == 'noxattn-hspace-last':
204
+ if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
205
+ continue
206
+ lora_name = prefix + "." + name + "." + child_name
207
+ lora_name = lora_name.replace(".", "_")
208
+ # print(f"{lora_name}")
209
+ lora = self.module(
210
+ lora_name, child_module, multiplier, rank, self.alpha
211
+ )
212
+ # print(name, child_name)
213
+ # print(child_module.weight.shape)
214
+ if lora_name not in names:
215
+ loras.append(lora)
216
+ names.append(lora_name)
217
+ # print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}')
218
+ return loras
219
+
220
+ def prepare_optimizer_params(self):
221
+ all_params = []
222
+
223
+ if self.unet_loras: # 実質これしかない
224
+ params = []
225
+ [params.extend(lora.parameters()) for lora in self.unet_loras]
226
+ param_data = {"params": params}
227
+ all_params.append(param_data)
228
+
229
+ return all_params
230
+
231
+ def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
232
+ state_dict = self.state_dict()
233
+
234
+ if dtype is not None:
235
+ for key in list(state_dict.keys()):
236
+ v = state_dict[key]
237
+ v = v.detach().clone().to("cpu").to(dtype)
238
+ state_dict[key] = v
239
+
240
+ # for key in list(state_dict.keys()):
241
+ # if not key.startswith("lora"):
242
+ # # lora以外除外
243
+ # del state_dict[key]
244
+
245
+ if os.path.splitext(file)[1] == ".safetensors":
246
+ save_file(state_dict, file, metadata)
247
+ else:
248
+ torch.save(state_dict, file)
249
+ def set_lora_slider(self, scale):
250
+ self.lora_scale = scale
251
+
252
+ def __enter__(self):
253
+ for lora in self.unet_loras:
254
+ lora.multiplier = 1.0 * self.lora_scale
255
+
256
+ def __exit__(self, exc_type, exc_value, tb):
257
+ for lora in self.unet_loras:
258
+ lora.multiplier = 0
trainscripts/textsliders/model_util.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Union, Optional
2
+
3
+ import torch
4
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
5
+ from diffusers import (
6
+ UNet2DConditionModel,
7
+ SchedulerMixin,
8
+ StableDiffusionPipeline,
9
+ StableDiffusionXLPipeline,
10
+ )
11
+ from diffusers.schedulers import (
12
+ DDIMScheduler,
13
+ DDPMScheduler,
14
+ LMSDiscreteScheduler,
15
+ EulerAncestralDiscreteScheduler,
16
+ )
17
+
18
+
19
+ TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
20
+ TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
21
+
22
+ AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"]
23
+
24
+ SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
25
+
26
+ DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
27
+
28
+
29
+ def load_diffusers_model(
30
+ pretrained_model_name_or_path: str,
31
+ v2: bool = False,
32
+ clip_skip: Optional[int] = None,
33
+ weight_dtype: torch.dtype = torch.float32,
34
+ ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
35
+ # VAE はいらない
36
+
37
+ if v2:
38
+ tokenizer = CLIPTokenizer.from_pretrained(
39
+ TOKENIZER_V2_MODEL_NAME,
40
+ subfolder="tokenizer",
41
+ torch_dtype=weight_dtype,
42
+ cache_dir=DIFFUSERS_CACHE_DIR,
43
+ )
44
+ text_encoder = CLIPTextModel.from_pretrained(
45
+ pretrained_model_name_or_path,
46
+ subfolder="text_encoder",
47
+ # default is clip skip 2
48
+ num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
49
+ torch_dtype=weight_dtype,
50
+ cache_dir=DIFFUSERS_CACHE_DIR,
51
+ )
52
+ else:
53
+ tokenizer = CLIPTokenizer.from_pretrained(
54
+ TOKENIZER_V1_MODEL_NAME,
55
+ subfolder="tokenizer",
56
+ torch_dtype=weight_dtype,
57
+ cache_dir=DIFFUSERS_CACHE_DIR,
58
+ )
59
+ text_encoder = CLIPTextModel.from_pretrained(
60
+ pretrained_model_name_or_path,
61
+ subfolder="text_encoder",
62
+ num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
63
+ torch_dtype=weight_dtype,
64
+ cache_dir=DIFFUSERS_CACHE_DIR,
65
+ )
66
+
67
+ unet = UNet2DConditionModel.from_pretrained(
68
+ pretrained_model_name_or_path,
69
+ subfolder="unet",
70
+ torch_dtype=weight_dtype,
71
+ cache_dir=DIFFUSERS_CACHE_DIR,
72
+ )
73
+
74
+ return tokenizer, text_encoder, unet
75
+
76
+
77
+ def load_checkpoint_model(
78
+ checkpoint_path: str,
79
+ v2: bool = False,
80
+ clip_skip: Optional[int] = None,
81
+ weight_dtype: torch.dtype = torch.float32,
82
+ ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
83
+ pipe = StableDiffusionPipeline.from_ckpt(
84
+ checkpoint_path,
85
+ upcast_attention=True if v2 else False,
86
+ torch_dtype=weight_dtype,
87
+ cache_dir=DIFFUSERS_CACHE_DIR,
88
+ )
89
+
90
+ unet = pipe.unet
91
+ tokenizer = pipe.tokenizer
92
+ text_encoder = pipe.text_encoder
93
+ if clip_skip is not None:
94
+ if v2:
95
+ text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
96
+ else:
97
+ text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
98
+
99
+ del pipe
100
+
101
+ return tokenizer, text_encoder, unet
102
+
103
+
104
+ def load_models(
105
+ pretrained_model_name_or_path: str,
106
+ scheduler_name: AVAILABLE_SCHEDULERS,
107
+ v2: bool = False,
108
+ v_pred: bool = False,
109
+ weight_dtype: torch.dtype = torch.float32,
110
+ ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
111
+ if pretrained_model_name_or_path.endswith(
112
+ ".ckpt"
113
+ ) or pretrained_model_name_or_path.endswith(".safetensors"):
114
+ tokenizer, text_encoder, unet = load_checkpoint_model(
115
+ pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
116
+ )
117
+ else: # diffusers
118
+ tokenizer, text_encoder, unet = load_diffusers_model(
119
+ pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
120
+ )
121
+
122
+ # VAE はいらない
123
+
124
+ scheduler = create_noise_scheduler(
125
+ scheduler_name,
126
+ prediction_type="v_prediction" if v_pred else "epsilon",
127
+ )
128
+
129
+ return tokenizer, text_encoder, unet, scheduler
130
+
131
+
132
+ def load_diffusers_model_xl(
133
+ pretrained_model_name_or_path: str,
134
+ weight_dtype: torch.dtype = torch.float32,
135
+ ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
136
+ # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
137
+
138
+ tokenizers = [
139
+ CLIPTokenizer.from_pretrained(
140
+ pretrained_model_name_or_path,
141
+ subfolder="tokenizer",
142
+ torch_dtype=weight_dtype,
143
+ cache_dir=DIFFUSERS_CACHE_DIR,
144
+ ),
145
+ CLIPTokenizer.from_pretrained(
146
+ pretrained_model_name_or_path,
147
+ subfolder="tokenizer_2",
148
+ torch_dtype=weight_dtype,
149
+ cache_dir=DIFFUSERS_CACHE_DIR,
150
+ pad_token_id=0, # same as open clip
151
+ ),
152
+ ]
153
+
154
+ text_encoders = [
155
+ CLIPTextModel.from_pretrained(
156
+ pretrained_model_name_or_path,
157
+ subfolder="text_encoder",
158
+ torch_dtype=weight_dtype,
159
+ cache_dir=DIFFUSERS_CACHE_DIR,
160
+ ),
161
+ CLIPTextModelWithProjection.from_pretrained(
162
+ pretrained_model_name_or_path,
163
+ subfolder="text_encoder_2",
164
+ torch_dtype=weight_dtype,
165
+ cache_dir=DIFFUSERS_CACHE_DIR,
166
+ ),
167
+ ]
168
+
169
+ unet = UNet2DConditionModel.from_pretrained(
170
+ pretrained_model_name_or_path,
171
+ subfolder="unet",
172
+ torch_dtype=weight_dtype,
173
+ cache_dir=DIFFUSERS_CACHE_DIR,
174
+ )
175
+
176
+ return tokenizers, text_encoders, unet
177
+
178
+
179
+ def load_checkpoint_model_xl(
180
+ checkpoint_path: str,
181
+ weight_dtype: torch.dtype = torch.float32,
182
+ ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
183
+ pipe = StableDiffusionXLPipeline.from_single_file(
184
+ checkpoint_path,
185
+ torch_dtype=weight_dtype,
186
+ cache_dir=DIFFUSERS_CACHE_DIR,
187
+ )
188
+
189
+ unet = pipe.unet
190
+ tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
191
+ text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
192
+ if len(text_encoders) == 2:
193
+ text_encoders[1].pad_token_id = 0
194
+
195
+ del pipe
196
+
197
+ return tokenizers, text_encoders, unet
198
+
199
+
200
+ def load_models_xl(
201
+ pretrained_model_name_or_path: str,
202
+ scheduler_name: AVAILABLE_SCHEDULERS,
203
+ weight_dtype: torch.dtype = torch.float32,
204
+ ) -> tuple[
205
+ list[CLIPTokenizer],
206
+ list[SDXL_TEXT_ENCODER_TYPE],
207
+ UNet2DConditionModel,
208
+ SchedulerMixin,
209
+ ]:
210
+ if pretrained_model_name_or_path.endswith(
211
+ ".ckpt"
212
+ ) or pretrained_model_name_or_path.endswith(".safetensors"):
213
+ (
214
+ tokenizers,
215
+ text_encoders,
216
+ unet,
217
+ ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype)
218
+ else: # diffusers
219
+ (
220
+ tokenizers,
221
+ text_encoders,
222
+ unet,
223
+ ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype)
224
+
225
+ scheduler = create_noise_scheduler(scheduler_name)
226
+
227
+ return tokenizers, text_encoders, unet, scheduler
228
+
229
+
230
+ def create_noise_scheduler(
231
+ scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
232
+ prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
233
+ ) -> SchedulerMixin:
234
+ # 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。
235
+
236
+ name = scheduler_name.lower().replace(" ", "_")
237
+ if name == "ddim":
238
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
239
+ scheduler = DDIMScheduler(
240
+ beta_start=0.00085,
241
+ beta_end=0.012,
242
+ beta_schedule="scaled_linear",
243
+ num_train_timesteps=1000,
244
+ clip_sample=False,
245
+ prediction_type=prediction_type, # これでいいの?
246
+ )
247
+ elif name == "ddpm":
248
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
249
+ scheduler = DDPMScheduler(
250
+ beta_start=0.00085,
251
+ beta_end=0.012,
252
+ beta_schedule="scaled_linear",
253
+ num_train_timesteps=1000,
254
+ clip_sample=False,
255
+ prediction_type=prediction_type,
256
+ )
257
+ elif name == "lms":
258
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
259
+ scheduler = LMSDiscreteScheduler(
260
+ beta_start=0.00085,
261
+ beta_end=0.012,
262
+ beta_schedule="scaled_linear",
263
+ num_train_timesteps=1000,
264
+ prediction_type=prediction_type,
265
+ )
266
+ elif name == "euler_a":
267
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
268
+ scheduler = EulerAncestralDiscreteScheduler(
269
+ beta_start=0.00085,
270
+ beta_end=0.012,
271
+ beta_schedule="scaled_linear",
272
+ num_train_timesteps=1000,
273
+ prediction_type=prediction_type,
274
+ )
275
+ else:
276
+ raise ValueError(f"Unknown scheduler name: {name}")
277
+
278
+ return scheduler
trainscripts/textsliders/prompt_util.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional, Union, List
2
+
3
+ import yaml
4
+ from pathlib import Path
5
+
6
+
7
+ from pydantic import BaseModel, root_validator
8
+ import torch
9
+ import copy
10
+
11
+ ACTION_TYPES = Literal[
12
+ "erase",
13
+ "enhance",
14
+ ]
15
+
16
+
17
+ # XL は二種類必要なので
18
+ class PromptEmbedsXL:
19
+ text_embeds: torch.FloatTensor
20
+ pooled_embeds: torch.FloatTensor
21
+
22
+ def __init__(self, *args) -> None:
23
+ self.text_embeds = args[0]
24
+ self.pooled_embeds = args[1]
25
+
26
+
27
+ # SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL
28
+ PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL]
29
+
30
+
31
+ class PromptEmbedsCache: # 使いまわしたいので
32
+ prompts: dict[str, PROMPT_EMBEDDING] = {}
33
+
34
+ def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None:
35
+ self.prompts[__name] = __value
36
+
37
+ def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]:
38
+ if __name in self.prompts:
39
+ return self.prompts[__name]
40
+ else:
41
+ return None
42
+
43
+
44
+ class PromptSettings(BaseModel): # yaml のやつ
45
+ target: str
46
+ positive: str = None # if None, target will be used
47
+ unconditional: str = "" # default is ""
48
+ neutral: str = None # if None, unconditional will be used
49
+ action: ACTION_TYPES = "erase" # default is "erase"
50
+ guidance_scale: float = 1.0 # default is 1.0
51
+ resolution: int = 512 # default is 512
52
+ dynamic_resolution: bool = False # default is False
53
+ batch_size: int = 1 # default is 1
54
+ dynamic_crops: bool = False # default is False. only used when model is XL
55
+
56
+ @root_validator(pre=True)
57
+ def fill_prompts(cls, values):
58
+ keys = values.keys()
59
+ if "target" not in keys:
60
+ raise ValueError("target must be specified")
61
+ if "positive" not in keys:
62
+ values["positive"] = values["target"]
63
+ if "unconditional" not in keys:
64
+ values["unconditional"] = ""
65
+ if "neutral" not in keys:
66
+ values["neutral"] = values["unconditional"]
67
+
68
+ return values
69
+
70
+
71
+ class PromptEmbedsPair:
72
+ target: PROMPT_EMBEDDING # not want to generate the concept
73
+ positive: PROMPT_EMBEDDING # generate the concept
74
+ unconditional: PROMPT_EMBEDDING # uncondition (default should be empty)
75
+ neutral: PROMPT_EMBEDDING # base condition (default should be empty)
76
+
77
+ guidance_scale: float
78
+ resolution: int
79
+ dynamic_resolution: bool
80
+ batch_size: int
81
+ dynamic_crops: bool
82
+
83
+ loss_fn: torch.nn.Module
84
+ action: ACTION_TYPES
85
+
86
+ def __init__(
87
+ self,
88
+ loss_fn: torch.nn.Module,
89
+ target: PROMPT_EMBEDDING,
90
+ positive: PROMPT_EMBEDDING,
91
+ unconditional: PROMPT_EMBEDDING,
92
+ neutral: PROMPT_EMBEDDING,
93
+ settings: PromptSettings,
94
+ ) -> None:
95
+ self.loss_fn = loss_fn
96
+ self.target = target
97
+ self.positive = positive
98
+ self.unconditional = unconditional
99
+ self.neutral = neutral
100
+
101
+ self.guidance_scale = settings.guidance_scale
102
+ self.resolution = settings.resolution
103
+ self.dynamic_resolution = settings.dynamic_resolution
104
+ self.batch_size = settings.batch_size
105
+ self.dynamic_crops = settings.dynamic_crops
106
+ self.action = settings.action
107
+
108
+ def _erase(
109
+ self,
110
+ target_latents: torch.FloatTensor, # "van gogh"
111
+ positive_latents: torch.FloatTensor, # "van gogh"
112
+ unconditional_latents: torch.FloatTensor, # ""
113
+ neutral_latents: torch.FloatTensor, # ""
114
+ ) -> torch.FloatTensor:
115
+ """Target latents are going not to have the positive concept."""
116
+ return self.loss_fn(
117
+ target_latents,
118
+ neutral_latents
119
+ - self.guidance_scale * (positive_latents - unconditional_latents)
120
+ )
121
+
122
+
123
+ def _enhance(
124
+ self,
125
+ target_latents: torch.FloatTensor, # "van gogh"
126
+ positive_latents: torch.FloatTensor, # "van gogh"
127
+ unconditional_latents: torch.FloatTensor, # ""
128
+ neutral_latents: torch.FloatTensor, # ""
129
+ ):
130
+ """Target latents are going to have the positive concept."""
131
+ return self.loss_fn(
132
+ target_latents,
133
+ neutral_latents
134
+ + self.guidance_scale * (positive_latents - unconditional_latents)
135
+ )
136
+
137
+ def loss(
138
+ self,
139
+ **kwargs,
140
+ ):
141
+ if self.action == "erase":
142
+ return self._erase(**kwargs)
143
+
144
+ elif self.action == "enhance":
145
+ return self._enhance(**kwargs)
146
+
147
+ else:
148
+ raise ValueError("action must be erase or enhance")
149
+
150
+
151
+ def load_prompts_from_yaml(path, target, positive, negative, attributes = []):
152
+ with open(path, "r") as f:
153
+ prompts = yaml.safe_load(f)
154
+ new = []
155
+ for prompt in prompts:
156
+ copy_ = copy.deepcopy(prompt)
157
+ copy_['target'] = target
158
+ copy_['positive'] = positive
159
+ copy_['neutral'] = target
160
+ copy_['unconditional'] = negative
161
+ new.append(copy_)
162
+ prompts = new
163
+ print(prompts)
164
+ if len(prompts) == 0:
165
+ raise ValueError("prompts file is empty")
166
+ if len(attributes)!=0:
167
+ newprompts = []
168
+ for i in range(len(prompts)):
169
+ for att in attributes:
170
+ copy_ = copy.deepcopy(prompts[i])
171
+ copy_['target'] = att + ' ' + copy_['target']
172
+ copy_['positive'] = att + ' ' + copy_['positive']
173
+ copy_['neutral'] = att + ' ' + copy_['neutral']
174
+ copy_['unconditional'] = att + ' ' + copy_['unconditional']
175
+ newprompts.append(copy_)
176
+ else:
177
+ newprompts = copy.deepcopy(prompts)
178
+
179
+ print(newprompts)
180
+ print(len(prompts), len(newprompts))
181
+ prompt_settings = [PromptSettings(**prompt) for prompt in newprompts]
182
+
183
+ return prompt_settings
trainscripts/textsliders/ptp_utils.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ from PIL import Image, ImageDraw, ImageFont
18
+ import cv2
19
+ from typing import Optional, Union, Tuple, List, Callable, Dict
20
+ from IPython.display import display
21
+ from tqdm.notebook import tqdm
22
+
23
+
24
+ def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
25
+ h, w, c = image.shape
26
+ offset = int(h * .2)
27
+ img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
28
+ font = cv2.FONT_HERSHEY_SIMPLEX
29
+ # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
30
+ img[:h] = image
31
+ textsize = cv2.getTextSize(text, font, 1, 2)[0]
32
+ text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
33
+ cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
34
+ return img
35
+
36
+
37
+ def view_images(images, num_rows=1, offset_ratio=0.02):
38
+ if type(images) is list:
39
+ num_empty = len(images) % num_rows
40
+ elif images.ndim == 4:
41
+ num_empty = images.shape[0] % num_rows
42
+ else:
43
+ images = [images]
44
+ num_empty = 0
45
+
46
+ empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
47
+ images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
48
+ num_items = len(images)
49
+
50
+ h, w, c = images[0].shape
51
+ offset = int(h * offset_ratio)
52
+ num_cols = num_items // num_rows
53
+ image_ = np.ones((h * num_rows + offset * (num_rows - 1),
54
+ w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
55
+ for i in range(num_rows):
56
+ for j in range(num_cols):
57
+ image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
58
+ i * num_cols + j]
59
+
60
+ pil_img = Image.fromarray(image_)
61
+ display(pil_img)
62
+
63
+
64
+ def diffusion_step(unet, model, controller, latents, context, t, guidance_scale, low_resource=False):
65
+ if low_resource:
66
+ noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
67
+ noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
68
+ else:
69
+ latents_input = torch.cat([latents] * 2)
70
+ noise_pred = unet(latents_input, t, encoder_hidden_states=context)["sample"]
71
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
72
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
73
+ latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
74
+ latents = controller.step_callback(latents)
75
+ return latents
76
+
77
+
78
+ def latent2image(vae, latents):
79
+ latents = 1 / 0.18215 * latents
80
+ image = vae.decode(latents)['sample']
81
+ image = (image / 2 + 0.5).clamp(0, 1)
82
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
83
+ image = (image * 255).astype(np.uint8)
84
+ return image
85
+
86
+
87
+ def init_latent(latent, model, height, width, generator, batch_size):
88
+ if latent is None:
89
+ latent = torch.randn(
90
+ (1, model.unet.in_channels, height // 8, width // 8),
91
+ generator=generator,
92
+ )
93
+ latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
94
+ return latent, latents
95
+
96
+
97
+ @torch.no_grad()
98
+ def text2image_ldm(
99
+ model,
100
+ prompt: List[str],
101
+ controller,
102
+ num_inference_steps: int = 50,
103
+ guidance_scale: Optional[float] = 7.,
104
+ generator: Optional[torch.Generator] = None,
105
+ latent: Optional[torch.FloatTensor] = None,
106
+ ):
107
+ register_attention_control(model, controller)
108
+ height = width = 256
109
+ batch_size = len(prompt)
110
+
111
+ uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
112
+ uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0]
113
+
114
+ text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
115
+ text_embeddings = model.bert(text_input.input_ids.to(model.device))[0]
116
+ latent, latents = init_latent(latent, model, height, width, generator, batch_size)
117
+ context = torch.cat([uncond_embeddings, text_embeddings])
118
+
119
+ model.scheduler.set_timesteps(num_inference_steps)
120
+ for t in tqdm(model.scheduler.timesteps):
121
+ latents = diffusion_step(model, controller, latents, context, t, guidance_scale)
122
+
123
+ image = latent2image(model.vqvae, latents)
124
+
125
+ return image, latent
126
+
127
+
128
+ @torch.no_grad()
129
+ def text2image_ldm_stable(
130
+ model,
131
+ prompt: List[str],
132
+ controller,
133
+ num_inference_steps: int = 50,
134
+ guidance_scale: float = 7.5,
135
+ generator: Optional[torch.Generator] = None,
136
+ latent: Optional[torch.FloatTensor] = None,
137
+ low_resource: bool = False,
138
+ ):
139
+ register_attention_control(model, controller)
140
+ height = width = 512
141
+ batch_size = len(prompt)
142
+
143
+ text_input = model.tokenizer(
144
+ prompt,
145
+ padding="max_length",
146
+ max_length=model.tokenizer.model_max_length,
147
+ truncation=True,
148
+ return_tensors="pt",
149
+ )
150
+ text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
151
+ max_length = text_input.input_ids.shape[-1]
152
+ uncond_input = model.tokenizer(
153
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
154
+ )
155
+ uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
156
+
157
+ context = [uncond_embeddings, text_embeddings]
158
+ if not low_resource:
159
+ context = torch.cat(context)
160
+ latent, latents = init_latent(latent, model, height, width, generator, batch_size)
161
+
162
+ # set timesteps
163
+ extra_set_kwargs = {"offset": 1}
164
+ model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
165
+ for t in tqdm(model.scheduler.timesteps):
166
+ latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource)
167
+
168
+ image = latent2image(model.vae, latents)
169
+
170
+ return image, latent
171
+
172
+
173
+ def register_attention_control(model, controller):
174
+ def ca_forward(self, place_in_unet):
175
+ to_out = self.to_out
176
+ if type(to_out) is torch.nn.modules.container.ModuleList:
177
+ to_out = self.to_out[0]
178
+ else:
179
+ to_out = self.to_out
180
+
181
+ def forward(x, context=None, mask=None):
182
+ batch_size, sequence_length, dim = x.shape
183
+ h = self.heads
184
+ q = self.to_q(x)
185
+ is_cross = context is not None
186
+ context = context if is_cross else x
187
+ k = self.to_k(context)
188
+ v = self.to_v(context)
189
+ q = self.reshape_heads_to_batch_dim(q)
190
+ k = self.reshape_heads_to_batch_dim(k)
191
+ v = self.reshape_heads_to_batch_dim(v)
192
+
193
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
194
+
195
+ if mask is not None:
196
+ mask = mask.reshape(batch_size, -1)
197
+ max_neg_value = -torch.finfo(sim.dtype).max
198
+ mask = mask[:, None, :].repeat(h, 1, 1)
199
+ sim.masked_fill_(~mask, max_neg_value)
200
+
201
+ # attention, what we cannot get enough of
202
+ attn = sim.softmax(dim=-1)
203
+ attn = controller(attn, is_cross, place_in_unet)
204
+ out = torch.einsum("b i j, b j d -> b i d", attn, v)
205
+ out = self.reshape_batch_dim_to_heads(out)
206
+ return to_out(out)
207
+
208
+ return forward
209
+
210
+ class DummyController:
211
+
212
+ def __call__(self, *args):
213
+ return args[0]
214
+
215
+ def __init__(self):
216
+ self.num_att_layers = 0
217
+
218
+ if controller is None:
219
+ controller = DummyController()
220
+
221
+ def register_recr(net_, count, place_in_unet):
222
+ if net_.__class__.__name__ == 'CrossAttention':
223
+ net_.forward = ca_forward(net_, place_in_unet)
224
+ return count + 1
225
+ elif hasattr(net_, 'children'):
226
+ for net__ in net_.children():
227
+ count = register_recr(net__, count, place_in_unet)
228
+ return count
229
+
230
+ cross_att_count = 0
231
+ sub_nets = model.unet.named_children()
232
+ for net in sub_nets:
233
+ if "down" in net[0]:
234
+ cross_att_count += register_recr(net[1], 0, "down")
235
+ elif "up" in net[0]:
236
+ cross_att_count += register_recr(net[1], 0, "up")
237
+ elif "mid" in net[0]:
238
+ cross_att_count += register_recr(net[1], 0, "mid")
239
+
240
+ controller.num_att_layers = cross_att_count
241
+
242
+
243
+ def get_word_inds(text: str, word_place: int, tokenizer):
244
+ split_text = text.split(" ")
245
+ if type(word_place) is str:
246
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
247
+ elif type(word_place) is int:
248
+ word_place = [word_place]
249
+ out = []
250
+ if len(word_place) > 0:
251
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
252
+ cur_len, ptr = 0, 0
253
+
254
+ for i in range(len(words_encode)):
255
+ cur_len += len(words_encode[i])
256
+ if ptr in word_place:
257
+ out.append(i + 1)
258
+ if cur_len >= len(split_text[ptr]):
259
+ ptr += 1
260
+ cur_len = 0
261
+ return np.array(out)
262
+
263
+
264
+ def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
265
+ word_inds: Optional[torch.Tensor]=None):
266
+ if type(bounds) is float:
267
+ bounds = 0, bounds
268
+ start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
269
+ if word_inds is None:
270
+ word_inds = torch.arange(alpha.shape[2])
271
+ alpha[: start, prompt_ind, word_inds] = 0
272
+ alpha[start: end, prompt_ind, word_inds] = 1
273
+ alpha[end:, prompt_ind, word_inds] = 0
274
+ return alpha
275
+
276
+
277
+ def get_time_words_attention_alpha(prompts, num_steps,
278
+ cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
279
+ tokenizer, max_num_words=77):
280
+ if type(cross_replace_steps) is not dict:
281
+ cross_replace_steps = {"default_": cross_replace_steps}
282
+ if "default_" not in cross_replace_steps:
283
+ cross_replace_steps["default_"] = (0., 1.)
284
+ alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
285
+ for i in range(len(prompts) - 1):
286
+ alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
287
+ i)
288
+ for key, item in cross_replace_steps.items():
289
+ if key != "default_":
290
+ inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
291
+ for i, ind in enumerate(inds):
292
+ if len(ind) > 0:
293
+ alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
294
+ alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
295
+ return alpha_time_words
trainscripts/textsliders/train_lora.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:
2
+ # - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566
3
+ # - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py
4
+
5
+ from typing import List, Optional
6
+ import argparse
7
+ import ast
8
+ from pathlib import Path
9
+ import gc
10
+
11
+ import torch
12
+ from tqdm import tqdm
13
+
14
+
15
+ from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
16
+ import train_util
17
+ import model_util
18
+ import prompt_util
19
+ from prompt_util import PromptEmbedsCache, PromptEmbedsPair, PromptSettings
20
+ import debug_util
21
+ import config_util
22
+ from config_util import RootConfig
23
+
24
+ import wandb
25
+
26
+
27
+ def flush():
28
+ torch.cuda.empty_cache()
29
+ gc.collect()
30
+
31
+
32
+ def train(
33
+ config: RootConfig,
34
+ prompts: list[PromptSettings],
35
+ device: int
36
+ ):
37
+
38
+ metadata = {
39
+ "prompts": ",".join([prompt.json() for prompt in prompts]),
40
+ "config": config.json(),
41
+ }
42
+ save_path = Path(config.save.path)
43
+
44
+ modules = DEFAULT_TARGET_REPLACE
45
+ if config.network.type == "c3lier":
46
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
47
+
48
+ if config.logging.verbose:
49
+ print(metadata)
50
+
51
+ if config.logging.use_wandb:
52
+ wandb.init(project=f"LECO_{config.save.name}", config=metadata)
53
+
54
+ weight_dtype = config_util.parse_precision(config.train.precision)
55
+ save_weight_dtype = config_util.parse_precision(config.train.precision)
56
+
57
+ tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models(
58
+ config.pretrained_model.name_or_path,
59
+ scheduler_name=config.train.noise_scheduler,
60
+ v2=config.pretrained_model.v2,
61
+ v_pred=config.pretrained_model.v_pred,
62
+ )
63
+
64
+ text_encoder.to(device, dtype=weight_dtype)
65
+ text_encoder.eval()
66
+
67
+ unet.to(device, dtype=weight_dtype)
68
+ unet.enable_xformers_memory_efficient_attention()
69
+ unet.requires_grad_(False)
70
+ unet.eval()
71
+
72
+ network = LoRANetwork(
73
+ unet,
74
+ rank=config.network.rank,
75
+ multiplier=1.0,
76
+ alpha=config.network.alpha,
77
+ train_method=config.network.training_method,
78
+ ).to(device, dtype=weight_dtype)
79
+
80
+ optimizer_module = train_util.get_optimizer(config.train.optimizer)
81
+ #optimizer_args
82
+ optimizer_kwargs = {}
83
+ if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0:
84
+ for arg in config.train.optimizer_args.split(" "):
85
+ key, value = arg.split("=")
86
+ value = ast.literal_eval(value)
87
+ optimizer_kwargs[key] = value
88
+
89
+ optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs)
90
+ lr_scheduler = train_util.get_lr_scheduler(
91
+ config.train.lr_scheduler,
92
+ optimizer,
93
+ max_iterations=config.train.iterations,
94
+ lr_min=config.train.lr / 100,
95
+ )
96
+ criteria = torch.nn.MSELoss()
97
+
98
+ print("Prompts")
99
+ for settings in prompts:
100
+ print(settings)
101
+
102
+ # debug
103
+ debug_util.check_requires_grad(network)
104
+ debug_util.check_training_mode(network)
105
+
106
+ cache = PromptEmbedsCache()
107
+ prompt_pairs: list[PromptEmbedsPair] = []
108
+
109
+ with torch.no_grad():
110
+ for settings in prompts:
111
+ print(settings)
112
+ for prompt in [
113
+ settings.target,
114
+ settings.positive,
115
+ settings.neutral,
116
+ settings.unconditional,
117
+ ]:
118
+ print(prompt)
119
+ if isinstance(prompt, list):
120
+ if prompt == settings.positive:
121
+ key_setting = 'positive'
122
+ else:
123
+ key_setting = 'attributes'
124
+ if len(prompt) == 0:
125
+ cache[key_setting] = []
126
+ else:
127
+ if cache[key_setting] is None:
128
+ cache[key_setting] = train_util.encode_prompts(
129
+ tokenizer, text_encoder, prompt
130
+ )
131
+ else:
132
+ if cache[prompt] == None:
133
+ cache[prompt] = train_util.encode_prompts(
134
+ tokenizer, text_encoder, [prompt]
135
+ )
136
+
137
+ prompt_pairs.append(
138
+ PromptEmbedsPair(
139
+ criteria,
140
+ cache[settings.target],
141
+ cache[settings.positive],
142
+ cache[settings.unconditional],
143
+ cache[settings.neutral],
144
+ settings,
145
+ )
146
+ )
147
+
148
+ del tokenizer
149
+ del text_encoder
150
+
151
+ flush()
152
+
153
+ pbar = tqdm(range(config.train.iterations))
154
+
155
+ for i in pbar:
156
+ with torch.no_grad():
157
+ noise_scheduler.set_timesteps(
158
+ config.train.max_denoising_steps, device=device
159
+ )
160
+
161
+ optimizer.zero_grad()
162
+
163
+ prompt_pair: PromptEmbedsPair = prompt_pairs[
164
+ torch.randint(0, len(prompt_pairs), (1,)).item()
165
+ ]
166
+
167
+ # 1 ~ 49 からランダム
168
+ timesteps_to = torch.randint(
169
+ 1, config.train.max_denoising_steps, (1,)
170
+ ).item()
171
+
172
+ height, width = (
173
+ prompt_pair.resolution,
174
+ prompt_pair.resolution,
175
+ )
176
+ if prompt_pair.dynamic_resolution:
177
+ height, width = train_util.get_random_resolution_in_bucket(
178
+ prompt_pair.resolution
179
+ )
180
+
181
+ if config.logging.verbose:
182
+ print("guidance_scale:", prompt_pair.guidance_scale)
183
+ print("resolution:", prompt_pair.resolution)
184
+ print("dynamic_resolution:", prompt_pair.dynamic_resolution)
185
+ if prompt_pair.dynamic_resolution:
186
+ print("bucketed resolution:", (height, width))
187
+ print("batch_size:", prompt_pair.batch_size)
188
+
189
+ latents = train_util.get_initial_latents(
190
+ noise_scheduler, prompt_pair.batch_size, height, width, 1
191
+ ).to(device, dtype=weight_dtype)
192
+
193
+ with network:
194
+ # ちょっとデノイズされれたものが返る
195
+ denoised_latents = train_util.diffusion(
196
+ unet,
197
+ noise_scheduler,
198
+ latents, # 単純なノイズのlatentsを渡す
199
+ train_util.concat_embeddings(
200
+ prompt_pair.unconditional,
201
+ prompt_pair.target,
202
+ prompt_pair.batch_size,
203
+ ),
204
+ start_timesteps=0,
205
+ total_timesteps=timesteps_to,
206
+ guidance_scale=3,
207
+ )
208
+
209
+ noise_scheduler.set_timesteps(1000)
210
+
211
+ current_timestep = noise_scheduler.timesteps[
212
+ int(timesteps_to * 1000 / config.train.max_denoising_steps)
213
+ ]
214
+
215
+ # with network: の外では空のLoRAのみが有効になる
216
+ positive_latents = train_util.predict_noise(
217
+ unet,
218
+ noise_scheduler,
219
+ current_timestep,
220
+ denoised_latents,
221
+ train_util.concat_embeddings(
222
+ prompt_pair.unconditional,
223
+ prompt_pair.positive,
224
+ prompt_pair.batch_size,
225
+ ),
226
+ guidance_scale=1,
227
+ ).to(device, dtype=weight_dtype)
228
+
229
+ neutral_latents = train_util.predict_noise(
230
+ unet,
231
+ noise_scheduler,
232
+ current_timestep,
233
+ denoised_latents,
234
+ train_util.concat_embeddings(
235
+ prompt_pair.unconditional,
236
+ prompt_pair.neutral,
237
+ prompt_pair.batch_size,
238
+ ),
239
+ guidance_scale=1,
240
+ ).to(device, dtype=weight_dtype)
241
+ unconditional_latents = train_util.predict_noise(
242
+ unet,
243
+ noise_scheduler,
244
+ current_timestep,
245
+ denoised_latents,
246
+ train_util.concat_embeddings(
247
+ prompt_pair.unconditional,
248
+ prompt_pair.unconditional,
249
+ prompt_pair.batch_size,
250
+ ),
251
+ guidance_scale=1,
252
+ ).to(device, dtype=weight_dtype)
253
+
254
+
255
+ #########################
256
+ if config.logging.verbose:
257
+ print("positive_latents:", positive_latents[0, 0, :5, :5])
258
+ print("neutral_latents:", neutral_latents[0, 0, :5, :5])
259
+ print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
260
+
261
+ with network:
262
+ target_latents = train_util.predict_noise(
263
+ unet,
264
+ noise_scheduler,
265
+ current_timestep,
266
+ denoised_latents,
267
+ train_util.concat_embeddings(
268
+ prompt_pair.unconditional,
269
+ prompt_pair.target,
270
+ prompt_pair.batch_size,
271
+ ),
272
+ guidance_scale=1,
273
+ ).to(device, dtype=weight_dtype)
274
+
275
+ #########################
276
+
277
+ if config.logging.verbose:
278
+ print("target_latents:", target_latents[0, 0, :5, :5])
279
+
280
+ positive_latents.requires_grad = False
281
+ neutral_latents.requires_grad = False
282
+ unconditional_latents.requires_grad = False
283
+
284
+ loss = prompt_pair.loss(
285
+ target_latents=target_latents,
286
+ positive_latents=positive_latents,
287
+ neutral_latents=neutral_latents,
288
+ unconditional_latents=unconditional_latents,
289
+ )
290
+
291
+ # 1000倍しないとずっと0.000...になってしまって見た目的に面白くない
292
+ pbar.set_description(f"Loss*1k: {loss.item()*1000:.4f}")
293
+ if config.logging.use_wandb:
294
+ wandb.log(
295
+ {"loss": loss, "iteration": i, "lr": lr_scheduler.get_last_lr()[0]}
296
+ )
297
+
298
+ loss.backward()
299
+ optimizer.step()
300
+ lr_scheduler.step()
301
+
302
+ del (
303
+ positive_latents,
304
+ neutral_latents,
305
+ unconditional_latents,
306
+ target_latents,
307
+ latents,
308
+ )
309
+ flush()
310
+
311
+ if (
312
+ i % config.save.per_steps == 0
313
+ and i != 0
314
+ and i != config.train.iterations - 1
315
+ ):
316
+ print("Saving...")
317
+ save_path.mkdir(parents=True, exist_ok=True)
318
+ network.save_weights(
319
+ save_path / f"{config.save.name}_{i}steps.pt",
320
+ dtype=save_weight_dtype,
321
+ )
322
+
323
+ print("Saving...")
324
+ save_path.mkdir(parents=True, exist_ok=True)
325
+ network.save_weights(
326
+ save_path / f"{config.save.name}_last.pt",
327
+ dtype=save_weight_dtype,
328
+ )
329
+
330
+ del (
331
+ unet,
332
+ noise_scheduler,
333
+ loss,
334
+ optimizer,
335
+ network,
336
+ )
337
+
338
+ flush()
339
+
340
+ print("Done.")
341
+
342
+
343
+ def main(args):
344
+ config_file = args.config_file
345
+
346
+ config = config_util.load_config_from_yaml(config_file)
347
+ if args.name is not None:
348
+ config.save.name = args.name
349
+ attributes = []
350
+ if args.attributes is not None:
351
+ attributes = args.attributes.split(',')
352
+ attributes = [a.strip() for a in attributes]
353
+
354
+ config.network.alpha = args.alpha
355
+ config.network.rank = args.rank
356
+ config.save.name += f'_alpha{args.alpha}'
357
+ config.save.name += f'_rank{config.network.rank }'
358
+ config.save.name += f'_{config.network.training_method}'
359
+ config.save.path += f'/{config.save.name}'
360
+
361
+ prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes)
362
+ device = torch.device(f"cuda:{args.device}")
363
+
364
+ train(config=config, prompts=prompts, device=device)
365
+
366
+
367
+ if __name__ == "__main__":
368
+ parser = argparse.ArgumentParser()
369
+ parser.add_argument(
370
+ "--config_file",
371
+ required=False,
372
+ default = 'data/config.yaml',
373
+ help="Config file for training.",
374
+ )
375
+ # config_file 'data/config.yaml'
376
+ parser.add_argument(
377
+ "--alpha",
378
+ type=float,
379
+ required=True,
380
+ help="LoRA weight.",
381
+ )
382
+ # --alpha 1.0
383
+ parser.add_argument(
384
+ "--rank",
385
+ type=int,
386
+ required=False,
387
+ help="Rank of LoRA.",
388
+ default=4,
389
+ )
390
+ # --rank 4
391
+ parser.add_argument(
392
+ "--device",
393
+ type=int,
394
+ required=False,
395
+ default=0,
396
+ help="Device to train on.",
397
+ )
398
+ # --device 0
399
+ parser.add_argument(
400
+ "--name",
401
+ type=str,
402
+ required=False,
403
+ default=None,
404
+ help="Device to train on.",
405
+ )
406
+ # --name 'eyesize_slider'
407
+ parser.add_argument(
408
+ "--attributes",
409
+ type=str,
410
+ required=False,
411
+ default=None,
412
+ help="attritbutes to disentangle (comma seperated string)",
413
+ )
414
+
415
+ # --attributes 'male, female'
416
+
417
+ args = parser.parse_args()
418
+
419
+ main(args)
trainscripts/textsliders/train_lora_xl.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:
2
+ # - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566
3
+ # - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py
4
+
5
+ from typing import List, Optional
6
+ import argparse
7
+ import ast
8
+ from pathlib import Path
9
+ import gc
10
+
11
+ import torch
12
+ from tqdm import tqdm
13
+
14
+
15
+ from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
16
+ import train_util
17
+ import model_util
18
+ import prompt_util
19
+ from prompt_util import (
20
+ PromptEmbedsCache,
21
+ PromptEmbedsPair,
22
+ PromptSettings,
23
+ PromptEmbedsXL,
24
+ )
25
+ import debug_util
26
+ import config_util
27
+ from config_util import RootConfig
28
+
29
+ import wandb
30
+
31
+ NUM_IMAGES_PER_PROMPT = 1
32
+
33
+
34
+ def flush():
35
+ torch.cuda.empty_cache()
36
+ gc.collect()
37
+
38
+
39
+ def train(
40
+ config: RootConfig,
41
+ prompts: list[PromptSettings],
42
+ device,
43
+ ):
44
+ metadata = {
45
+ "prompts": ",".join([prompt.json() for prompt in prompts]),
46
+ "config": config.json(),
47
+ }
48
+ save_path = Path(config.save.path)
49
+
50
+ modules = DEFAULT_TARGET_REPLACE
51
+ if config.network.type == "c3lier":
52
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
53
+
54
+ if config.logging.verbose:
55
+ print(metadata)
56
+
57
+ if config.logging.use_wandb:
58
+ wandb.init(project=f"LECO_{config.save.name}", config=metadata)
59
+
60
+ weight_dtype = config_util.parse_precision(config.train.precision)
61
+ save_weight_dtype = config_util.parse_precision(config.train.precision)
62
+
63
+ (
64
+ tokenizers,
65
+ text_encoders,
66
+ unet,
67
+ noise_scheduler,
68
+ ) = model_util.load_models_xl(
69
+ config.pretrained_model.name_or_path,
70
+ scheduler_name=config.train.noise_scheduler,
71
+ )
72
+
73
+ for text_encoder in text_encoders:
74
+ text_encoder.to(device, dtype=weight_dtype)
75
+ text_encoder.requires_grad_(False)
76
+ text_encoder.eval()
77
+
78
+ unet.to(device, dtype=weight_dtype)
79
+ if config.other.use_xformers:
80
+ unet.enable_xformers_memory_efficient_attention()
81
+ unet.requires_grad_(False)
82
+ unet.eval()
83
+
84
+ network = LoRANetwork(
85
+ unet,
86
+ rank=config.network.rank,
87
+ multiplier=1.0,
88
+ alpha=config.network.alpha,
89
+ train_method=config.network.training_method,
90
+ ).to(device, dtype=weight_dtype)
91
+
92
+ optimizer_module = train_util.get_optimizer(config.train.optimizer)
93
+ #optimizer_args
94
+ optimizer_kwargs = {}
95
+ if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0:
96
+ for arg in config.train.optimizer_args.split(" "):
97
+ key, value = arg.split("=")
98
+ value = ast.literal_eval(value)
99
+ optimizer_kwargs[key] = value
100
+
101
+ optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs)
102
+ lr_scheduler = train_util.get_lr_scheduler(
103
+ config.train.lr_scheduler,
104
+ optimizer,
105
+ max_iterations=config.train.iterations,
106
+ lr_min=config.train.lr / 100,
107
+ )
108
+ criteria = torch.nn.MSELoss()
109
+
110
+ print("Prompts")
111
+ for settings in prompts:
112
+ print(settings)
113
+
114
+ # debug
115
+ debug_util.check_requires_grad(network)
116
+ debug_util.check_training_mode(network)
117
+
118
+ cache = PromptEmbedsCache()
119
+ prompt_pairs: list[PromptEmbedsPair] = []
120
+
121
+ with torch.no_grad():
122
+ for settings in prompts:
123
+ print(settings)
124
+ for prompt in [
125
+ settings.target,
126
+ settings.positive,
127
+ settings.neutral,
128
+ settings.unconditional,
129
+ ]:
130
+ if cache[prompt] == None:
131
+ tex_embs, pool_embs = train_util.encode_prompts_xl(
132
+ tokenizers,
133
+ text_encoders,
134
+ [prompt],
135
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
136
+ )
137
+ cache[prompt] = PromptEmbedsXL(
138
+ tex_embs,
139
+ pool_embs
140
+ )
141
+
142
+ prompt_pairs.append(
143
+ PromptEmbedsPair(
144
+ criteria,
145
+ cache[settings.target],
146
+ cache[settings.positive],
147
+ cache[settings.unconditional],
148
+ cache[settings.neutral],
149
+ settings,
150
+ )
151
+ )
152
+
153
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
154
+ del tokenizer, text_encoder
155
+
156
+ flush()
157
+
158
+ pbar = tqdm(range(config.train.iterations))
159
+
160
+ loss = None
161
+
162
+ for i in pbar:
163
+ with torch.no_grad():
164
+ noise_scheduler.set_timesteps(
165
+ config.train.max_denoising_steps, device=device
166
+ )
167
+
168
+ optimizer.zero_grad()
169
+
170
+ prompt_pair: PromptEmbedsPair = prompt_pairs[
171
+ torch.randint(0, len(prompt_pairs), (1,)).item()
172
+ ]
173
+
174
+ # 1 ~ 49 からランダム
175
+ timesteps_to = torch.randint(
176
+ 1, config.train.max_denoising_steps, (1,)
177
+ ).item()
178
+
179
+ height, width = prompt_pair.resolution, prompt_pair.resolution
180
+ if prompt_pair.dynamic_resolution:
181
+ height, width = train_util.get_random_resolution_in_bucket(
182
+ prompt_pair.resolution
183
+ )
184
+
185
+ if config.logging.verbose:
186
+ print("gudance_scale:", prompt_pair.guidance_scale)
187
+ print("resolution:", prompt_pair.resolution)
188
+ print("dynamic_resolution:", prompt_pair.dynamic_resolution)
189
+ if prompt_pair.dynamic_resolution:
190
+ print("bucketed resolution:", (height, width))
191
+ print("batch_size:", prompt_pair.batch_size)
192
+ print("dynamic_crops:", prompt_pair.dynamic_crops)
193
+
194
+ latents = train_util.get_initial_latents(
195
+ noise_scheduler, prompt_pair.batch_size, height, width, 1
196
+ ).to(device, dtype=weight_dtype)
197
+
198
+ add_time_ids = train_util.get_add_time_ids(
199
+ height,
200
+ width,
201
+ dynamic_crops=prompt_pair.dynamic_crops,
202
+ dtype=weight_dtype,
203
+ ).to(device, dtype=weight_dtype)
204
+
205
+ with network:
206
+ # ちょっとデノイズされれたものが返る
207
+ denoised_latents = train_util.diffusion_xl(
208
+ unet,
209
+ noise_scheduler,
210
+ latents, # 単純なノイズのlatentsを渡す
211
+ text_embeddings=train_util.concat_embeddings(
212
+ prompt_pair.unconditional.text_embeds,
213
+ prompt_pair.target.text_embeds,
214
+ prompt_pair.batch_size,
215
+ ),
216
+ add_text_embeddings=train_util.concat_embeddings(
217
+ prompt_pair.unconditional.pooled_embeds,
218
+ prompt_pair.target.pooled_embeds,
219
+ prompt_pair.batch_size,
220
+ ),
221
+ add_time_ids=train_util.concat_embeddings(
222
+ add_time_ids, add_time_ids, prompt_pair.batch_size
223
+ ),
224
+ start_timesteps=0,
225
+ total_timesteps=timesteps_to,
226
+ guidance_scale=3,
227
+ )
228
+
229
+ noise_scheduler.set_timesteps(1000)
230
+
231
+ current_timestep = noise_scheduler.timesteps[
232
+ int(timesteps_to * 1000 / config.train.max_denoising_steps)
233
+ ]
234
+
235
+ # with network: の外では空のLoRAのみが有効になる
236
+ positive_latents = train_util.predict_noise_xl(
237
+ unet,
238
+ noise_scheduler,
239
+ current_timestep,
240
+ denoised_latents,
241
+ text_embeddings=train_util.concat_embeddings(
242
+ prompt_pair.unconditional.text_embeds,
243
+ prompt_pair.positive.text_embeds,
244
+ prompt_pair.batch_size,
245
+ ),
246
+ add_text_embeddings=train_util.concat_embeddings(
247
+ prompt_pair.unconditional.pooled_embeds,
248
+ prompt_pair.positive.pooled_embeds,
249
+ prompt_pair.batch_size,
250
+ ),
251
+ add_time_ids=train_util.concat_embeddings(
252
+ add_time_ids, add_time_ids, prompt_pair.batch_size
253
+ ),
254
+ guidance_scale=1,
255
+ ).to(device, dtype=weight_dtype)
256
+ neutral_latents = train_util.predict_noise_xl(
257
+ unet,
258
+ noise_scheduler,
259
+ current_timestep,
260
+ denoised_latents,
261
+ text_embeddings=train_util.concat_embeddings(
262
+ prompt_pair.unconditional.text_embeds,
263
+ prompt_pair.neutral.text_embeds,
264
+ prompt_pair.batch_size,
265
+ ),
266
+ add_text_embeddings=train_util.concat_embeddings(
267
+ prompt_pair.unconditional.pooled_embeds,
268
+ prompt_pair.neutral.pooled_embeds,
269
+ prompt_pair.batch_size,
270
+ ),
271
+ add_time_ids=train_util.concat_embeddings(
272
+ add_time_ids, add_time_ids, prompt_pair.batch_size
273
+ ),
274
+ guidance_scale=1,
275
+ ).to(device, dtype=weight_dtype)
276
+ unconditional_latents = train_util.predict_noise_xl(
277
+ unet,
278
+ noise_scheduler,
279
+ current_timestep,
280
+ denoised_latents,
281
+ text_embeddings=train_util.concat_embeddings(
282
+ prompt_pair.unconditional.text_embeds,
283
+ prompt_pair.unconditional.text_embeds,
284
+ prompt_pair.batch_size,
285
+ ),
286
+ add_text_embeddings=train_util.concat_embeddings(
287
+ prompt_pair.unconditional.pooled_embeds,
288
+ prompt_pair.unconditional.pooled_embeds,
289
+ prompt_pair.batch_size,
290
+ ),
291
+ add_time_ids=train_util.concat_embeddings(
292
+ add_time_ids, add_time_ids, prompt_pair.batch_size
293
+ ),
294
+ guidance_scale=1,
295
+ ).to(device, dtype=weight_dtype)
296
+
297
+ if config.logging.verbose:
298
+ print("positive_latents:", positive_latents[0, 0, :5, :5])
299
+ print("neutral_latents:", neutral_latents[0, 0, :5, :5])
300
+ print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
301
+
302
+ with network:
303
+ target_latents = train_util.predict_noise_xl(
304
+ unet,
305
+ noise_scheduler,
306
+ current_timestep,
307
+ denoised_latents,
308
+ text_embeddings=train_util.concat_embeddings(
309
+ prompt_pair.unconditional.text_embeds,
310
+ prompt_pair.target.text_embeds,
311
+ prompt_pair.batch_size,
312
+ ),
313
+ add_text_embeddings=train_util.concat_embeddings(
314
+ prompt_pair.unconditional.pooled_embeds,
315
+ prompt_pair.target.pooled_embeds,
316
+ prompt_pair.batch_size,
317
+ ),
318
+ add_time_ids=train_util.concat_embeddings(
319
+ add_time_ids, add_time_ids, prompt_pair.batch_size
320
+ ),
321
+ guidance_scale=1,
322
+ ).to(device, dtype=weight_dtype)
323
+
324
+ if config.logging.verbose:
325
+ print("target_latents:", target_latents[0, 0, :5, :5])
326
+
327
+ positive_latents.requires_grad = False
328
+ neutral_latents.requires_grad = False
329
+ unconditional_latents.requires_grad = False
330
+
331
+ loss = prompt_pair.loss(
332
+ target_latents=target_latents,
333
+ positive_latents=positive_latents,
334
+ neutral_latents=neutral_latents,
335
+ unconditional_latents=unconditional_latents,
336
+ )
337
+
338
+ # 1000倍しないとずっと0.000...になってしまって見た目的に面白くない
339
+ pbar.set_description(f"Loss*1k: {loss.item()*1000:.4f}")
340
+ if config.logging.use_wandb:
341
+ wandb.log(
342
+ {"loss": loss, "iteration": i, "lr": lr_scheduler.get_last_lr()[0]}
343
+ )
344
+
345
+ loss.backward()
346
+ optimizer.step()
347
+ lr_scheduler.step()
348
+
349
+ del (
350
+ positive_latents,
351
+ neutral_latents,
352
+ unconditional_latents,
353
+ target_latents,
354
+ latents,
355
+ )
356
+ flush()
357
+
358
+ if (
359
+ i % config.save.per_steps == 0
360
+ and i != 0
361
+ and i != config.train.iterations - 1
362
+ ):
363
+ print("Saving...")
364
+ save_path.mkdir(parents=True, exist_ok=True)
365
+ network.save_weights(
366
+ save_path / f"{config.save.name}_{i}steps.pt",
367
+ dtype=save_weight_dtype,
368
+ )
369
+
370
+ print("Saving...")
371
+ save_path.mkdir(parents=True, exist_ok=True)
372
+ network.save_weights(
373
+ save_path / f"{config.save.name}_last.pt",
374
+ dtype=save_weight_dtype,
375
+ )
376
+
377
+ del (
378
+ unet,
379
+ noise_scheduler,
380
+ loss,
381
+ optimizer,
382
+ network,
383
+ )
384
+
385
+ flush()
386
+
387
+ print("Done.")
388
+
389
+
390
+ def main(args):
391
+ config_file = args.config_file
392
+
393
+ config = config_util.load_config_from_yaml(config_file)
394
+ if args.name is not None:
395
+ config.save.name = args.name
396
+ attributes = []
397
+ if args.attributes is not None:
398
+ attributes = args.attributes.split(',')
399
+ attributes = [a.strip() for a in attributes]
400
+
401
+ config.network.alpha = args.alpha
402
+ config.network.rank = args.rank
403
+ config.save.name += f'_alpha{args.alpha}'
404
+ config.save.name += f'_rank{config.network.rank }'
405
+ config.save.name += f'_{config.network.training_method}'
406
+ config.save.path += f'/{config.save.name}'
407
+
408
+ prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes)
409
+
410
+ device = torch.device(f"cuda:{args.device}")
411
+ train(config, prompts, device)
412
+
413
+
414
+ if __name__ == "__main__":
415
+ parser = argparse.ArgumentParser()
416
+ parser.add_argument(
417
+ "--config_file",
418
+ required=True,
419
+ help="Config file for training.",
420
+ )
421
+ # config_file 'data/config.yaml'
422
+ parser.add_argument(
423
+ "--alpha",
424
+ type=float,
425
+ required=True,
426
+ help="LoRA weight.",
427
+ )
428
+ # --alpha 1.0
429
+ parser.add_argument(
430
+ "--rank",
431
+ type=int,
432
+ required=False,
433
+ help="Rank of LoRA.",
434
+ default=4,
435
+ )
436
+ # --rank 4
437
+ parser.add_argument(
438
+ "--device",
439
+ type=int,
440
+ required=False,
441
+ default=0,
442
+ help="Device to train on.",
443
+ )
444
+ # --device 0
445
+ parser.add_argument(
446
+ "--name",
447
+ type=str,
448
+ required=False,
449
+ default=None,
450
+ help="Device to train on.",
451
+ )
452
+ # --name 'eyesize_slider'
453
+ parser.add_argument(
454
+ "--attributes",
455
+ type=str,
456
+ required=False,
457
+ default=None,
458
+ help="attritbutes to disentangle (comma seperated string)",
459
+ )
460
+
461
+ args = parser.parse_args()
462
+
463
+ main(args)
trainscripts/textsliders/train_util.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+
5
+ from transformers import CLIPTextModel, CLIPTokenizer
6
+ from diffusers import UNet2DConditionModel, SchedulerMixin
7
+
8
+ from trainscripts.textsliders.model_util import SDXL_TEXT_ENCODER_TYPE
9
+
10
+ from tqdm import tqdm
11
+
12
+ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
13
+ VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
14
+
15
+ UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL
16
+ TEXT_ENCODER_2_PROJECTION_DIM = 1280
17
+ UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816
18
+
19
+
20
+ def get_random_noise(
21
+ batch_size: int, height: int, width: int, generator: torch.Generator = None
22
+ ) -> torch.Tensor:
23
+ return torch.randn(
24
+ (
25
+ batch_size,
26
+ UNET_IN_CHANNELS,
27
+ height // VAE_SCALE_FACTOR, # 縦と横これであってるのかわからないけど、どっちにしろ大きな問題は発生しないのでこれでいいや
28
+ width // VAE_SCALE_FACTOR,
29
+ ),
30
+ generator=generator,
31
+ device="cpu",
32
+ )
33
+
34
+
35
+ # https://www.crosslabs.org/blog/diffusion-with-offset-noise
36
+ def apply_noise_offset(latents: torch.FloatTensor, noise_offset: float):
37
+ latents = latents + noise_offset * torch.randn(
38
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
39
+ )
40
+ return latents
41
+
42
+
43
+ def get_initial_latents(
44
+ scheduler: SchedulerMixin,
45
+ n_imgs: int,
46
+ height: int,
47
+ width: int,
48
+ n_prompts: int,
49
+ generator=None,
50
+ ) -> torch.Tensor:
51
+ noise = get_random_noise(n_imgs, height, width, generator=generator).repeat(
52
+ n_prompts, 1, 1, 1
53
+ )
54
+
55
+ latents = noise * scheduler.init_noise_sigma
56
+
57
+ return latents
58
+
59
+
60
+ def text_tokenize(
61
+ tokenizer: CLIPTokenizer, # 普通ならひとつ、XLならふたつ!
62
+ prompts: list[str],
63
+ ):
64
+ return tokenizer(
65
+ prompts,
66
+ padding="max_length",
67
+ max_length=tokenizer.model_max_length,
68
+ truncation=True,
69
+ return_tensors="pt",
70
+ ).input_ids
71
+
72
+
73
+ def text_encode(text_encoder: CLIPTextModel, tokens):
74
+ return text_encoder(tokens.to(text_encoder.device))[0]
75
+
76
+
77
+ def encode_prompts(
78
+ tokenizer: CLIPTokenizer,
79
+ text_encoder: CLIPTokenizer,
80
+ prompts: list[str],
81
+ ):
82
+
83
+ text_tokens = text_tokenize(tokenizer, prompts)
84
+ text_embeddings = text_encode(text_encoder, text_tokens)
85
+
86
+
87
+
88
+ return text_embeddings
89
+
90
+
91
+ # https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
92
+ def text_encode_xl(
93
+ text_encoder: SDXL_TEXT_ENCODER_TYPE,
94
+ tokens: torch.FloatTensor,
95
+ num_images_per_prompt: int = 1,
96
+ ):
97
+ prompt_embeds = text_encoder(
98
+ tokens.to(text_encoder.device), output_hidden_states=True
99
+ )
100
+ pooled_prompt_embeds = prompt_embeds[0]
101
+ prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
102
+
103
+ bs_embed, seq_len, _ = prompt_embeds.shape
104
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
105
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
106
+
107
+ return prompt_embeds, pooled_prompt_embeds
108
+
109
+
110
+ def encode_prompts_xl(
111
+ tokenizers: list[CLIPTokenizer],
112
+ text_encoders: list[SDXL_TEXT_ENCODER_TYPE],
113
+ prompts: list[str],
114
+ num_images_per_prompt: int = 1,
115
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
116
+ # text_encoder and text_encoder_2's penuultimate layer's output
117
+ text_embeds_list = []
118
+ pooled_text_embeds = None # always text_encoder_2's pool
119
+
120
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
121
+ text_tokens_input_ids = text_tokenize(tokenizer, prompts)
122
+ text_embeds, pooled_text_embeds = text_encode_xl(
123
+ text_encoder, text_tokens_input_ids, num_images_per_prompt
124
+ )
125
+
126
+ text_embeds_list.append(text_embeds)
127
+
128
+ bs_embed = pooled_text_embeds.shape[0]
129
+ pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view(
130
+ bs_embed * num_images_per_prompt, -1
131
+ )
132
+
133
+ return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
134
+
135
+
136
+ def concat_embeddings(
137
+ unconditional: torch.FloatTensor,
138
+ conditional: torch.FloatTensor,
139
+ n_imgs: int,
140
+ ):
141
+ return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0)
142
+
143
+
144
+ # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L721
145
+ def predict_noise(
146
+ unet: UNet2DConditionModel,
147
+ scheduler: SchedulerMixin,
148
+ timestep: int, # 現在のタイムステップ
149
+ latents: torch.FloatTensor,
150
+ text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
151
+ guidance_scale=7.5,
152
+ ) -> torch.FloatTensor:
153
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
154
+ latent_model_input = torch.cat([latents] * 2)
155
+
156
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
157
+
158
+ # predict the noise residual
159
+ noise_pred = unet(
160
+ latent_model_input,
161
+ timestep,
162
+ encoder_hidden_states=text_embeddings,
163
+ ).sample
164
+
165
+ # perform guidance
166
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
167
+ guided_target = noise_pred_uncond + guidance_scale * (
168
+ noise_pred_text - noise_pred_uncond
169
+ )
170
+
171
+ return guided_target
172
+
173
+
174
+ # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
175
+ @torch.no_grad()
176
+ def diffusion(
177
+ unet: UNet2DConditionModel,
178
+ scheduler: SchedulerMixin,
179
+ latents: torch.FloatTensor, # ただのノイズだけのlatents
180
+ text_embeddings: torch.FloatTensor,
181
+ total_timesteps: int = 1000,
182
+ start_timesteps=0,
183
+ **kwargs,
184
+ ):
185
+ # latents_steps = []
186
+
187
+ for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]):
188
+ noise_pred = predict_noise(
189
+ unet, scheduler, timestep, latents, text_embeddings, **kwargs
190
+ )
191
+
192
+ # compute the previous noisy sample x_t -> x_t-1
193
+ latents = scheduler.step(noise_pred, timestep, latents).prev_sample
194
+
195
+ # return latents_steps
196
+ return latents
197
+
198
+
199
+ def rescale_noise_cfg(
200
+ noise_cfg: torch.FloatTensor, noise_pred_text, guidance_rescale=0.0
201
+ ):
202
+ """
203
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
204
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
205
+ """
206
+ std_text = noise_pred_text.std(
207
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
208
+ )
209
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
210
+ # rescale the results from guidance (fixes overexposure)
211
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
212
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
213
+ noise_cfg = (
214
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
215
+ )
216
+
217
+ return noise_cfg
218
+
219
+
220
+ def predict_noise_xl(
221
+ unet: UNet2DConditionModel,
222
+ scheduler: SchedulerMixin,
223
+ timestep: int, # 現在のタイムステップ
224
+ latents: torch.FloatTensor,
225
+ text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
226
+ add_text_embeddings: torch.FloatTensor, # pooled なやつ
227
+ add_time_ids: torch.FloatTensor,
228
+ guidance_scale=7.5,
229
+ guidance_rescale=0.7,
230
+ ) -> torch.FloatTensor:
231
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
232
+ latent_model_input = torch.cat([latents] * 2)
233
+
234
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
235
+
236
+ added_cond_kwargs = {
237
+ "text_embeds": add_text_embeddings,
238
+ "time_ids": add_time_ids,
239
+ }
240
+
241
+ # predict the noise residual
242
+ noise_pred = unet(
243
+ latent_model_input,
244
+ timestep,
245
+ encoder_hidden_states=text_embeddings,
246
+ added_cond_kwargs=added_cond_kwargs,
247
+ ).sample
248
+
249
+ # perform guidance
250
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
251
+ guided_target = noise_pred_uncond + guidance_scale * (
252
+ noise_pred_text - noise_pred_uncond
253
+ )
254
+
255
+ # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
256
+ noise_pred = rescale_noise_cfg(
257
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
258
+ )
259
+
260
+ return guided_target
261
+
262
+
263
+ @torch.no_grad()
264
+ def diffusion_xl(
265
+ unet: UNet2DConditionModel,
266
+ scheduler: SchedulerMixin,
267
+ latents: torch.FloatTensor, # ただのノイズだけのlatents
268
+ text_embeddings: tuple[torch.FloatTensor, torch.FloatTensor],
269
+ add_text_embeddings: torch.FloatTensor, # pooled なやつ
270
+ add_time_ids: torch.FloatTensor,
271
+ guidance_scale: float = 1.0,
272
+ total_timesteps: int = 1000,
273
+ start_timesteps=0,
274
+ ):
275
+ # latents_steps = []
276
+
277
+ for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]):
278
+ noise_pred = predict_noise_xl(
279
+ unet,
280
+ scheduler,
281
+ timestep,
282
+ latents,
283
+ text_embeddings,
284
+ add_text_embeddings,
285
+ add_time_ids,
286
+ guidance_scale=guidance_scale,
287
+ guidance_rescale=0.7,
288
+ )
289
+
290
+ # compute the previous noisy sample x_t -> x_t-1
291
+ latents = scheduler.step(noise_pred, timestep, latents).prev_sample
292
+
293
+ # return latents_steps
294
+ return latents
295
+
296
+
297
+ # for XL
298
+ def get_add_time_ids(
299
+ height: int,
300
+ width: int,
301
+ dynamic_crops: bool = False,
302
+ dtype: torch.dtype = torch.float32,
303
+ ):
304
+ if dynamic_crops:
305
+ # random float scale between 1 and 3
306
+ random_scale = torch.rand(1).item() * 2 + 1
307
+ original_size = (int(height * random_scale), int(width * random_scale))
308
+ # random position
309
+ crops_coords_top_left = (
310
+ torch.randint(0, original_size[0] - height, (1,)).item(),
311
+ torch.randint(0, original_size[1] - width, (1,)).item(),
312
+ )
313
+ target_size = (height, width)
314
+ else:
315
+ original_size = (height, width)
316
+ crops_coords_top_left = (0, 0)
317
+ target_size = (height, width)
318
+
319
+ # this is expected as 6
320
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
321
+
322
+ # this is expected as 2816
323
+ passed_add_embed_dim = (
324
+ UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6
325
+ + TEXT_ENCODER_2_PROJECTION_DIM # + 1280
326
+ )
327
+ if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM:
328
+ raise ValueError(
329
+ f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
330
+ )
331
+
332
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
333
+ return add_time_ids
334
+
335
+
336
+ def get_optimizer(name: str):
337
+ name = name.lower()
338
+
339
+ if name.startswith("dadapt"):
340
+ import dadaptation
341
+
342
+ if name == "dadaptadam":
343
+ return dadaptation.DAdaptAdam
344
+ elif name == "dadaptlion":
345
+ return dadaptation.DAdaptLion
346
+ else:
347
+ raise ValueError("DAdapt optimizer must be dadaptadam or dadaptlion")
348
+
349
+ elif name.endswith("8bit"): # 検証してない
350
+ import bitsandbytes as bnb
351
+
352
+ if name == "adam8bit":
353
+ return bnb.optim.Adam8bit
354
+ elif name == "lion8bit":
355
+ return bnb.optim.Lion8bit
356
+ else:
357
+ raise ValueError("8bit optimizer must be adam8bit or lion8bit")
358
+
359
+ else:
360
+ if name == "adam":
361
+ return torch.optim.Adam
362
+ elif name == "adamw":
363
+ return torch.optim.AdamW
364
+ elif name == "lion":
365
+ from lion_pytorch import Lion
366
+
367
+ return Lion
368
+ elif name == "prodigy":
369
+ import prodigyopt
370
+
371
+ return prodigyopt.Prodigy
372
+ else:
373
+ raise ValueError("Optimizer must be adam, adamw, lion or Prodigy")
374
+
375
+
376
+ def get_lr_scheduler(
377
+ name: Optional[str],
378
+ optimizer: torch.optim.Optimizer,
379
+ max_iterations: Optional[int],
380
+ lr_min: Optional[float],
381
+ **kwargs,
382
+ ):
383
+ if name == "cosine":
384
+ return torch.optim.lr_scheduler.CosineAnnealingLR(
385
+ optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
386
+ )
387
+ elif name == "cosine_with_restarts":
388
+ return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
389
+ optimizer, T_0=max_iterations // 10, T_mult=2, eta_min=lr_min, **kwargs
390
+ )
391
+ elif name == "step":
392
+ return torch.optim.lr_scheduler.StepLR(
393
+ optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
394
+ )
395
+ elif name == "constant":
396
+ return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
397
+ elif name == "linear":
398
+ return torch.optim.lr_scheduler.LinearLR(
399
+ optimizer, factor=0.5, total_iters=max_iterations // 100, **kwargs
400
+ )
401
+ else:
402
+ raise ValueError(
403
+ "Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
404
+ )
405
+
406
+
407
+ def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> tuple[int, int]:
408
+ max_resolution = bucket_resolution
409
+ min_resolution = bucket_resolution // 2
410
+
411
+ step = 64
412
+
413
+ min_step = min_resolution // step
414
+ max_step = max_resolution // step
415
+
416
+ height = torch.randint(min_step, max_step, (1,)).item() * step
417
+ width = torch.randint(min_step, max_step, (1,)).item() * step
418
+
419
+ return height, width
utils.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import argparse
4
+ import os, json, random
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
+ import glob, re
8
+
9
+ from safetensors.torch import load_file
10
+ import matplotlib.image as mpimg
11
+ import copy
12
+ import gc
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+
15
+ import diffusers
16
+ from diffusers import DiffusionPipeline
17
+ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
18
+ from diffusers.loaders import AttnProcsLayers
19
+ from diffusers.models.attention_processor import LoRAAttnProcessor, AttentionProcessor
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
22
+
23
+ import inspect
24
+ import os
25
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
26
+ from diffusers.pipelines import StableDiffusionXLPipeline
27
+ import random
28
+
29
+ import torch
30
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
31
+
32
+ def flush():
33
+ torch.cuda.empty_cache()
34
+ gc.collect()
35
+
36
+ @torch.no_grad()
37
+ def call(
38
+ self,
39
+ prompt: Union[str, List[str]] = None,
40
+ prompt_2: Optional[Union[str, List[str]]] = None,
41
+ height: Optional[int] = None,
42
+ width: Optional[int] = None,
43
+ num_inference_steps: int = 50,
44
+ denoising_end: Optional[float] = None,
45
+ guidance_scale: float = 5.0,
46
+ negative_prompt: Optional[Union[str, List[str]]] = None,
47
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
48
+ num_images_per_prompt: Optional[int] = 1,
49
+ eta: float = 0.0,
50
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
51
+ latents: Optional[torch.FloatTensor] = None,
52
+ prompt_embeds: Optional[torch.FloatTensor] = None,
53
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
54
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
55
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
56
+ output_type: Optional[str] = "pil",
57
+ return_dict: bool = True,
58
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
59
+ callback_steps: int = 1,
60
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
61
+ guidance_rescale: float = 0.0,
62
+ original_size: Optional[Tuple[int, int]] = None,
63
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
64
+ target_size: Optional[Tuple[int, int]] = None,
65
+ negative_original_size: Optional[Tuple[int, int]] = None,
66
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
67
+ negative_target_size: Optional[Tuple[int, int]] = None,
68
+
69
+ network=None,
70
+ start_noise=None,
71
+ scale=None,
72
+ unet=None,
73
+ ):
74
+ r"""
75
+ Function invoked when calling the pipeline for generation.
76
+
77
+ Args:
78
+ prompt (`str` or `List[str]`, *optional*):
79
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
80
+ instead.
81
+ prompt_2 (`str` or `List[str]`, *optional*):
82
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
83
+ used in both text-encoders
84
+ height (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor):
85
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
86
+ Anything below 512 pixels won't work well for
87
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
88
+ and checkpoints that are not specifically fine-tuned on low resolutions.
89
+ width (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor):
90
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
91
+ Anything below 512 pixels won't work well for
92
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
93
+ and checkpoints that are not specifically fine-tuned on low resolutions.
94
+ num_inference_steps (`int`, *optional*, defaults to 50):
95
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
96
+ expense of slower inference.
97
+ denoising_end (`float`, *optional*):
98
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
99
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
100
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
101
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
102
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
103
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
104
+ guidance_scale (`float`, *optional*, defaults to 5.0):
105
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
106
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
107
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
108
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
109
+ usually at the expense of lower image quality.
110
+ negative_prompt (`str` or `List[str]`, *optional*):
111
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
112
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
113
+ less than `1`).
114
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
115
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
116
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
117
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
118
+ The number of images to generate per prompt.
119
+ eta (`float`, *optional*, defaults to 0.0):
120
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
121
+ [`schedulers.DDIMScheduler`], will be ignored for others.
122
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
123
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
124
+ to make generation deterministic.
125
+ latents (`torch.FloatTensor`, *optional*):
126
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
127
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
128
+ tensor will ge generated by sampling using the supplied random `generator`.
129
+ prompt_embeds (`torch.FloatTensor`, *optional*):
130
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
131
+ provided, text embeddings will be generated from `prompt` input argument.
132
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
133
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
134
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
135
+ argument.
136
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
137
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
138
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
139
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
140
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
141
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
142
+ input argument.
143
+ output_type (`str`, *optional*, defaults to `"pil"`):
144
+ The output format of the generate image. Choose between
145
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
146
+ return_dict (`bool`, *optional*, defaults to `True`):
147
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
148
+ of a plain tuple.
149
+ callback (`Callable`, *optional*):
150
+ A function that will be called every `callback_steps` steps during inference. The function will be
151
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
152
+ callback_steps (`int`, *optional*, defaults to 1):
153
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
154
+ called at every step.
155
+ cross_attention_kwargs (`dict`, *optional*):
156
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
157
+ `self.processor` in
158
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
159
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
160
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
161
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
162
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
163
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
164
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
165
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
166
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
167
+ explained in section 2.2 of
168
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
169
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
170
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
171
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
172
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
173
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
174
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
175
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
176
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
177
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
178
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
179
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
180
+ micro-conditioning as explained in section 2.2 of
181
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
182
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
183
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
184
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
185
+ micro-conditioning as explained in section 2.2 of
186
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
187
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
188
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
189
+ To negatively condition the generation process based on a target image resolution. It should be as same
190
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
191
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
192
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
193
+
194
+ Examples:
195
+
196
+ Returns:
197
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
198
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
199
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
200
+ """
201
+ # 0. Default height and width to unet
202
+ height = height or self.default_sample_size * self.vae_scale_factor
203
+ width = width or self.default_sample_size * self.vae_scale_factor
204
+
205
+ original_size = original_size or (height, width)
206
+ target_size = target_size or (height, width)
207
+
208
+ # 1. Check inputs. Raise error if not correct
209
+ self.check_inputs(
210
+ prompt,
211
+ prompt_2,
212
+ height,
213
+ width,
214
+ callback_steps,
215
+ negative_prompt,
216
+ negative_prompt_2,
217
+ prompt_embeds,
218
+ negative_prompt_embeds,
219
+ pooled_prompt_embeds,
220
+ negative_pooled_prompt_embeds,
221
+ )
222
+
223
+ # 2. Define call parameters
224
+ if prompt is not None and isinstance(prompt, str):
225
+ batch_size = 1
226
+ elif prompt is not None and isinstance(prompt, list):
227
+ batch_size = len(prompt)
228
+ else:
229
+ batch_size = prompt_embeds.shape[0]
230
+
231
+ device = self._execution_device
232
+
233
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
234
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
235
+ # corresponds to doing no classifier free guidance.
236
+ do_classifier_free_guidance = guidance_scale > 1.0
237
+
238
+ # 3. Encode input prompt
239
+ text_encoder_lora_scale = (
240
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
241
+ )
242
+ (
243
+ prompt_embeds,
244
+ negative_prompt_embeds,
245
+ pooled_prompt_embeds,
246
+ negative_pooled_prompt_embeds,
247
+ ) = self.encode_prompt(
248
+ prompt=prompt,
249
+ prompt_2=prompt_2,
250
+ device=device,
251
+ num_images_per_prompt=num_images_per_prompt,
252
+ do_classifier_free_guidance=do_classifier_free_guidance,
253
+ negative_prompt=negative_prompt,
254
+ negative_prompt_2=negative_prompt_2,
255
+ prompt_embeds=prompt_embeds,
256
+ negative_prompt_embeds=negative_prompt_embeds,
257
+ pooled_prompt_embeds=pooled_prompt_embeds,
258
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
259
+ lora_scale=text_encoder_lora_scale,
260
+ )
261
+
262
+ # 4. Prepare timesteps
263
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
264
+
265
+ timesteps = self.scheduler.timesteps
266
+
267
+ # 5. Prepare latent variables
268
+ num_channels_latents = unet.config.in_channels
269
+ latents = self.prepare_latents(
270
+ batch_size * num_images_per_prompt,
271
+ num_channels_latents,
272
+ height,
273
+ width,
274
+ prompt_embeds.dtype,
275
+ device,
276
+ generator,
277
+ latents,
278
+ )
279
+
280
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
281
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
282
+
283
+ # 7. Prepare added time ids & embeddings
284
+ add_text_embeds = pooled_prompt_embeds
285
+ add_time_ids = self._get_add_time_ids(
286
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
287
+ )
288
+ if negative_original_size is not None and negative_target_size is not None:
289
+ negative_add_time_ids = self._get_add_time_ids(
290
+ negative_original_size,
291
+ negative_crops_coords_top_left,
292
+ negative_target_size,
293
+ dtype=prompt_embeds.dtype,
294
+ )
295
+ else:
296
+ negative_add_time_ids = add_time_ids
297
+
298
+ if do_classifier_free_guidance:
299
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
300
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
301
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
302
+
303
+ prompt_embeds = prompt_embeds.to(device)
304
+ add_text_embeds = add_text_embeds.to(device)
305
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
306
+
307
+ # 8. Denoising loop
308
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
309
+
310
+ # 7.1 Apply denoising_end
311
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
312
+ discrete_timestep_cutoff = int(
313
+ round(
314
+ self.scheduler.config.num_train_timesteps
315
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
316
+ )
317
+ )
318
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
319
+ timesteps = timesteps[:num_inference_steps]
320
+ latents = latents.to(unet.dtype)
321
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
322
+ for i, t in enumerate(timesteps):
323
+ if t>start_noise:
324
+ network.set_lora_slider(scale=0)
325
+ else:
326
+ network.set_lora_slider(scale=scale)
327
+ # expand the latents if we are doing classifier free guidance
328
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
329
+
330
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
331
+
332
+ # predict the noise residual
333
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
334
+ with network:
335
+ noise_pred = unet(
336
+ latent_model_input,
337
+ t,
338
+ encoder_hidden_states=prompt_embeds,
339
+ cross_attention_kwargs=cross_attention_kwargs,
340
+ added_cond_kwargs=added_cond_kwargs,
341
+ return_dict=False,
342
+ )[0]
343
+
344
+ # perform guidance
345
+ if do_classifier_free_guidance:
346
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
347
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
348
+
349
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
350
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
351
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
352
+
353
+ # compute the previous noisy sample x_t -> x_t-1
354
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
355
+
356
+ # call the callback, if provided
357
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
358
+ progress_bar.update()
359
+ if callback is not None and i % callback_steps == 0:
360
+ callback(i, t, latents)
361
+
362
+ if not output_type == "latent":
363
+ # make sure the VAE is in float32 mode, as it overflows in float16
364
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
365
+
366
+ if needs_upcasting:
367
+ self.upcast_vae()
368
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
369
+
370
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
371
+
372
+ # cast back to fp16 if needed
373
+ if needs_upcasting:
374
+ self.vae.to(dtype=torch.float16)
375
+ else:
376
+ image = latents
377
+
378
+ if not output_type == "latent":
379
+ # apply watermark if available
380
+ if self.watermark is not None:
381
+ image = self.watermark.apply_watermark(image)
382
+
383
+ image = self.image_processor.postprocess(image, output_type=output_type)
384
+
385
+ # Offload all models
386
+ # self.maybe_free_model_hooks()
387
+
388
+ if not return_dict:
389
+ return (image,)
390
+
391
+ return StableDiffusionXLPipelineOutput(images=image)