# convert_lora_i2v_to_fc.py import torch import safetensors.torch import safetensors # Need this for safe_open import argparse import os import re # Regular expressions might be useful for more complex key parsing if needed # !!! IMPORTANT: Updated based on the output of analyze_wan_models.py !!! # The base layer name identified with shape mismatch. # Check your LoRA file's keys if they use a different prefix (e.g., 'transformer.') # Assuming the base name identified in LoRA keys matches this. BASE_LAYERS_TO_SKIP_LORA = { "patch_embedding", # The layer name from the analysis output # Add other layers here ONLY if the analysis revealed more mismatches } # !!! END IMPORTANT SECTION !!! def get_base_layer_name(lora_key: str, prefixes = ["lora_transformer_", "lora_unet_"]): """ Attempts to extract the base model layer name from a LoRA key. Handles common prefixes and suffixes. Adjust prefixes if needed. Example: "lora_transformer_patch_embedding_down.weight" -> "patch_embedding" "lora_transformer_blocks_0_attn_qkv.alpha" -> "blocks.0.attn.qkv" Args: lora_key (str): The key from the LoRA state dictionary. prefixes (list[str]): A list of potential prefixes used in LoRA keys. Returns: str: The inferred base model layer name. """ cleaned_key = lora_key # Remove known prefixes for prefix in prefixes: if cleaned_key.startswith(prefix): cleaned_key = cleaned_key[len(prefix):] break # Assume only one prefix matches # Remove known suffixes # Order matters slightly if one suffix is part of another; list longer ones first if needed known_suffixes = [ ".lora_up.weight", ".lora_down.weight", "_lora_up.weight", # Include underscore variants just in case "_lora_down.weight", ".alpha" ] for suffix in known_suffixes: if cleaned_key.endswith(suffix): cleaned_key = cleaned_key[:-len(suffix)] break # Replace underscores used by some training scripts with periods for consistency # if the original model uses periods (like typical PyTorch modules). # Adjust this logic if the base model itself uses underscores extensively. cleaned_key = cleaned_key.replace("_", ".") # Specific fix for the target layer if prefix/suffix removal was incomplete or ambiguous # This is somewhat heuristic and might need adjustment based on exact LoRA key naming. if cleaned_key.startswith("patch.embedding"): # Handle case where prefix removal was incomplete # Map potential variants back to the canonical name found in analysis cleaned_key = "patch_embedding" elif cleaned_key == "patch.embedding.weight": # If suffix removal left .weight attached somehow cleaned_key = "patch_embedding" # Add elif clauses here if other specific key mappings are needed return cleaned_key def convert_lora(source_lora_path: str, target_lora_path: str): """ Converts an i2v_14B LoRA to be compatible with i2v_14B_FC by removing LoRA weights associated with layers that have incompatible shapes. Args: source_lora_path (str): Path to the input LoRA file (.safetensors). target_lora_path (str): Path to save the converted LoRA file (.safetensors). """ print(f"Loading source LoRA from: {source_lora_path}") if not os.path.exists(source_lora_path): print(f"Error: Source file not found: {source_lora_path}") return try: # Load tensors and metadata using safe_open for better handling source_lora_state_dict = {} metadata = {} with safetensors.safe_open(source_lora_path, framework="pt", device="cpu") as f: metadata = f.metadata() # Get metadata if it exists if metadata is None: # Ensure metadata is a dict even if empty metadata = {} for key in f.keys(): source_lora_state_dict[key] = f.get_tensor(key) # Load tensors print(f"Successfully loaded {len(source_lora_state_dict)} tensors.") if metadata: print(f"Found metadata: {metadata}") else: print("No metadata found.") except Exception as e: print(f"Error loading LoRA file: {e}") import traceback traceback.print_exc() return target_lora_state_dict = {} skipped_keys = [] kept_keys = [] base_name_map = {} # Store mapping for reporting print(f"\nConverting LoRA weights...") print(f"Will skip LoRA weights targeting these base layers: {BASE_LAYERS_TO_SKIP_LORA}") # Iterate through the loaded tensors for key, tensor in source_lora_state_dict.items(): # Use the helper function to extract the base layer name base_layer_name = get_base_layer_name(key) base_name_map[key] = base_layer_name # Store for reporting purposes # Check if the identified base layer name should be skipped if base_layer_name in BASE_LAYERS_TO_SKIP_LORA: skipped_keys.append(key) else: # Keep the tensor if its base layer is not in the skip list target_lora_state_dict[key] = tensor kept_keys.append(key) # --- Reporting --- print(f"\nConversion Summary:") print(f" - Total Tensors in Source: {len(source_lora_state_dict)}") print(f" - Kept {len(kept_keys)} LoRA weight tensors.") print(f" - Skipped {len(skipped_keys)} LoRA weight tensors (due to incompatible base layer shape):") if skipped_keys: max_print = 15 # Show a few more skipped keys if desired skipped_sorted = sorted(skipped_keys) # Sort for consistent output order for i, key in enumerate(skipped_sorted): base_name = base_name_map.get(key, "N/A") # Get the identified base name print(f" - {key} (Base Layer Identified: {base_name})") if i >= max_print -1 and len(skipped_keys) > max_print: print(f" ... and {len(skipped_keys) - max_print} more.") break else: print(" None") # --- Saving --- print(f"\nSaving converted LoRA ({len(target_lora_state_dict)} tensors) to: {target_lora_path}") try: # Save the filtered state dictionary with the original metadata safetensors.torch.save_file(target_lora_state_dict, target_lora_path, metadata=metadata) print("Conversion successful!") except Exception as e: print(f"Error saving converted LoRA file: {e}") if __name__ == "__main__": # Setup argument parser parser = argparse.ArgumentParser( description="Convert Wan i2v_14B LoRA to i2v_14B_FC LoRA by removing incompatible patch_embedding weights.", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("source_lora", type=str, help="Path to the source i2v_14B LoRA file (.safetensors).") parser.add_argument("target_lora", type=str, help="Path to save the converted i2v_14B_FC LoRA file (.safetensors).") # Parse arguments args = parser.parse_args() # --- Input Validation --- if not os.path.exists(args.source_lora): print(f"Error: Source LoRA file not found at '{args.source_lora}'") elif not args.source_lora.lower().endswith(".safetensors"): print(f"Warning: Source file '{args.source_lora}' does not have a .safetensors extension.") elif args.source_lora == args.target_lora: print(f"Error: Source and target paths cannot be the same ('{args.source_lora}'). Choose a different target path.") elif os.path.exists(args.target_lora): print(f"Warning: Target file '{args.target_lora}' already exists and will be overwritten.") # Optionally add a --force flag or prompt user here convert_lora(args.source_lora, args.target_lora) else: # Run the conversion if basic checks pass convert_lora(args.source_lora, args.target_lora)