Spaces:
Sleeping
Sleeping
feat: 修改调用方法以支持多个网络和缩放参数的处理
Browse files
app.py
CHANGED
@@ -63,18 +63,18 @@ def merge_lora_networks(networks):
|
|
63 |
return base_network
|
64 |
|
65 |
# 修改 call 方法以支持传递 networks 参数
|
66 |
-
def rw_sd_call(self, *args, networks=None, scales=None, **kwargs):
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
StableDiffusionXLPipeline.__original_call__ = StableDiffusionXLPipeline.__call__
|
77 |
-
StableDiffusionXLPipeline.__call__ = rw_sd_call
|
78 |
|
79 |
class Demo:
|
80 |
|
|
|
63 |
return base_network
|
64 |
|
65 |
# 修改 call 方法以支持传递 networks 参数
|
66 |
+
# def rw_sd_call(self, *args, networks=None, scales=None, **kwargs):
|
67 |
+
# if networks is not None and scales is not None:
|
68 |
+
# for network, scale in zip(networks, scales):
|
69 |
+
# for name, param in network.named_parameters():
|
70 |
+
# if name in self.unet.state_dict():
|
71 |
+
# self.unet.state_dict()[name].add_(param * scale)
|
72 |
+
# else:
|
73 |
+
# self.unet.state_dict()[name] = param.clone() * scale
|
74 |
+
# return self.__original_call__(*args, **kwargs)
|
75 |
+
|
76 |
+
# StableDiffusionXLPipeline.__original_call__ = StableDiffusionXLPipeline.__call__
|
77 |
+
# StableDiffusionXLPipeline.__call__ = rw_sd_call
|
78 |
|
79 |
class Demo:
|
80 |
|
utils.py
CHANGED
@@ -66,9 +66,11 @@ def call(
|
|
66 |
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
67 |
negative_target_size: Optional[Tuple[int, int]] = None,
|
68 |
|
69 |
-
network=None,
|
|
|
70 |
start_noise=None,
|
71 |
scale=None,
|
|
|
72 |
unet=None,
|
73 |
):
|
74 |
r"""
|
@@ -318,12 +320,21 @@ def call(
|
|
318 |
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
319 |
timesteps = timesteps[:num_inference_steps]
|
320 |
latents = latents.to(unet.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
322 |
-
for i, t in enumerate(timesteps):
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
|
|
|
|
327 |
# expand the latents if we are doing classifier free guidance
|
328 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
329 |
|
@@ -331,15 +342,18 @@ def call(
|
|
331 |
|
332 |
# predict the noise residual
|
333 |
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
|
|
|
|
|
|
343 |
|
344 |
# perform guidance
|
345 |
if do_classifier_free_guidance:
|
|
|
66 |
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
67 |
negative_target_size: Optional[Tuple[int, int]] = None,
|
68 |
|
69 |
+
network=None,
|
70 |
+
networks=None,
|
71 |
start_noise=None,
|
72 |
scale=None,
|
73 |
+
scales=None,
|
74 |
unet=None,
|
75 |
):
|
76 |
r"""
|
|
|
320 |
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
321 |
timesteps = timesteps[:num_inference_steps]
|
322 |
latents = latents.to(unet.dtype)
|
323 |
+
|
324 |
+
# 统一处理 network,scale | 处理成 list
|
325 |
+
if network is not None:
|
326 |
+
networks = [network]
|
327 |
+
if scale is not None:
|
328 |
+
scales = [scale]
|
329 |
+
|
330 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
331 |
+
for i, t in enumerate(timesteps):
|
332 |
+
# 遍历所有网络,设置 scale
|
333 |
+
if networks is not None and scales is not None:
|
334 |
+
for _network, _scale in zip(networks, scales):
|
335 |
+
with _network:
|
336 |
+
_network.set_lora_slider(scale=0 if t > start_noise else _scale)
|
337 |
+
|
338 |
# expand the latents if we are doing classifier free guidance
|
339 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
340 |
|
|
|
342 |
|
343 |
# predict the noise residual
|
344 |
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
345 |
+
# 应用多个 LoRA 网络
|
346 |
+
if networks is not None and scales is not None:
|
347 |
+
for _network, _scale in zip(networks, scales):
|
348 |
+
with _network:
|
349 |
+
noise_pred = self.unet(
|
350 |
+
latent_model_input,
|
351 |
+
t,
|
352 |
+
encoder_hidden_states=prompt_embeds,
|
353 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
354 |
+
added_cond_kwargs=added_cond_kwargs,
|
355 |
+
return_dict=False,
|
356 |
+
)[0]
|
357 |
|
358 |
# perform guidance
|
359 |
if do_classifier_free_guidance:
|