Spaces:
Running
Running
File size: 1,691 Bytes
e0336bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
# LoRA module for Wan2.1
import ast
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
import networks.lora as lora
WAN_TARGET_REPLACE_MODULES = ["WanAttentionBlock"]
def create_arch_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: nn.Module,
text_encoders: List[nn.Module],
unet: nn.Module,
neuron_dropout: Optional[float] = None,
**kwargs,
):
# add default exclude patterns
exclude_patterns = kwargs.get("exclude_patterns", None)
if exclude_patterns is None:
exclude_patterns = []
else:
exclude_patterns = ast.literal_eval(exclude_patterns)
# exclude if 'img_mod', 'txt_mod' or 'modulation' in the name
exclude_patterns.append(r".*(patch_embedding|text_embedding|time_embedding|time_projection|norm|head).*")
kwargs["exclude_patterns"] = exclude_patterns
return lora.create_network(
WAN_TARGET_REPLACE_MODULES,
"lora_unet",
multiplier,
network_dim,
network_alpha,
vae,
text_encoders,
unet,
neuron_dropout=neuron_dropout,
**kwargs,
)
def create_arch_network_from_weights(
multiplier: float,
weights_sd: Dict[str, torch.Tensor],
text_encoders: Optional[List[nn.Module]] = None,
unet: Optional[nn.Module] = None,
for_inference: bool = False,
**kwargs,
) -> lora.LoRANetwork:
return lora.create_network_from_weights(
WAN_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs
)
|