Vanisper commited on
Commit
dfcfb88
·
1 Parent(s): a819b10

feat: 修改调用方法以支持传递网络和缩放参数

Browse files
Files changed (1) hide show
  1. app.py +15 -1
app.py CHANGED
@@ -62,6 +62,20 @@ def merge_lora_networks(networks):
62
  base_network.state_dict()[name] = param.clone()
63
  return base_network
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  class Demo:
66
 
67
  def __init__(self) -> None:
@@ -382,7 +396,7 @@ class Demo:
382
  alpha=alpha,
383
  train_method=train_method,
384
  ).to(self.device, dtype=self.weight_dtype)
385
- network.load_state_dict(torch.load(model_path))
386
  networks.append(network)
387
 
388
  # 设置种子
 
62
  base_network.state_dict()[name] = param.clone()
63
  return base_network
64
 
65
+ # 修改 call 方法以支持传递 networks 参数
66
+ def rw_sd_call(self, *args, networks=None, scales=None, **kwargs):
67
+ if networks is not None and scales is not None:
68
+ for network, scale in zip(networks, scales):
69
+ for name, param in network.named_parameters():
70
+ if name in self.unet.state_dict():
71
+ self.unet.state_dict()[name].add_(param * scale)
72
+ else:
73
+ self.unet.state_dict()[name] = param.clone() * scale
74
+ return self.__original_call__(*args, **kwargs)
75
+
76
+ StableDiffusionXLPipeline.__original_call__ = StableDiffusionXLPipeline.__call__
77
+ StableDiffusionXLPipeline.__call__ = rw_sd_call
78
+
79
  class Demo:
80
 
81
  def __init__(self) -> None:
 
396
  alpha=alpha,
397
  train_method=train_method,
398
  ).to(self.device, dtype=self.weight_dtype)
399
+ network.load_state_dict(torch.load(model_path, weights_only=True))
400
  networks.append(network)
401
 
402
  # 设置种子