Spaces:
Running
Running
File size: 9,911 Bytes
e0336bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
# convert_lora.py
import argparse
import os
import re
import torch
from safetensors.torch import load_file, save_file
import logging
# Configure logging similar to the utility file
logger = logging.getLogger(__name__)
# Avoid re-configuring if basicConfig was already called by the imported module
if not logging.root.handlers:
logging.basicConfig(level=logging.INFO)
# Assuming framepack_lora_inf_utils.py is in the same directory
try:
from framepack_lora_inf_utils import (
convert_hunyuan_to_framepack,
convert_from_diffusion_pipe_or_something,
)
except ImportError:
logger.error("Error: Could not import conversion functions from framepack_lora_inf_utils.")
logger.error("Please make sure framepack_lora_inf_utils.py is in the same directory as this script.")
exit(1)
def main():
"""
Main function to parse arguments and perform the LoRA conversion,
detecting the input format (Hunyuan or Diffusion Pipe-like).
"""
parser = argparse.ArgumentParser(description="Convert various LoRA formats to FramePack format.")
parser.add_argument(
"--input_lora",
type=str,
required=True,
help="Path to the input LoRA .safetensors file (Hunyuan, Diffusion Pipe-like, or Musubi).",
)
parser.add_argument(
"--output_lora",
type=str,
required=True,
help="Path to save the converted FramePack LoRA .safetensors file.",
)
args = parser.parse_args()
input_file = args.input_lora
output_file = args.output_lora
# Validate input file
if not os.path.exists(input_file):
logger.error(f"Input file not found: {input_file}")
exit(1)
if not input_file.lower().endswith(".safetensors"):
logger.warning(f"Input file '{input_file}' does not end with .safetensors. Proceeding anyway.")
# Ensure output directory exists
output_dir = os.path.dirname(output_file)
if output_dir and not os.path.exists(output_dir):
try:
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Created output directory: {output_dir}")
except OSError as e:
logger.error(f"Error creating output directory {output_dir}: {e}")
exit(1)
# Ensure output file ends with .safetensors
if not output_file.lower().endswith(".safetensors"):
output_file += ".safetensors"
logger.warning(f"Output file name did not end with .safetensors. Appended: {output_file}")
logger.info(f"Loading input LoRA file: {input_file}")
loaded_lora_sd = None
try:
# Load the state dictionary from the input .safetensors file
loaded_lora_sd = load_file(input_file)
logger.info(f"Input LoRA loaded successfully. Found {len(loaded_lora_sd)} keys.")
except Exception as e:
logger.error(f"Error loading input LoRA file {input_file}: {e}")
exit(1)
# --- Determine LoRA format and apply conversion(s) ---
# Following the logic flow from merge_lora_to_state_dict
processed_lora_sd = None # This will hold the SD after potential conversions
lora_keys = list(loaded_lora_sd.keys()) if loaded_lora_sd else []
if not lora_keys:
logger.error("Input LoRA file was empty or failed to load keys.")
exit(1)
# 1. Check if it's Musubi Tuner format (first key starts with "lora_unet_")
is_musubi = lora_keys[0].startswith("lora_unet_")
if is_musubi:
logger.info("Detected Musubi Tuner format based on first key.")
# Keep the original SD for now, as Musubi format might still contain Hunyuan patterns
current_lora_sd_to_check = loaded_lora_sd
else:
# 2. If not Musubi (based on first key), check for Diffusion Pipe format
diffusion_pipe_pattern_found = False
transformer_prefixes = ["diffusion_model", "transformer"]
lora_suffix_A_or_B_found = False
# Find the first key with .lora_A or .lora_B and check its prefix
for key in lora_keys:
if ".lora_A." in key or ".lora_B." in key:
lora_suffix_A_or_B_found = True
pfx = key.split(".")[0]
if pfx in transformer_prefixes:
diffusion_pipe_pattern_found = True
break # Found the required pattern
if diffusion_pipe_pattern_found:
logger.info(f"Detected Diffusion Pipe (?) format based on keys like '{pfx}.*.lora_A/B.'. Attempting conversion...")
target_prefix_for_diffusers_conversion = "lora_unet_"
try:
# Apply the Diffusion Pipe conversion
current_lora_sd_to_check = convert_from_diffusion_pipe_or_something(loaded_lora_sd, target_prefix_for_diffusers_conversion)
logger.info("Converted from Diffusion Pipe format.")
except Exception as e:
logger.error(f"Error during Diffusion Pipe conversion: {e}", exc_info=True) # Log traceback
current_lora_sd_to_check = None # Conversion failed, treat as unprocessable
else:
# If not Musubi and not Diffusion Pipe, the format is unrecognized initially
logger.warning(f"LoRA file format not recognized based on common patterns (Musubi, Diffusion Pipe-like). Checking for Hunyuan anyway...")
current_lora_sd_to_check = loaded_lora_sd # Keep the original SD to check for Hunyuan keys next
# 3. Check for Hunyuan pattern (double_blocks/single_blocks) in the *current* state dict
if current_lora_sd_to_check is not None:
keys_for_hunyuan_check = list(current_lora_sd_to_check.keys())
is_hunyuan_pattern_found = any("double_blocks" in key or "single_blocks" in key for key in keys_for_hunyuan_check)
if is_hunyuan_pattern_found:
logger.info("Detected HunyuanVideo format based on keys (double_blocks/single_blocks). Attempting conversion...")
try:
# Apply the Hunyuan conversion
processed_lora_sd = convert_hunyuan_to_framepack(current_lora_sd_to_check)
logger.info("Converted from HunyuanVideo format.")
except Exception as e:
logger.error(f"Error during HunyuanVideo conversion: {e}", exc_info=True) # Log traceback
processed_lora_sd = None # Conversion failed, treat as unprocessable
else:
# If Hunyuan pattern is not found, the current_lora_sd_to_check is the final state
# (either the original Musubi SD, or the SD converted from Diffusion Pipe).
processed_lora_sd = current_lora_sd_to_check
if not is_musubi and not diffusion_pipe_pattern_found:
# If we reached here and neither Musubi nor Diffusion Pipe patterns were initially found,
# and no Hunyuan pattern was found either, then the format is truly unrecognized.
logger.warning("Input LoRA does not match Musubi, Diffusion Pipe-like, or Hunyuan patterns.")
# Log keys to help debugging unrecognized formats
logger.info(f"Input LoRA keys start with: {lora_keys[:20]}...") # Show first few keys
processed_lora_sd = None # Mark as unprocessable
# --- Final check and saving ---
if processed_lora_sd is None or not processed_lora_sd:
logger.error("Could not convert the input LoRA file to a recognized FramePack format.")
logger.error("Input file format not recognized or conversion failed.")
# Log keys if conversion didn't happen due to format not matching
if loaded_lora_sd is not None:
logger.info(f"Input LoRA keys start with: {lora_keys[:20]}...") # Show first few keys
exit(1) # Exit if conversion failed or no data resulted
logger.info(f"Conversion complete. Converted state dictionary contains {len(processed_lora_sd)} keys.")
logger.info(f"Saving converted LoRA file to: {output_file}")
# --- WORKAROUND for older safetensors versions that don't support allow_shared=True ---
# The conversion functions might create shared tensors in the dictionary.
# Older safetensors versions require tensors to be distinct objects for save_file.
# We create a deep copy of tensors to satisfy this requirement.
# The recommended fix is to upgrade safetensors (pip install --upgrade safetensors)
# and use allow_shared=True in save_file.
logger.info("Checking for shared tensors and creating copies for saving (workaround for older safetensors)...")
processed_lora_sd_copy = {}
for key, tensor in processed_lora_sd.items():
if isinstance(tensor, torch.Tensor):
# Create a new tensor with copied data, detached from any computation graph
processed_lora_sd_copy[key] = tensor.clone().detach()
else:
# Keep non-tensor items (like alpha which might be a scalar number) as is
processed_lora_sd_copy[key] = tensor
logger.info("Deep copy complete.")
# --- END OF WORKAROUND ---
try:
# Save using the deep-copied dictionary.
# This works with older safetensors versions (pre-0.3.0)
# If you upgraded safetensors to 0.3.0 or later, you could use:
# save_file(processed_lora_sd, output_file, allow_shared=True)
save_file(processed_lora_sd_copy, output_file)
logger.info(f"Successfully saved converted LoRA to {output_file}")
except Exception as e:
# Note: If you still get a shared memory error here, it implies the deep copy
# workaround didn't fully resolve it for your specific setup, OR the error
# is coming from a different source. Upgrading safetensors is then highly recommended.
logger.error(f"Error saving converted LoRA file {output_file}: {e}", exc_info=True) # Log traceback
exit(1)
if __name__ == "__main__":
main() |