Framepack-H111 / networks /lora_wan.py
rahul7star's picture
Upload 303 files
e0336bc verified
# LoRA module for Wan2.1
import ast
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
import networks.lora as lora
WAN_TARGET_REPLACE_MODULES = ["WanAttentionBlock"]
def create_arch_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: nn.Module,
text_encoders: List[nn.Module],
unet: nn.Module,
neuron_dropout: Optional[float] = None,
**kwargs,
):
# add default exclude patterns
exclude_patterns = kwargs.get("exclude_patterns", None)
if exclude_patterns is None:
exclude_patterns = []
else:
exclude_patterns = ast.literal_eval(exclude_patterns)
# exclude if 'img_mod', 'txt_mod' or 'modulation' in the name
exclude_patterns.append(r".*(patch_embedding|text_embedding|time_embedding|time_projection|norm|head).*")
kwargs["exclude_patterns"] = exclude_patterns
return lora.create_network(
WAN_TARGET_REPLACE_MODULES,
"lora_unet",
multiplier,
network_dim,
network_alpha,
vae,
text_encoders,
unet,
neuron_dropout=neuron_dropout,
**kwargs,
)
def create_arch_network_from_weights(
multiplier: float,
weights_sd: Dict[str, torch.Tensor],
text_encoders: Optional[List[nn.Module]] = None,
unet: Optional[nn.Module] = None,
for_inference: bool = False,
**kwargs,
) -> lora.LoRANetwork:
return lora.create_network_from_weights(
WAN_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs
)