Framepack-H111 / convert_hunyuan_to_framepack.py
rahul7star's picture
Upload 303 files
e0336bc verified
# convert_lora.py
import argparse
import os
import re
import torch
from safetensors.torch import load_file, save_file
import logging
# Configure logging similar to the utility file
logger = logging.getLogger(__name__)
# Avoid re-configuring if basicConfig was already called by the imported module
if not logging.root.handlers:
logging.basicConfig(level=logging.INFO)
# Assuming framepack_lora_inf_utils.py is in the same directory
try:
from framepack_lora_inf_utils import (
convert_hunyuan_to_framepack,
convert_from_diffusion_pipe_or_something,
)
except ImportError:
logger.error("Error: Could not import conversion functions from framepack_lora_inf_utils.")
logger.error("Please make sure framepack_lora_inf_utils.py is in the same directory as this script.")
exit(1)
def main():
"""
Main function to parse arguments and perform the LoRA conversion,
detecting the input format (Hunyuan or Diffusion Pipe-like).
"""
parser = argparse.ArgumentParser(description="Convert various LoRA formats to FramePack format.")
parser.add_argument(
"--input_lora",
type=str,
required=True,
help="Path to the input LoRA .safetensors file (Hunyuan, Diffusion Pipe-like, or Musubi).",
)
parser.add_argument(
"--output_lora",
type=str,
required=True,
help="Path to save the converted FramePack LoRA .safetensors file.",
)
args = parser.parse_args()
input_file = args.input_lora
output_file = args.output_lora
# Validate input file
if not os.path.exists(input_file):
logger.error(f"Input file not found: {input_file}")
exit(1)
if not input_file.lower().endswith(".safetensors"):
logger.warning(f"Input file '{input_file}' does not end with .safetensors. Proceeding anyway.")
# Ensure output directory exists
output_dir = os.path.dirname(output_file)
if output_dir and not os.path.exists(output_dir):
try:
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Created output directory: {output_dir}")
except OSError as e:
logger.error(f"Error creating output directory {output_dir}: {e}")
exit(1)
# Ensure output file ends with .safetensors
if not output_file.lower().endswith(".safetensors"):
output_file += ".safetensors"
logger.warning(f"Output file name did not end with .safetensors. Appended: {output_file}")
logger.info(f"Loading input LoRA file: {input_file}")
loaded_lora_sd = None
try:
# Load the state dictionary from the input .safetensors file
loaded_lora_sd = load_file(input_file)
logger.info(f"Input LoRA loaded successfully. Found {len(loaded_lora_sd)} keys.")
except Exception as e:
logger.error(f"Error loading input LoRA file {input_file}: {e}")
exit(1)
# --- Determine LoRA format and apply conversion(s) ---
# Following the logic flow from merge_lora_to_state_dict
processed_lora_sd = None # This will hold the SD after potential conversions
lora_keys = list(loaded_lora_sd.keys()) if loaded_lora_sd else []
if not lora_keys:
logger.error("Input LoRA file was empty or failed to load keys.")
exit(1)
# 1. Check if it's Musubi Tuner format (first key starts with "lora_unet_")
is_musubi = lora_keys[0].startswith("lora_unet_")
if is_musubi:
logger.info("Detected Musubi Tuner format based on first key.")
# Keep the original SD for now, as Musubi format might still contain Hunyuan patterns
current_lora_sd_to_check = loaded_lora_sd
else:
# 2. If not Musubi (based on first key), check for Diffusion Pipe format
diffusion_pipe_pattern_found = False
transformer_prefixes = ["diffusion_model", "transformer"]
lora_suffix_A_or_B_found = False
# Find the first key with .lora_A or .lora_B and check its prefix
for key in lora_keys:
if ".lora_A." in key or ".lora_B." in key:
lora_suffix_A_or_B_found = True
pfx = key.split(".")[0]
if pfx in transformer_prefixes:
diffusion_pipe_pattern_found = True
break # Found the required pattern
if diffusion_pipe_pattern_found:
logger.info(f"Detected Diffusion Pipe (?) format based on keys like '{pfx}.*.lora_A/B.'. Attempting conversion...")
target_prefix_for_diffusers_conversion = "lora_unet_"
try:
# Apply the Diffusion Pipe conversion
current_lora_sd_to_check = convert_from_diffusion_pipe_or_something(loaded_lora_sd, target_prefix_for_diffusers_conversion)
logger.info("Converted from Diffusion Pipe format.")
except Exception as e:
logger.error(f"Error during Diffusion Pipe conversion: {e}", exc_info=True) # Log traceback
current_lora_sd_to_check = None # Conversion failed, treat as unprocessable
else:
# If not Musubi and not Diffusion Pipe, the format is unrecognized initially
logger.warning(f"LoRA file format not recognized based on common patterns (Musubi, Diffusion Pipe-like). Checking for Hunyuan anyway...")
current_lora_sd_to_check = loaded_lora_sd # Keep the original SD to check for Hunyuan keys next
# 3. Check for Hunyuan pattern (double_blocks/single_blocks) in the *current* state dict
if current_lora_sd_to_check is not None:
keys_for_hunyuan_check = list(current_lora_sd_to_check.keys())
is_hunyuan_pattern_found = any("double_blocks" in key or "single_blocks" in key for key in keys_for_hunyuan_check)
if is_hunyuan_pattern_found:
logger.info("Detected HunyuanVideo format based on keys (double_blocks/single_blocks). Attempting conversion...")
try:
# Apply the Hunyuan conversion
processed_lora_sd = convert_hunyuan_to_framepack(current_lora_sd_to_check)
logger.info("Converted from HunyuanVideo format.")
except Exception as e:
logger.error(f"Error during HunyuanVideo conversion: {e}", exc_info=True) # Log traceback
processed_lora_sd = None # Conversion failed, treat as unprocessable
else:
# If Hunyuan pattern is not found, the current_lora_sd_to_check is the final state
# (either the original Musubi SD, or the SD converted from Diffusion Pipe).
processed_lora_sd = current_lora_sd_to_check
if not is_musubi and not diffusion_pipe_pattern_found:
# If we reached here and neither Musubi nor Diffusion Pipe patterns were initially found,
# and no Hunyuan pattern was found either, then the format is truly unrecognized.
logger.warning("Input LoRA does not match Musubi, Diffusion Pipe-like, or Hunyuan patterns.")
# Log keys to help debugging unrecognized formats
logger.info(f"Input LoRA keys start with: {lora_keys[:20]}...") # Show first few keys
processed_lora_sd = None # Mark as unprocessable
# --- Final check and saving ---
if processed_lora_sd is None or not processed_lora_sd:
logger.error("Could not convert the input LoRA file to a recognized FramePack format.")
logger.error("Input file format not recognized or conversion failed.")
# Log keys if conversion didn't happen due to format not matching
if loaded_lora_sd is not None:
logger.info(f"Input LoRA keys start with: {lora_keys[:20]}...") # Show first few keys
exit(1) # Exit if conversion failed or no data resulted
logger.info(f"Conversion complete. Converted state dictionary contains {len(processed_lora_sd)} keys.")
logger.info(f"Saving converted LoRA file to: {output_file}")
# --- WORKAROUND for older safetensors versions that don't support allow_shared=True ---
# The conversion functions might create shared tensors in the dictionary.
# Older safetensors versions require tensors to be distinct objects for save_file.
# We create a deep copy of tensors to satisfy this requirement.
# The recommended fix is to upgrade safetensors (pip install --upgrade safetensors)
# and use allow_shared=True in save_file.
logger.info("Checking for shared tensors and creating copies for saving (workaround for older safetensors)...")
processed_lora_sd_copy = {}
for key, tensor in processed_lora_sd.items():
if isinstance(tensor, torch.Tensor):
# Create a new tensor with copied data, detached from any computation graph
processed_lora_sd_copy[key] = tensor.clone().detach()
else:
# Keep non-tensor items (like alpha which might be a scalar number) as is
processed_lora_sd_copy[key] = tensor
logger.info("Deep copy complete.")
# --- END OF WORKAROUND ---
try:
# Save using the deep-copied dictionary.
# This works with older safetensors versions (pre-0.3.0)
# If you upgraded safetensors to 0.3.0 or later, you could use:
# save_file(processed_lora_sd, output_file, allow_shared=True)
save_file(processed_lora_sd_copy, output_file)
logger.info(f"Successfully saved converted LoRA to {output_file}")
except Exception as e:
# Note: If you still get a shared memory error here, it implies the deep copy
# workaround didn't fully resolve it for your specific setup, OR the error
# is coming from a different source. Upgrading safetensors is then highly recommended.
logger.error(f"Error saving converted LoRA file {output_file}: {e}", exc_info=True) # Log traceback
exit(1)
if __name__ == "__main__":
main()