Vanisper commited on
Commit
a5812ce
·
1 Parent(s): ee38d7b

chore: 更新依赖项并优化网络合并功能

Browse files
Files changed (5) hide show
  1. app.py +17 -2
  2. install_deps.bat +2 -7
  3. requirements.txt +10 -11
  4. utils.bak +413 -0
  5. utils.py +17 -39
app.py CHANGED
@@ -50,6 +50,19 @@ SHARED_UI_WARNING = f'''## 注意 - 在此共享UI中训练可能会很慢。您
50
  '''
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  class Demo:
54
 
55
  def __init__(self) -> None:
@@ -354,12 +367,14 @@ class Demo:
354
  ).to(self.device, dtype=self.weight_dtype)
355
  network.load_state_dict(torch.load(model_path))
356
  networks.append(network)
 
 
357
 
358
  generator = torch.manual_seed(seed)
359
- edited_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, networks=networks, start_noise=int(start_noise), scale=float(scale), unet=unet, guidance_scale=self.guidance_scale).images[0]
360
 
361
  generator = torch.manual_seed(seed)
362
- original_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, networks=networks, start_noise=start_noise, scale=0, unet=unet, guidance_scale=self.guidance_scale).images[0]
363
 
364
  del unet, networks
365
  unet = None
 
50
  '''
51
 
52
 
53
+ def merge_lora_networks(networks):
54
+ if not networks:
55
+ return None
56
+
57
+ base_network = networks[0]
58
+ for network in networks[1:]:
59
+ for name, param in network.named_parameters():
60
+ if name in base_network.state_dict():
61
+ base_network.state_dict()[name].add_(param)
62
+ else:
63
+ base_network.state_dict()[name] = param.clone()
64
+ return base_network
65
+
66
  class Demo:
67
 
68
  def __init__(self) -> None:
 
367
  ).to(self.device, dtype=self.weight_dtype)
368
  network.load_state_dict(torch.load(model_path))
369
  networks.append(network)
370
+
371
+ __network__ = merge_lora_networks(networks)
372
 
373
  generator = torch.manual_seed(seed)
374
+ edited_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=__network__, start_noise=int(start_noise), scale=float(scale), unet=unet, guidance_scale=self.guidance_scale).images[0]
375
 
376
  generator = torch.manual_seed(seed)
377
+ original_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=__network__, start_noise=start_noise, scale=0, unet=unet, guidance_scale=self.guidance_scale).images[0]
378
 
379
  del unet, networks
380
  unet = None
install_deps.bat CHANGED
@@ -1,7 +1,2 @@
1
- pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
2
- pip install -U xformers --index-url https://download.pytorch.org/whl/cu124
3
- pip install 'https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_multi-backend-refactor/bitsandbytes-0.44.1.dev0-py3-none-win_amd64.whl'
4
-
5
- pip install -r requirements-win.txt
6
-
7
- pip install --upgrade gradio
 
1
+ pip install -r requirements.txt
2
+ pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
 
 
 
 
 
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  bitsandbytes==0.41.1
2
  dadaptation==3.1
3
- diffusers==0.32.2
4
  ipython==8.7.0
5
  lion_pytorch==0.1.2
6
  lpips==0.1.4
@@ -11,17 +11,16 @@ opencv_python_headless==4.7.0.68
11
  pandas==1.5.2
12
  Pillow==10.1.0
13
  prodigyopt==1.0
14
- pydantic==2.10.5
15
  PyYAML==6.0.1
16
  Requests==2.31.0
17
- safetensors==0.5.2
18
- torch==2.5.1
19
- torchvision==0.20.1
20
- xformers
21
  tqdm==4.64.1
22
- transformers==4.48.1
23
  wandb==0.12.21
24
- accelerate==1.3.0
25
- gradio==5.12.0
26
- gradio_client==1.5.4
27
- huggingface-hub==0.27.1
 
1
  bitsandbytes==0.41.1
2
  dadaptation==3.1
3
+ diffusers==0.20.2
4
  ipython==8.7.0
5
  lion_pytorch==0.1.2
6
  lpips==0.1.4
 
11
  pandas==1.5.2
12
  Pillow==10.1.0
13
  prodigyopt==1.0
14
+ pydantic==1.10.3
15
  PyYAML==6.0.1
16
  Requests==2.31.0
17
+ safetensors==0.3.1
18
+ torch==2.0.1
19
+ torchvision==0.15.2
 
20
  tqdm==4.64.1
21
+ transformers==4.27.4
22
  wandb==0.12.21
23
+ accelerate==0.16.0
24
+ xformers
25
+ gradio
26
+ huggingface-hub==0.23.5
utils.bak ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import argparse
4
+ import os, json, random
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
+ import glob, re
8
+
9
+ from safetensors.torch import load_file
10
+ import matplotlib.image as mpimg
11
+ import copy
12
+ import gc
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+
15
+ import diffusers
16
+ from diffusers import DiffusionPipeline
17
+ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
18
+ from diffusers.loaders import AttnProcsLayers
19
+ from diffusers.models.attention_processor import LoRAAttnProcessor, AttentionProcessor
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
22
+
23
+ import inspect
24
+ import os
25
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
26
+ from diffusers.pipelines import StableDiffusionXLPipeline
27
+ import random
28
+
29
+ import torch
30
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
31
+
32
+ def flush():
33
+ torch.cuda.empty_cache()
34
+ gc.collect()
35
+
36
+ @torch.no_grad()
37
+ def call(
38
+ self,
39
+ prompt: Union[str, List[str]] = None,
40
+ prompt_2: Optional[Union[str, List[str]]] = None,
41
+ height: Optional[int] = None,
42
+ width: Optional[int] = None,
43
+ num_inference_steps: int = 50,
44
+ denoising_end: Optional[float] = None,
45
+ guidance_scale: float = 5.0,
46
+ negative_prompt: Optional[Union[str, List[str]]] = None,
47
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
48
+ num_images_per_prompt: Optional[int] = 1,
49
+ eta: float = 0.0,
50
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
51
+ latents: Optional[torch.FloatTensor] = None,
52
+ prompt_embeds: Optional[torch.FloatTensor] = None,
53
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
54
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
55
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
56
+ output_type: Optional[str] = "pil",
57
+ return_dict: bool = True,
58
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
59
+ callback_steps: int = 1,
60
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
61
+ guidance_rescale: float = 0.0,
62
+ original_size: Optional[Tuple[int, int]] = None,
63
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
64
+ target_size: Optional[Tuple[int, int]] = None,
65
+ negative_original_size: Optional[Tuple[int, int]] = None,
66
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
67
+ negative_target_size: Optional[Tuple[int, int]] = None,
68
+
69
+ network=None,
70
+ networks=None,
71
+ start_noise=None,
72
+ scale=None,
73
+ scales=None,
74
+ unet=None,
75
+ ):
76
+ r"""
77
+ Function invoked when calling the pipeline for generation.
78
+
79
+ Args:
80
+ prompt (`str` or `List[str]`, *optional*):
81
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
82
+ instead.
83
+ prompt_2 (`str` or `List[str]`, *optional*):
84
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
85
+ used in both text-encoders
86
+ height (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor):
87
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
88
+ Anything below 512 pixels won't work well for
89
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
90
+ and checkpoints that are not specifically fine-tuned on low resolutions.
91
+ width (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor):
92
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
93
+ Anything below 512 pixels won't work well for
94
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
95
+ and checkpoints that are not specifically fine-tuned on low resolutions.
96
+ num_inference_steps (`int`, *optional*, defaults to 50):
97
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
98
+ expense of slower inference.
99
+ denoising_end (`float`, *optional*):
100
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
101
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
102
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
103
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
104
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
105
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
106
+ guidance_scale (`float`, *optional*, defaults to 5.0):
107
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
108
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
109
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
110
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
111
+ usually at the expense of lower image quality.
112
+ negative_prompt (`str` or `List[str]`, *optional*):
113
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
114
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
115
+ less than `1`).
116
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
117
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
118
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
119
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
120
+ The number of images to generate per prompt.
121
+ eta (`float`, *optional*, defaults to 0.0):
122
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
123
+ [`schedulers.DDIMScheduler`], will be ignored for others.
124
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
125
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
126
+ to make generation deterministic.
127
+ latents (`torch.FloatTensor`, *optional*):
128
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
129
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
130
+ tensor will ge generated by sampling using the supplied random `generator`.
131
+ prompt_embeds (`torch.FloatTensor`, *optional*):
132
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
133
+ provided, text embeddings will be generated from `prompt` input argument.
134
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
135
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
136
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
137
+ argument.
138
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
139
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
140
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
141
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
142
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
143
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
144
+ input argument.
145
+ output_type (`str`, *optional*, defaults to `"pil"`):
146
+ The output format of the generate image. Choose between
147
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
148
+ return_dict (`bool`, *optional*, defaults to `True`):
149
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
150
+ of a plain tuple.
151
+ callback (`Callable`, *optional*):
152
+ A function that will be called every `callback_steps` steps during inference. The function will be
153
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
154
+ callback_steps (`int`, *optional*, defaults to 1):
155
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
156
+ called at every step.
157
+ cross_attention_kwargs (`dict`, *optional*):
158
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
159
+ `self.processor` in
160
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
161
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
162
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
163
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
164
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
165
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
166
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
167
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
168
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
169
+ explained in section 2.2 of
170
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
171
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
172
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
173
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
174
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
175
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
176
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
177
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
178
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
179
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
180
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
181
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
182
+ micro-conditioning as explained in section 2.2 of
183
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
184
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
185
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
186
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
187
+ micro-conditioning as explained in section 2.2 of
188
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
189
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
190
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
191
+ To negatively condition the generation process based on a target image resolution. It should be as same
192
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
193
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
194
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
195
+
196
+ Examples:
197
+
198
+ Returns:
199
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
200
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
201
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
202
+ """
203
+ # 0. Default height and width to unet
204
+ height = height or self.default_sample_size * self.vae_scale_factor
205
+ width = width or self.default_sample_size * self.vae_scale_factor
206
+
207
+ original_size = original_size or (height, width)
208
+ target_size = target_size or (height, width)
209
+
210
+ # 1. Check inputs. Raise error if not correct
211
+ self.check_inputs(
212
+ prompt,
213
+ prompt_2,
214
+ height,
215
+ width,
216
+ callback_steps,
217
+ negative_prompt,
218
+ negative_prompt_2,
219
+ prompt_embeds,
220
+ negative_prompt_embeds,
221
+ pooled_prompt_embeds,
222
+ negative_pooled_prompt_embeds,
223
+ )
224
+
225
+ # 2. Define call parameters
226
+ if prompt is not None and isinstance(prompt, str):
227
+ batch_size = 1
228
+ elif prompt is not None and isinstance(prompt, list):
229
+ batch_size = len(prompt)
230
+ else:
231
+ batch_size = prompt_embeds.shape[0]
232
+
233
+ device = self._execution_device
234
+
235
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
236
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
237
+ # corresponds to doing no classifier free guidance.
238
+ do_classifier_free_guidance = guidance_scale > 1.0
239
+
240
+ # 3. Encode input prompt
241
+ text_encoder_lora_scale = (
242
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
243
+ )
244
+ (
245
+ prompt_embeds,
246
+ negative_prompt_embeds,
247
+ pooled_prompt_embeds,
248
+ negative_pooled_prompt_embeds,
249
+ ) = self.encode_prompt(
250
+ prompt=prompt,
251
+ prompt_2=prompt_2,
252
+ device=device,
253
+ num_images_per_prompt=num_images_per_prompt,
254
+ do_classifier_free_guidance=do_classifier_free_guidance,
255
+ negative_prompt=negative_prompt,
256
+ negative_prompt_2=negative_prompt_2,
257
+ prompt_embeds=prompt_embeds,
258
+ negative_prompt_embeds=negative_prompt_embeds,
259
+ pooled_prompt_embeds=pooled_prompt_embeds,
260
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
261
+ lora_scale=text_encoder_lora_scale,
262
+ )
263
+
264
+ # 4. Prepare timesteps
265
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
266
+
267
+ timesteps = self.scheduler.timesteps
268
+
269
+ # 5. Prepare latent variables
270
+ num_channels_latents = unet.config.in_channels
271
+ latents = self.prepare_latents(
272
+ batch_size * num_images_per_prompt,
273
+ num_channels_latents,
274
+ height,
275
+ width,
276
+ prompt_embeds.dtype,
277
+ device,
278
+ generator,
279
+ latents,
280
+ )
281
+
282
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
283
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
284
+
285
+ # 7. Prepare added time ids & embeddings
286
+ add_text_embeds = pooled_prompt_embeds
287
+ # 确保 text_encoder_projection_dim 被正确初始化
288
+ if self.text_encoder_2 is None:
289
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
290
+ else:
291
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
292
+ add_time_ids = self._get_add_time_ids(
293
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype,
294
+ text_encoder_projection_dim=text_encoder_projection_dim
295
+ )
296
+ if negative_original_size is not None and negative_target_size is not None:
297
+ negative_add_time_ids = self._get_add_time_ids(
298
+ negative_original_size,
299
+ negative_crops_coords_top_left,
300
+ negative_target_size,
301
+ dtype=prompt_embeds.dtype,
302
+ text_encoder_projection_dim=text_encoder_projection_dim
303
+ )
304
+ else:
305
+ negative_add_time_ids = add_time_ids
306
+
307
+ if do_classifier_free_guidance:
308
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
309
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
310
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
311
+
312
+ prompt_embeds = prompt_embeds.to(device)
313
+ add_text_embeds = add_text_embeds.to(device)
314
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
315
+
316
+
317
+ # 8. Denoising loop
318
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
319
+
320
+ # 7.1 Apply denoising_end
321
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
322
+ discrete_timestep_cutoff = int(
323
+ round(
324
+ self.scheduler.config.num_train_timesteps
325
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
326
+ )
327
+ )
328
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
329
+ timesteps = timesteps[:num_inference_steps]
330
+ latents = latents.to(unet.dtype)
331
+
332
+ # 统一处理 network,scale | 处理成 list
333
+ if network is not None:
334
+ networks = [network]
335
+ if scale is not None:
336
+ scales = [scale]
337
+
338
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
339
+ for i, t in enumerate(timesteps):
340
+ # 遍历所有网络,设置 scale
341
+ if networks is not None and scales is not None:
342
+ for _network, _scale in zip(networks, scales):
343
+ with _network:
344
+ _network.set_lora_slider(scale=0 if t > start_noise else float(_scale))
345
+
346
+ # expand the latents if we are doing classifier free guidance
347
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
348
+
349
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
350
+
351
+ # predict the noise residual
352
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
353
+ # 应用多个 LoRA 网络
354
+ if networks is not None and scales is not None:
355
+ for _network, _scale in zip(networks, scales):
356
+ with _network:
357
+ noise_pred = self.unet(
358
+ latent_model_input,
359
+ t,
360
+ encoder_hidden_states=prompt_embeds,
361
+ cross_attention_kwargs=cross_attention_kwargs,
362
+ added_cond_kwargs=added_cond_kwargs,
363
+ return_dict=False,
364
+ )[0]
365
+
366
+ # perform guidance
367
+ if do_classifier_free_guidance:
368
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
369
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
370
+
371
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
372
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
373
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
374
+
375
+ # compute the previous noisy sample x_t -> x_t-1
376
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
377
+
378
+ # call the callback, if provided
379
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
380
+ progress_bar.update()
381
+ if callback is not None and i % callback_steps == 0:
382
+ callback(i, t, latents)
383
+
384
+ if not output_type == "latent":
385
+ # make sure the VAE is in float32 mode, as it overflows in float16
386
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
387
+
388
+ if needs_upcasting:
389
+ self.upcast_vae()
390
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
391
+
392
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
393
+
394
+ # cast back to fp16 if needed
395
+ if needs_upcasting:
396
+ self.vae.to(dtype=torch.float16)
397
+ else:
398
+ image = latents
399
+
400
+ if not output_type == "latent":
401
+ # apply watermark if available
402
+ if self.watermark is not None:
403
+ image = self.watermark.apply_watermark(image)
404
+
405
+ image = self.image_processor.postprocess(image, output_type=output_type)
406
+
407
+ # Offload all models
408
+ # self.maybe_free_model_hooks()
409
+
410
+ if not return_dict:
411
+ return (image,)
412
+
413
+ return StableDiffusionXLPipelineOutput(images=image)
utils.py CHANGED
@@ -66,11 +66,9 @@ def call(
66
  negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
67
  negative_target_size: Optional[Tuple[int, int]] = None,
68
 
69
- network=None,
70
- networks=None,
71
  start_noise=None,
72
  scale=None,
73
- scales=None,
74
  unet=None,
75
  ):
76
  r"""
@@ -284,14 +282,8 @@ def call(
284
 
285
  # 7. Prepare added time ids & embeddings
286
  add_text_embeds = pooled_prompt_embeds
287
- # 确保 text_encoder_projection_dim 被正确初始化
288
- if self.text_encoder_2 is None:
289
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
290
- else:
291
- text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
292
  add_time_ids = self._get_add_time_ids(
293
- original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype,
294
- text_encoder_projection_dim=text_encoder_projection_dim
295
  )
296
  if negative_original_size is not None and negative_target_size is not None:
297
  negative_add_time_ids = self._get_add_time_ids(
@@ -299,7 +291,6 @@ def call(
299
  negative_crops_coords_top_left,
300
  negative_target_size,
301
  dtype=prompt_embeds.dtype,
302
- text_encoder_projection_dim=text_encoder_projection_dim
303
  )
304
  else:
305
  negative_add_time_ids = add_time_ids
@@ -313,7 +304,6 @@ def call(
313
  add_text_embeds = add_text_embeds.to(device)
314
  add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
315
 
316
-
317
  # 8. Denoising loop
318
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
319
 
@@ -328,21 +318,12 @@ def call(
328
  num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
329
  timesteps = timesteps[:num_inference_steps]
330
  latents = latents.to(unet.dtype)
331
-
332
- # 统一处理 network,scale | 处理成 list
333
- if network is not None:
334
- networks = [network]
335
- if scale is not None:
336
- scales = [scale]
337
-
338
  with self.progress_bar(total=num_inference_steps) as progress_bar:
339
- for i, t in enumerate(timesteps):
340
- # 遍历所有网络,设置 scale
341
- if networks is not None and scales is not None:
342
- for _network, _scale in zip(networks, scales):
343
- with _network:
344
- _network.set_lora_slider(scale=0 if t > start_noise else float(_scale))
345
-
346
  # expand the latents if we are doing classifier free guidance
347
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
348
 
@@ -350,18 +331,15 @@ def call(
350
 
351
  # predict the noise residual
352
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
353
- # 应用多个 LoRA 网络
354
- if networks is not None and scales is not None:
355
- for _network, _scale in zip(networks, scales):
356
- with _network:
357
- noise_pred = self.unet(
358
- latent_model_input,
359
- t,
360
- encoder_hidden_states=prompt_embeds,
361
- cross_attention_kwargs=cross_attention_kwargs,
362
- added_cond_kwargs=added_cond_kwargs,
363
- return_dict=False,
364
- )[0]
365
 
366
  # perform guidance
367
  if do_classifier_free_guidance:
@@ -410,4 +388,4 @@ def call(
410
  if not return_dict:
411
  return (image,)
412
 
413
- return StableDiffusionXLPipelineOutput(images=image)
 
66
  negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
67
  negative_target_size: Optional[Tuple[int, int]] = None,
68
 
69
+ network=None,
 
70
  start_noise=None,
71
  scale=None,
 
72
  unet=None,
73
  ):
74
  r"""
 
282
 
283
  # 7. Prepare added time ids & embeddings
284
  add_text_embeds = pooled_prompt_embeds
 
 
 
 
 
285
  add_time_ids = self._get_add_time_ids(
286
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
 
287
  )
288
  if negative_original_size is not None and negative_target_size is not None:
289
  negative_add_time_ids = self._get_add_time_ids(
 
291
  negative_crops_coords_top_left,
292
  negative_target_size,
293
  dtype=prompt_embeds.dtype,
 
294
  )
295
  else:
296
  negative_add_time_ids = add_time_ids
 
304
  add_text_embeds = add_text_embeds.to(device)
305
  add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
306
 
 
307
  # 8. Denoising loop
308
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
309
 
 
318
  num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
319
  timesteps = timesteps[:num_inference_steps]
320
  latents = latents.to(unet.dtype)
 
 
 
 
 
 
 
321
  with self.progress_bar(total=num_inference_steps) as progress_bar:
322
+ for i, t in enumerate(timesteps):
323
+ if t>start_noise:
324
+ network.set_lora_slider(scale=0)
325
+ else:
326
+ network.set_lora_slider(scale=scale)
 
 
327
  # expand the latents if we are doing classifier free guidance
328
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
329
 
 
331
 
332
  # predict the noise residual
333
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
334
+ with network:
335
+ noise_pred = unet(
336
+ latent_model_input,
337
+ t,
338
+ encoder_hidden_states=prompt_embeds,
339
+ cross_attention_kwargs=cross_attention_kwargs,
340
+ added_cond_kwargs=added_cond_kwargs,
341
+ return_dict=False,
342
+ )[0]
 
 
 
343
 
344
  # perform guidance
345
  if do_classifier_free_guidance:
 
388
  if not return_dict:
389
  return (image,)
390
 
391
+ return StableDiffusionXLPipelineOutput(images=image)