Spaces:
Running
Running
feat: 修改调用方法以支持传递网络和缩放参数
Browse files
app.py
CHANGED
@@ -62,6 +62,20 @@ def merge_lora_networks(networks):
|
|
62 |
base_network.state_dict()[name] = param.clone()
|
63 |
return base_network
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
class Demo:
|
66 |
|
67 |
def __init__(self) -> None:
|
@@ -382,7 +396,7 @@ class Demo:
|
|
382 |
alpha=alpha,
|
383 |
train_method=train_method,
|
384 |
).to(self.device, dtype=self.weight_dtype)
|
385 |
-
network.load_state_dict(torch.load(model_path))
|
386 |
networks.append(network)
|
387 |
|
388 |
# 设置种子
|
|
|
62 |
base_network.state_dict()[name] = param.clone()
|
63 |
return base_network
|
64 |
|
65 |
+
# 修改 call 方法以支持传递 networks 参数
|
66 |
+
def rw_sd_call(self, *args, networks=None, scales=None, **kwargs):
|
67 |
+
if networks is not None and scales is not None:
|
68 |
+
for network, scale in zip(networks, scales):
|
69 |
+
for name, param in network.named_parameters():
|
70 |
+
if name in self.unet.state_dict():
|
71 |
+
self.unet.state_dict()[name].add_(param * scale)
|
72 |
+
else:
|
73 |
+
self.unet.state_dict()[name] = param.clone() * scale
|
74 |
+
return self.__original_call__(*args, **kwargs)
|
75 |
+
|
76 |
+
StableDiffusionXLPipeline.__original_call__ = StableDiffusionXLPipeline.__call__
|
77 |
+
StableDiffusionXLPipeline.__call__ = rw_sd_call
|
78 |
+
|
79 |
class Demo:
|
80 |
|
81 |
def __init__(self) -> None:
|
|
|
396 |
alpha=alpha,
|
397 |
train_method=train_method,
|
398 |
).to(self.device, dtype=self.weight_dtype)
|
399 |
+
network.load_state_dict(torch.load(model_path, weights_only=True))
|
400 |
networks.append(network)
|
401 |
|
402 |
# 设置种子
|