Vanisper commited on
Commit
49a7fff
·
1 Parent(s): dfcfb88

feat: 修改调用方法以支持多个网络和缩放参数的处理

Browse files
Files changed (2) hide show
  1. app.py +12 -12
  2. utils.py +29 -15
app.py CHANGED
@@ -63,18 +63,18 @@ def merge_lora_networks(networks):
63
  return base_network
64
 
65
  # 修改 call 方法以支持传递 networks 参数
66
- def rw_sd_call(self, *args, networks=None, scales=None, **kwargs):
67
- if networks is not None and scales is not None:
68
- for network, scale in zip(networks, scales):
69
- for name, param in network.named_parameters():
70
- if name in self.unet.state_dict():
71
- self.unet.state_dict()[name].add_(param * scale)
72
- else:
73
- self.unet.state_dict()[name] = param.clone() * scale
74
- return self.__original_call__(*args, **kwargs)
75
-
76
- StableDiffusionXLPipeline.__original_call__ = StableDiffusionXLPipeline.__call__
77
- StableDiffusionXLPipeline.__call__ = rw_sd_call
78
 
79
  class Demo:
80
 
 
63
  return base_network
64
 
65
  # 修改 call 方法以支持传递 networks 参数
66
+ # def rw_sd_call(self, *args, networks=None, scales=None, **kwargs):
67
+ # if networks is not None and scales is not None:
68
+ # for network, scale in zip(networks, scales):
69
+ # for name, param in network.named_parameters():
70
+ # if name in self.unet.state_dict():
71
+ # self.unet.state_dict()[name].add_(param * scale)
72
+ # else:
73
+ # self.unet.state_dict()[name] = param.clone() * scale
74
+ # return self.__original_call__(*args, **kwargs)
75
+
76
+ # StableDiffusionXLPipeline.__original_call__ = StableDiffusionXLPipeline.__call__
77
+ # StableDiffusionXLPipeline.__call__ = rw_sd_call
78
 
79
  class Demo:
80
 
utils.py CHANGED
@@ -66,9 +66,11 @@ def call(
66
  negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
67
  negative_target_size: Optional[Tuple[int, int]] = None,
68
 
69
- network=None,
 
70
  start_noise=None,
71
  scale=None,
 
72
  unet=None,
73
  ):
74
  r"""
@@ -318,12 +320,21 @@ def call(
318
  num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
319
  timesteps = timesteps[:num_inference_steps]
320
  latents = latents.to(unet.dtype)
 
 
 
 
 
 
 
321
  with self.progress_bar(total=num_inference_steps) as progress_bar:
322
- for i, t in enumerate(timesteps):
323
- if t>start_noise:
324
- network.set_lora_slider(scale=0)
325
- else:
326
- network.set_lora_slider(scale=scale)
 
 
327
  # expand the latents if we are doing classifier free guidance
328
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
329
 
@@ -331,15 +342,18 @@ def call(
331
 
332
  # predict the noise residual
333
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
334
- with network:
335
- noise_pred = unet(
336
- latent_model_input,
337
- t,
338
- encoder_hidden_states=prompt_embeds,
339
- cross_attention_kwargs=cross_attention_kwargs,
340
- added_cond_kwargs=added_cond_kwargs,
341
- return_dict=False,
342
- )[0]
 
 
 
343
 
344
  # perform guidance
345
  if do_classifier_free_guidance:
 
66
  negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
67
  negative_target_size: Optional[Tuple[int, int]] = None,
68
 
69
+ network=None,
70
+ networks=None,
71
  start_noise=None,
72
  scale=None,
73
+ scales=None,
74
  unet=None,
75
  ):
76
  r"""
 
320
  num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
321
  timesteps = timesteps[:num_inference_steps]
322
  latents = latents.to(unet.dtype)
323
+
324
+ # 统一处理 network,scale | 处理成 list
325
+ if network is not None:
326
+ networks = [network]
327
+ if scale is not None:
328
+ scales = [scale]
329
+
330
  with self.progress_bar(total=num_inference_steps) as progress_bar:
331
+ for i, t in enumerate(timesteps):
332
+ # 遍历所有网络,设置 scale
333
+ if networks is not None and scales is not None:
334
+ for _network, _scale in zip(networks, scales):
335
+ with _network:
336
+ _network.set_lora_slider(scale=0 if t > start_noise else _scale)
337
+
338
  # expand the latents if we are doing classifier free guidance
339
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
340
 
 
342
 
343
  # predict the noise residual
344
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
345
+ # 应用多个 LoRA 网络
346
+ if networks is not None and scales is not None:
347
+ for _network, _scale in zip(networks, scales):
348
+ with _network:
349
+ noise_pred = self.unet(
350
+ latent_model_input,
351
+ t,
352
+ encoder_hidden_states=prompt_embeds,
353
+ cross_attention_kwargs=cross_attention_kwargs,
354
+ added_cond_kwargs=added_cond_kwargs,
355
+ return_dict=False,
356
+ )[0]
357
 
358
  # perform guidance
359
  if do_classifier_free_guidance: