Framepack-H111 / funconvert_lora.py
rahul7star's picture
Upload 303 files
e0336bc verified
# convert_lora_i2v_to_fc.py
import torch
import safetensors.torch
import safetensors # Need this for safe_open
import argparse
import os
import re # Regular expressions might be useful for more complex key parsing if needed
# !!! IMPORTANT: Updated based on the output of analyze_wan_models.py !!!
# The base layer name identified with shape mismatch.
# Check your LoRA file's keys if they use a different prefix (e.g., 'transformer.')
# Assuming the base name identified in LoRA keys matches this.
BASE_LAYERS_TO_SKIP_LORA = {
"patch_embedding", # The layer name from the analysis output
# Add other layers here ONLY if the analysis revealed more mismatches
}
# !!! END IMPORTANT SECTION !!!
def get_base_layer_name(lora_key: str, prefixes = ["lora_transformer_", "lora_unet_"]):
"""
Attempts to extract the base model layer name from a LoRA key.
Handles common prefixes and suffixes. Adjust prefixes if needed.
Example: "lora_transformer_patch_embedding_down.weight" -> "patch_embedding"
"lora_transformer_blocks_0_attn_qkv.alpha" -> "blocks.0.attn.qkv"
Args:
lora_key (str): The key from the LoRA state dictionary.
prefixes (list[str]): A list of potential prefixes used in LoRA keys.
Returns:
str: The inferred base model layer name.
"""
cleaned_key = lora_key
# Remove known prefixes
for prefix in prefixes:
if cleaned_key.startswith(prefix):
cleaned_key = cleaned_key[len(prefix):]
break # Assume only one prefix matches
# Remove known suffixes
# Order matters slightly if one suffix is part of another; list longer ones first if needed
known_suffixes = [
".lora_up.weight",
".lora_down.weight",
"_lora_up.weight", # Include underscore variants just in case
"_lora_down.weight",
".alpha"
]
for suffix in known_suffixes:
if cleaned_key.endswith(suffix):
cleaned_key = cleaned_key[:-len(suffix)]
break
# Replace underscores used by some training scripts with periods for consistency
# if the original model uses periods (like typical PyTorch modules).
# Adjust this logic if the base model itself uses underscores extensively.
cleaned_key = cleaned_key.replace("_", ".")
# Specific fix for the target layer if prefix/suffix removal was incomplete or ambiguous
# This is somewhat heuristic and might need adjustment based on exact LoRA key naming.
if cleaned_key.startswith("patch.embedding"): # Handle case where prefix removal was incomplete
# Map potential variants back to the canonical name found in analysis
cleaned_key = "patch_embedding"
elif cleaned_key == "patch.embedding.weight": # If suffix removal left .weight attached somehow
cleaned_key = "patch_embedding"
# Add elif clauses here if other specific key mappings are needed
return cleaned_key
def convert_lora(source_lora_path: str, target_lora_path: str):
"""
Converts an i2v_14B LoRA to be compatible with i2v_14B_FC by
removing LoRA weights associated with layers that have incompatible shapes.
Args:
source_lora_path (str): Path to the input LoRA file (.safetensors).
target_lora_path (str): Path to save the converted LoRA file (.safetensors).
"""
print(f"Loading source LoRA from: {source_lora_path}")
if not os.path.exists(source_lora_path):
print(f"Error: Source file not found: {source_lora_path}")
return
try:
# Load tensors and metadata using safe_open for better handling
source_lora_state_dict = {}
metadata = {}
with safetensors.safe_open(source_lora_path, framework="pt", device="cpu") as f:
metadata = f.metadata() # Get metadata if it exists
if metadata is None: # Ensure metadata is a dict even if empty
metadata = {}
for key in f.keys():
source_lora_state_dict[key] = f.get_tensor(key) # Load tensors
print(f"Successfully loaded {len(source_lora_state_dict)} tensors.")
if metadata:
print(f"Found metadata: {metadata}")
else:
print("No metadata found.")
except Exception as e:
print(f"Error loading LoRA file: {e}")
import traceback
traceback.print_exc()
return
target_lora_state_dict = {}
skipped_keys = []
kept_keys = []
base_name_map = {} # Store mapping for reporting
print(f"\nConverting LoRA weights...")
print(f"Will skip LoRA weights targeting these base layers: {BASE_LAYERS_TO_SKIP_LORA}")
# Iterate through the loaded tensors
for key, tensor in source_lora_state_dict.items():
# Use the helper function to extract the base layer name
base_layer_name = get_base_layer_name(key)
base_name_map[key] = base_layer_name # Store for reporting purposes
# Check if the identified base layer name should be skipped
if base_layer_name in BASE_LAYERS_TO_SKIP_LORA:
skipped_keys.append(key)
else:
# Keep the tensor if its base layer is not in the skip list
target_lora_state_dict[key] = tensor
kept_keys.append(key)
# --- Reporting ---
print(f"\nConversion Summary:")
print(f" - Total Tensors in Source: {len(source_lora_state_dict)}")
print(f" - Kept {len(kept_keys)} LoRA weight tensors.")
print(f" - Skipped {len(skipped_keys)} LoRA weight tensors (due to incompatible base layer shape):")
if skipped_keys:
max_print = 15 # Show a few more skipped keys if desired
skipped_sorted = sorted(skipped_keys) # Sort for consistent output order
for i, key in enumerate(skipped_sorted):
base_name = base_name_map.get(key, "N/A") # Get the identified base name
print(f" - {key} (Base Layer Identified: {base_name})")
if i >= max_print -1 and len(skipped_keys) > max_print:
print(f" ... and {len(skipped_keys) - max_print} more.")
break
else:
print(" None")
# --- Saving ---
print(f"\nSaving converted LoRA ({len(target_lora_state_dict)} tensors) to: {target_lora_path}")
try:
# Save the filtered state dictionary with the original metadata
safetensors.torch.save_file(target_lora_state_dict, target_lora_path, metadata=metadata)
print("Conversion successful!")
except Exception as e:
print(f"Error saving converted LoRA file: {e}")
if __name__ == "__main__":
# Setup argument parser
parser = argparse.ArgumentParser(
description="Convert Wan i2v_14B LoRA to i2v_14B_FC LoRA by removing incompatible patch_embedding weights.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("source_lora", type=str, help="Path to the source i2v_14B LoRA file (.safetensors).")
parser.add_argument("target_lora", type=str, help="Path to save the converted i2v_14B_FC LoRA file (.safetensors).")
# Parse arguments
args = parser.parse_args()
# --- Input Validation ---
if not os.path.exists(args.source_lora):
print(f"Error: Source LoRA file not found at '{args.source_lora}'")
elif not args.source_lora.lower().endswith(".safetensors"):
print(f"Warning: Source file '{args.source_lora}' does not have a .safetensors extension.")
elif args.source_lora == args.target_lora:
print(f"Error: Source and target paths cannot be the same ('{args.source_lora}'). Choose a different target path.")
elif os.path.exists(args.target_lora):
print(f"Warning: Target file '{args.target_lora}' already exists and will be overwritten.")
# Optionally add a --force flag or prompt user here
convert_lora(args.source_lora, args.target_lora)
else:
# Run the conversion if basic checks pass
convert_lora(args.source_lora, args.target_lora)