Spaces:
Running
Running
# LoRA module for FramePack | |
import ast | |
from typing import Dict, List, Optional | |
import torch | |
import torch.nn as nn | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
import networks.lora as lora | |
FRAMEPACK_TARGET_REPLACE_MODULES = ["HunyuanVideoTransformerBlock", "HunyuanVideoSingleTransformerBlock"] | |
def create_arch_network( | |
multiplier: float, | |
network_dim: Optional[int], | |
network_alpha: Optional[float], | |
vae: nn.Module, | |
text_encoders: List[nn.Module], | |
unet: nn.Module, | |
neuron_dropout: Optional[float] = None, | |
**kwargs, | |
): | |
# add default exclude patterns | |
exclude_patterns = kwargs.get("exclude_patterns", None) | |
if exclude_patterns is None: | |
exclude_patterns = [] | |
else: | |
exclude_patterns = ast.literal_eval(exclude_patterns) | |
# exclude if 'norm' in the name of the module | |
exclude_patterns.append(r".*(norm).*") | |
kwargs["exclude_patterns"] = exclude_patterns | |
return lora.create_network( | |
FRAMEPACK_TARGET_REPLACE_MODULES, | |
"lora_unet", | |
multiplier, | |
network_dim, | |
network_alpha, | |
vae, | |
text_encoders, | |
unet, | |
neuron_dropout=neuron_dropout, | |
**kwargs, | |
) | |
def create_arch_network_from_weights( | |
multiplier: float, | |
weights_sd: Dict[str, torch.Tensor], | |
text_encoders: Optional[List[nn.Module]] = None, | |
unet: Optional[nn.Module] = None, | |
for_inference: bool = False, | |
**kwargs, | |
) -> lora.LoRANetwork: | |
return lora.create_network_from_weights( | |
FRAMEPACK_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs | |
) | |