Framepack-H111 / networks /lora_framepack.py
rahul7star's picture
Upload 303 files
e0336bc verified
# LoRA module for FramePack
import ast
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
import networks.lora as lora
FRAMEPACK_TARGET_REPLACE_MODULES = ["HunyuanVideoTransformerBlock", "HunyuanVideoSingleTransformerBlock"]
def create_arch_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: nn.Module,
text_encoders: List[nn.Module],
unet: nn.Module,
neuron_dropout: Optional[float] = None,
**kwargs,
):
# add default exclude patterns
exclude_patterns = kwargs.get("exclude_patterns", None)
if exclude_patterns is None:
exclude_patterns = []
else:
exclude_patterns = ast.literal_eval(exclude_patterns)
# exclude if 'norm' in the name of the module
exclude_patterns.append(r".*(norm).*")
kwargs["exclude_patterns"] = exclude_patterns
return lora.create_network(
FRAMEPACK_TARGET_REPLACE_MODULES,
"lora_unet",
multiplier,
network_dim,
network_alpha,
vae,
text_encoders,
unet,
neuron_dropout=neuron_dropout,
**kwargs,
)
def create_arch_network_from_weights(
multiplier: float,
weights_sd: Dict[str, torch.Tensor],
text_encoders: Optional[List[nn.Module]] = None,
unet: Optional[nn.Module] = None,
for_inference: bool = False,
**kwargs,
) -> lora.LoRANetwork:
return lora.create_network_from_weights(
FRAMEPACK_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs
)