Spaces:
Running
Running
from pathlib import Path | |
from typing import Optional | |
from diffusers.loaders.lora_pipeline import _fetch_state_dict | |
from diffusers.loaders.lora_conversion_utils import _convert_hunyuan_video_lora_to_diffusers | |
def load_lora(transformer, lora_path: Path, weight_name: Optional[str] = "pytorch_lora_weights.safetensors", diffuser_lora: bool = False): | |
""" | |
Load LoRA weights into the transformer model. | |
Args: | |
transformer: The transformer model to which LoRA weights will be applied. | |
lora_path (Path): Path to the LoRA weights file. | |
weight_name (Optional[str]): Name of the weight to load. | |
""" | |
state_dict = _fetch_state_dict( | |
lora_path, | |
weight_name, | |
True, | |
True, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None) | |
if not diffuser_lora: | |
print("Not a diffusers lora, assuming Hunyuan.") | |
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) | |
transformer.load_lora_adapter(state_dict, network_alphas=None) | |
print("LoRA weights loaded successfully.") | |
return transformer | |