Spaces:
Running
Running
# convert_lora.py | |
import argparse | |
import os | |
import re | |
import torch | |
from safetensors.torch import load_file, save_file | |
import logging | |
# Configure logging similar to the utility file | |
logger = logging.getLogger(__name__) | |
# Avoid re-configuring if basicConfig was already called by the imported module | |
if not logging.root.handlers: | |
logging.basicConfig(level=logging.INFO) | |
# Assuming framepack_lora_inf_utils.py is in the same directory | |
try: | |
from framepack_lora_inf_utils import ( | |
convert_hunyuan_to_framepack, | |
convert_from_diffusion_pipe_or_something, | |
) | |
except ImportError: | |
logger.error("Error: Could not import conversion functions from framepack_lora_inf_utils.") | |
logger.error("Please make sure framepack_lora_inf_utils.py is in the same directory as this script.") | |
exit(1) | |
def main(): | |
""" | |
Main function to parse arguments and perform the LoRA conversion, | |
detecting the input format (Hunyuan or Diffusion Pipe-like). | |
""" | |
parser = argparse.ArgumentParser(description="Convert various LoRA formats to FramePack format.") | |
parser.add_argument( | |
"--input_lora", | |
type=str, | |
required=True, | |
help="Path to the input LoRA .safetensors file (Hunyuan, Diffusion Pipe-like, or Musubi).", | |
) | |
parser.add_argument( | |
"--output_lora", | |
type=str, | |
required=True, | |
help="Path to save the converted FramePack LoRA .safetensors file.", | |
) | |
args = parser.parse_args() | |
input_file = args.input_lora | |
output_file = args.output_lora | |
# Validate input file | |
if not os.path.exists(input_file): | |
logger.error(f"Input file not found: {input_file}") | |
exit(1) | |
if not input_file.lower().endswith(".safetensors"): | |
logger.warning(f"Input file '{input_file}' does not end with .safetensors. Proceeding anyway.") | |
# Ensure output directory exists | |
output_dir = os.path.dirname(output_file) | |
if output_dir and not os.path.exists(output_dir): | |
try: | |
os.makedirs(output_dir, exist_ok=True) | |
logger.info(f"Created output directory: {output_dir}") | |
except OSError as e: | |
logger.error(f"Error creating output directory {output_dir}: {e}") | |
exit(1) | |
# Ensure output file ends with .safetensors | |
if not output_file.lower().endswith(".safetensors"): | |
output_file += ".safetensors" | |
logger.warning(f"Output file name did not end with .safetensors. Appended: {output_file}") | |
logger.info(f"Loading input LoRA file: {input_file}") | |
loaded_lora_sd = None | |
try: | |
# Load the state dictionary from the input .safetensors file | |
loaded_lora_sd = load_file(input_file) | |
logger.info(f"Input LoRA loaded successfully. Found {len(loaded_lora_sd)} keys.") | |
except Exception as e: | |
logger.error(f"Error loading input LoRA file {input_file}: {e}") | |
exit(1) | |
# --- Determine LoRA format and apply conversion(s) --- | |
# Following the logic flow from merge_lora_to_state_dict | |
processed_lora_sd = None # This will hold the SD after potential conversions | |
lora_keys = list(loaded_lora_sd.keys()) if loaded_lora_sd else [] | |
if not lora_keys: | |
logger.error("Input LoRA file was empty or failed to load keys.") | |
exit(1) | |
# 1. Check if it's Musubi Tuner format (first key starts with "lora_unet_") | |
is_musubi = lora_keys[0].startswith("lora_unet_") | |
if is_musubi: | |
logger.info("Detected Musubi Tuner format based on first key.") | |
# Keep the original SD for now, as Musubi format might still contain Hunyuan patterns | |
current_lora_sd_to_check = loaded_lora_sd | |
else: | |
# 2. If not Musubi (based on first key), check for Diffusion Pipe format | |
diffusion_pipe_pattern_found = False | |
transformer_prefixes = ["diffusion_model", "transformer"] | |
lora_suffix_A_or_B_found = False | |
# Find the first key with .lora_A or .lora_B and check its prefix | |
for key in lora_keys: | |
if ".lora_A." in key or ".lora_B." in key: | |
lora_suffix_A_or_B_found = True | |
pfx = key.split(".")[0] | |
if pfx in transformer_prefixes: | |
diffusion_pipe_pattern_found = True | |
break # Found the required pattern | |
if diffusion_pipe_pattern_found: | |
logger.info(f"Detected Diffusion Pipe (?) format based on keys like '{pfx}.*.lora_A/B.'. Attempting conversion...") | |
target_prefix_for_diffusers_conversion = "lora_unet_" | |
try: | |
# Apply the Diffusion Pipe conversion | |
current_lora_sd_to_check = convert_from_diffusion_pipe_or_something(loaded_lora_sd, target_prefix_for_diffusers_conversion) | |
logger.info("Converted from Diffusion Pipe format.") | |
except Exception as e: | |
logger.error(f"Error during Diffusion Pipe conversion: {e}", exc_info=True) # Log traceback | |
current_lora_sd_to_check = None # Conversion failed, treat as unprocessable | |
else: | |
# If not Musubi and not Diffusion Pipe, the format is unrecognized initially | |
logger.warning(f"LoRA file format not recognized based on common patterns (Musubi, Diffusion Pipe-like). Checking for Hunyuan anyway...") | |
current_lora_sd_to_check = loaded_lora_sd # Keep the original SD to check for Hunyuan keys next | |
# 3. Check for Hunyuan pattern (double_blocks/single_blocks) in the *current* state dict | |
if current_lora_sd_to_check is not None: | |
keys_for_hunyuan_check = list(current_lora_sd_to_check.keys()) | |
is_hunyuan_pattern_found = any("double_blocks" in key or "single_blocks" in key for key in keys_for_hunyuan_check) | |
if is_hunyuan_pattern_found: | |
logger.info("Detected HunyuanVideo format based on keys (double_blocks/single_blocks). Attempting conversion...") | |
try: | |
# Apply the Hunyuan conversion | |
processed_lora_sd = convert_hunyuan_to_framepack(current_lora_sd_to_check) | |
logger.info("Converted from HunyuanVideo format.") | |
except Exception as e: | |
logger.error(f"Error during HunyuanVideo conversion: {e}", exc_info=True) # Log traceback | |
processed_lora_sd = None # Conversion failed, treat as unprocessable | |
else: | |
# If Hunyuan pattern is not found, the current_lora_sd_to_check is the final state | |
# (either the original Musubi SD, or the SD converted from Diffusion Pipe). | |
processed_lora_sd = current_lora_sd_to_check | |
if not is_musubi and not diffusion_pipe_pattern_found: | |
# If we reached here and neither Musubi nor Diffusion Pipe patterns were initially found, | |
# and no Hunyuan pattern was found either, then the format is truly unrecognized. | |
logger.warning("Input LoRA does not match Musubi, Diffusion Pipe-like, or Hunyuan patterns.") | |
# Log keys to help debugging unrecognized formats | |
logger.info(f"Input LoRA keys start with: {lora_keys[:20]}...") # Show first few keys | |
processed_lora_sd = None # Mark as unprocessable | |
# --- Final check and saving --- | |
if processed_lora_sd is None or not processed_lora_sd: | |
logger.error("Could not convert the input LoRA file to a recognized FramePack format.") | |
logger.error("Input file format not recognized or conversion failed.") | |
# Log keys if conversion didn't happen due to format not matching | |
if loaded_lora_sd is not None: | |
logger.info(f"Input LoRA keys start with: {lora_keys[:20]}...") # Show first few keys | |
exit(1) # Exit if conversion failed or no data resulted | |
logger.info(f"Conversion complete. Converted state dictionary contains {len(processed_lora_sd)} keys.") | |
logger.info(f"Saving converted LoRA file to: {output_file}") | |
# --- WORKAROUND for older safetensors versions that don't support allow_shared=True --- | |
# The conversion functions might create shared tensors in the dictionary. | |
# Older safetensors versions require tensors to be distinct objects for save_file. | |
# We create a deep copy of tensors to satisfy this requirement. | |
# The recommended fix is to upgrade safetensors (pip install --upgrade safetensors) | |
# and use allow_shared=True in save_file. | |
logger.info("Checking for shared tensors and creating copies for saving (workaround for older safetensors)...") | |
processed_lora_sd_copy = {} | |
for key, tensor in processed_lora_sd.items(): | |
if isinstance(tensor, torch.Tensor): | |
# Create a new tensor with copied data, detached from any computation graph | |
processed_lora_sd_copy[key] = tensor.clone().detach() | |
else: | |
# Keep non-tensor items (like alpha which might be a scalar number) as is | |
processed_lora_sd_copy[key] = tensor | |
logger.info("Deep copy complete.") | |
# --- END OF WORKAROUND --- | |
try: | |
# Save using the deep-copied dictionary. | |
# This works with older safetensors versions (pre-0.3.0) | |
# If you upgraded safetensors to 0.3.0 or later, you could use: | |
# save_file(processed_lora_sd, output_file, allow_shared=True) | |
save_file(processed_lora_sd_copy, output_file) | |
logger.info(f"Successfully saved converted LoRA to {output_file}") | |
except Exception as e: | |
# Note: If you still get a shared memory error here, it implies the deep copy | |
# workaround didn't fully resolve it for your specific setup, OR the error | |
# is coming from a different source. Upgrading safetensors is then highly recommended. | |
logger.error(f"Error saving converted LoRA file {output_file}: {e}", exc_info=True) # Log traceback | |
exit(1) | |
if __name__ == "__main__": | |
main() |