import os import json import torch from safetensors.torch import load_file, save_file import logging import shutil from typing import Dict, Any, Set import re logger = logging.getLogger("PeftMerger") logger.setLevel(logging.INFO) console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") console_handler.setFormatter(formatter) logger.addHandler(console_handler) def normalize_key(key: str) -> str: """Normalize key format to match base model""" key = key.replace("transformer.double_blocks", "transformer_blocks") key = key.replace("transformer.single_blocks", "single_transformer_blocks") key = re.sub(r'\.+', '.', key) # Remove double dots if key.endswith('.'): key = key[:-1] return key def merge_lora_weights(base_weights: Dict[str, torch.Tensor], lora_weights: Dict[str, torch.Tensor], alpha: float = 1.0) -> Dict[str, torch.Tensor]: """Merge LoRA weights into base model weights""" merged = base_weights.copy() # Print first few keys for debugging logger.info(f"Base model keys (first 5): {list(base_weights.keys())[:5]}") logger.info(f"LoRA keys (first 5): {list(lora_weights.keys())[:5]}") # Process LoRA keys for key in lora_weights.keys(): if '.lora_A.weight' not in key: continue logger.info(f"Processing LoRA key: {key}") base_key = key.replace('.lora_A.weight', '') lora_a = lora_weights[key] lora_b = lora_weights[base_key + '.lora_B.weight'] # Normalize after getting both A and B weights normalized_key = normalize_key(base_key) logger.info(f"Normalized key: {normalized_key}") # Map double blocks if 'img_attn_qkv' in base_key: weights = torch.matmul(lora_b, lora_a) q, k, v = torch.chunk(weights, 3, dim=0) block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) block_num = block_match.group(1) q_key = f'transformer_blocks.{block_num}.attn.to_q.weight' k_key = f'transformer_blocks.{block_num}.attn.to_k.weight' v_key = f'transformer_blocks.{block_num}.attn.to_v.weight' if all(k in merged for k in [q_key, k_key, v_key]): merged[q_key] = merged[q_key] + alpha * q merged[k_key] = merged[k_key] + alpha * k merged[v_key] = merged[v_key] + alpha * v logger.info(f"Updated keys: {q_key}, {k_key}, {v_key}") else: logger.warning(f"Missing some keys: {[k for k in [q_key, k_key, v_key] if k not in merged]}") elif 'txt_attn_qkv' in base_key: weights = torch.matmul(lora_b, lora_a) q, k, v = torch.chunk(weights, 3, dim=0) block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) block_num = block_match.group(1) q_key = f'transformer_blocks.{block_num}.attn.add_q_proj.weight' k_key = f'transformer_blocks.{block_num}.attn.add_k_proj.weight' v_key = f'transformer_blocks.{block_num}.attn.add_v_proj.weight' if all(k in merged for k in [q_key, k_key, v_key]): merged[q_key] = merged[q_key] + alpha * q merged[k_key] = merged[k_key] + alpha * k merged[v_key] = merged[v_key] + alpha * v logger.info(f"Updated keys: {q_key}, {k_key}, {v_key}") else: logger.warning(f"Missing some keys: {[k for k in [q_key, k_key, v_key] if k not in merged]}") elif 'img_attn_proj' in base_key: block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) block_num = block_match.group(1) model_key = f'transformer_blocks.{block_num}.attn.to_out.0.weight' if model_key in merged: weights = torch.matmul(lora_b, lora_a) merged[model_key] = merged[model_key] + alpha * weights logger.info(f"Updated key: {model_key}") else: logger.warning(f"Missing key: {model_key}") elif 'txt_attn_proj' in base_key: block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) block_num = block_match.group(1) model_key = f'transformer_blocks.{block_num}.attn.to_add_out.weight' if model_key in merged: weights = torch.matmul(lora_b, lora_a) merged[model_key] = merged[model_key] + alpha * weights logger.info(f"Updated key: {model_key}") else: logger.warning(f"Missing key: {model_key}") elif 'img_mlp.fc1' in base_key: block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) block_num = block_match.group(1) model_key = f'transformer_blocks.{block_num}.ff.net.0.proj.weight' if model_key in merged: weights = torch.matmul(lora_b, lora_a) merged[model_key] = merged[model_key] + alpha * weights logger.info(f"Updated key: {model_key}") else: logger.warning(f"Missing key: {model_key}") elif 'img_mlp.fc2' in base_key: block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) block_num = block_match.group(1) model_key = f'transformer_blocks.{block_num}.ff.net.2.weight' if model_key in merged: weights = torch.matmul(lora_b, lora_a) merged[model_key] = merged[model_key] + alpha * weights logger.info(f"Updated key: {model_key}") else: logger.warning(f"Missing key: {model_key}") elif 'txt_mlp.fc1' in base_key: block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) block_num = block_match.group(1) model_key = f'transformer_blocks.{block_num}.ff_context.net.0.proj.weight' if model_key in merged: weights = torch.matmul(lora_b, lora_a) merged[model_key] = merged[model_key] + alpha * weights logger.info(f"Updated key: {model_key}") else: logger.warning(f"Missing key: {model_key}") elif 'txt_mlp.fc2' in base_key: block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) block_num = block_match.group(1) model_key = f'transformer_blocks.{block_num}.ff_context.net.2.weight' if model_key in merged: weights = torch.matmul(lora_b, lora_a) merged[model_key] = merged[model_key] + alpha * weights logger.info(f"Updated key: {model_key}") else: logger.warning(f"Missing key: {model_key}") return merged def save_sharded_model(weights: Dict[str, torch.Tensor], index_data: dict, output_dir: str, base_model_path: str): """Save merged weights in same sharded format as original""" os.makedirs(output_dir, exist_ok=True) # Copy all non-safetensor files from original directory index_dir = os.path.dirname(os.path.abspath(base_model_path)) for file in os.listdir(index_dir): if not file.endswith('.safetensors'): src = os.path.join(index_dir, file) dst = os.path.join(output_dir, file) if os.path.isfile(src): shutil.copy2(src, dst) elif os.path.isdir(src): shutil.copytree(src, dst) # Group weights by shard weight_map = index_data['weight_map'] shard_weights = {} for key, shard in weight_map.items(): if shard not in shard_weights: shard_weights[shard] = {} if key in weights: shard_weights[shard][key] = weights[key] # Save each shard for shard, shard_dict in shard_weights.items(): if not shard_dict: # Skip empty shards continue shard_path = os.path.join(output_dir, shard) logger.info(f"Saving shard {shard} with {len(shard_dict)} tensors") save_file(shard_dict, shard_path) def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--base_model", type=str, required=True) parser.add_argument("--adapter", type=str, required=True) parser.add_argument("--output", type=str, required=True) parser.add_argument("--alpha", type=float, default=1.0) args = parser.parse_args() # Load base model index logger.info("Loading base model index...") with open(args.base_model, 'r') as f: index_data = json.load(f) weight_map = index_data['weight_map'] # Load base weights logger.info("Loading base model weights...") base_dir = os.path.dirname(args.base_model) base_weights = {} for part_file in set(weight_map.values()): part_path = os.path.join(base_dir, part_file) logger.info(f"Loading from {part_path}") weights = load_file(part_path) base_weights.update(weights) # Load LoRA logger.info("Loading LoRA weights...") lora_weights = load_file(args.adapter) # Merge logger.info(f"Merging with alpha={args.alpha}") merged_weights = merge_lora_weights(base_weights, lora_weights, args.alpha) # Save in sharded format logger.info(f"Saving merged model to {args.output}") save_sharded_model(merged_weights, index_data, args.output, args.base_model) logger.info("Done!") if __name__ == '__main__': main()