# LoRA module for Wan2.1 import ast from typing import Dict, List, Optional import torch import torch.nn as nn import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) import networks.lora as lora WAN_TARGET_REPLACE_MODULES = ["WanAttentionBlock"] def create_arch_network( multiplier: float, network_dim: Optional[int], network_alpha: Optional[float], vae: nn.Module, text_encoders: List[nn.Module], unet: nn.Module, neuron_dropout: Optional[float] = None, **kwargs, ): # add default exclude patterns exclude_patterns = kwargs.get("exclude_patterns", None) if exclude_patterns is None: exclude_patterns = [] else: exclude_patterns = ast.literal_eval(exclude_patterns) # exclude if 'img_mod', 'txt_mod' or 'modulation' in the name exclude_patterns.append(r".*(patch_embedding|text_embedding|time_embedding|time_projection|norm|head).*") kwargs["exclude_patterns"] = exclude_patterns return lora.create_network( WAN_TARGET_REPLACE_MODULES, "lora_unet", multiplier, network_dim, network_alpha, vae, text_encoders, unet, neuron_dropout=neuron_dropout, **kwargs, ) def create_arch_network_from_weights( multiplier: float, weights_sd: Dict[str, torch.Tensor], text_encoders: Optional[List[nn.Module]] = None, unet: Optional[nn.Module] = None, for_inference: bool = False, **kwargs, ) -> lora.LoRANetwork: return lora.create_network_from_weights( WAN_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs )