Spaces:
Running
Running
import os | |
import json | |
import torch | |
from safetensors.torch import load_file, save_file | |
import logging | |
import shutil | |
from typing import Dict, Any, Set | |
import re | |
logger = logging.getLogger("PeftMerger") | |
logger.setLevel(logging.INFO) | |
console_handler = logging.StreamHandler() | |
console_handler.setLevel(logging.INFO) | |
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
console_handler.setFormatter(formatter) | |
logger.addHandler(console_handler) | |
def normalize_key(key: str) -> str: | |
"""Normalize key format to match base model""" | |
key = key.replace("transformer.double_blocks", "transformer_blocks") | |
key = key.replace("transformer.single_blocks", "single_transformer_blocks") | |
key = re.sub(r'\.+', '.', key) # Remove double dots | |
if key.endswith('.'): | |
key = key[:-1] | |
return key | |
def merge_lora_weights(base_weights: Dict[str, torch.Tensor], | |
lora_weights: Dict[str, torch.Tensor], | |
alpha: float = 1.0) -> Dict[str, torch.Tensor]: | |
"""Merge LoRA weights into base model weights""" | |
merged = base_weights.copy() | |
# Print first few keys for debugging | |
logger.info(f"Base model keys (first 5): {list(base_weights.keys())[:5]}") | |
logger.info(f"LoRA keys (first 5): {list(lora_weights.keys())[:5]}") | |
# Process LoRA keys | |
for key in lora_weights.keys(): | |
if '.lora_A.weight' not in key: | |
continue | |
logger.info(f"Processing LoRA key: {key}") | |
base_key = key.replace('.lora_A.weight', '') | |
lora_a = lora_weights[key] | |
lora_b = lora_weights[base_key + '.lora_B.weight'] | |
# Normalize after getting both A and B weights | |
normalized_key = normalize_key(base_key) | |
logger.info(f"Normalized key: {normalized_key}") | |
# Map double blocks | |
if 'img_attn_qkv' in base_key: | |
weights = torch.matmul(lora_b, lora_a) | |
q, k, v = torch.chunk(weights, 3, dim=0) | |
block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) | |
block_num = block_match.group(1) | |
q_key = f'transformer_blocks.{block_num}.attn.to_q.weight' | |
k_key = f'transformer_blocks.{block_num}.attn.to_k.weight' | |
v_key = f'transformer_blocks.{block_num}.attn.to_v.weight' | |
if all(k in merged for k in [q_key, k_key, v_key]): | |
merged[q_key] = merged[q_key] + alpha * q | |
merged[k_key] = merged[k_key] + alpha * k | |
merged[v_key] = merged[v_key] + alpha * v | |
logger.info(f"Updated keys: {q_key}, {k_key}, {v_key}") | |
else: | |
logger.warning(f"Missing some keys: {[k for k in [q_key, k_key, v_key] if k not in merged]}") | |
elif 'txt_attn_qkv' in base_key: | |
weights = torch.matmul(lora_b, lora_a) | |
q, k, v = torch.chunk(weights, 3, dim=0) | |
block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) | |
block_num = block_match.group(1) | |
q_key = f'transformer_blocks.{block_num}.attn.add_q_proj.weight' | |
k_key = f'transformer_blocks.{block_num}.attn.add_k_proj.weight' | |
v_key = f'transformer_blocks.{block_num}.attn.add_v_proj.weight' | |
if all(k in merged for k in [q_key, k_key, v_key]): | |
merged[q_key] = merged[q_key] + alpha * q | |
merged[k_key] = merged[k_key] + alpha * k | |
merged[v_key] = merged[v_key] + alpha * v | |
logger.info(f"Updated keys: {q_key}, {k_key}, {v_key}") | |
else: | |
logger.warning(f"Missing some keys: {[k for k in [q_key, k_key, v_key] if k not in merged]}") | |
elif 'img_attn_proj' in base_key: | |
block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) | |
block_num = block_match.group(1) | |
model_key = f'transformer_blocks.{block_num}.attn.to_out.0.weight' | |
if model_key in merged: | |
weights = torch.matmul(lora_b, lora_a) | |
merged[model_key] = merged[model_key] + alpha * weights | |
logger.info(f"Updated key: {model_key}") | |
else: | |
logger.warning(f"Missing key: {model_key}") | |
elif 'txt_attn_proj' in base_key: | |
block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) | |
block_num = block_match.group(1) | |
model_key = f'transformer_blocks.{block_num}.attn.to_add_out.weight' | |
if model_key in merged: | |
weights = torch.matmul(lora_b, lora_a) | |
merged[model_key] = merged[model_key] + alpha * weights | |
logger.info(f"Updated key: {model_key}") | |
else: | |
logger.warning(f"Missing key: {model_key}") | |
elif 'img_mlp.fc1' in base_key: | |
block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) | |
block_num = block_match.group(1) | |
model_key = f'transformer_blocks.{block_num}.ff.net.0.proj.weight' | |
if model_key in merged: | |
weights = torch.matmul(lora_b, lora_a) | |
merged[model_key] = merged[model_key] + alpha * weights | |
logger.info(f"Updated key: {model_key}") | |
else: | |
logger.warning(f"Missing key: {model_key}") | |
elif 'img_mlp.fc2' in base_key: | |
block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) | |
block_num = block_match.group(1) | |
model_key = f'transformer_blocks.{block_num}.ff.net.2.weight' | |
if model_key in merged: | |
weights = torch.matmul(lora_b, lora_a) | |
merged[model_key] = merged[model_key] + alpha * weights | |
logger.info(f"Updated key: {model_key}") | |
else: | |
logger.warning(f"Missing key: {model_key}") | |
elif 'txt_mlp.fc1' in base_key: | |
block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) | |
block_num = block_match.group(1) | |
model_key = f'transformer_blocks.{block_num}.ff_context.net.0.proj.weight' | |
if model_key in merged: | |
weights = torch.matmul(lora_b, lora_a) | |
merged[model_key] = merged[model_key] + alpha * weights | |
logger.info(f"Updated key: {model_key}") | |
else: | |
logger.warning(f"Missing key: {model_key}") | |
elif 'txt_mlp.fc2' in base_key: | |
block_match = re.search(r'transformer_blocks\.(\d+)', normalized_key) | |
block_num = block_match.group(1) | |
model_key = f'transformer_blocks.{block_num}.ff_context.net.2.weight' | |
if model_key in merged: | |
weights = torch.matmul(lora_b, lora_a) | |
merged[model_key] = merged[model_key] + alpha * weights | |
logger.info(f"Updated key: {model_key}") | |
else: | |
logger.warning(f"Missing key: {model_key}") | |
return merged | |
def save_sharded_model(weights: Dict[str, torch.Tensor], | |
index_data: dict, | |
output_dir: str, | |
base_model_path: str): | |
"""Save merged weights in same sharded format as original""" | |
os.makedirs(output_dir, exist_ok=True) | |
# Copy all non-safetensor files from original directory | |
index_dir = os.path.dirname(os.path.abspath(base_model_path)) | |
for file in os.listdir(index_dir): | |
if not file.endswith('.safetensors'): | |
src = os.path.join(index_dir, file) | |
dst = os.path.join(output_dir, file) | |
if os.path.isfile(src): | |
shutil.copy2(src, dst) | |
elif os.path.isdir(src): | |
shutil.copytree(src, dst) | |
# Group weights by shard | |
weight_map = index_data['weight_map'] | |
shard_weights = {} | |
for key, shard in weight_map.items(): | |
if shard not in shard_weights: | |
shard_weights[shard] = {} | |
if key in weights: | |
shard_weights[shard][key] = weights[key] | |
# Save each shard | |
for shard, shard_dict in shard_weights.items(): | |
if not shard_dict: # Skip empty shards | |
continue | |
shard_path = os.path.join(output_dir, shard) | |
logger.info(f"Saving shard {shard} with {len(shard_dict)} tensors") | |
save_file(shard_dict, shard_path) | |
def main(): | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--base_model", type=str, required=True) | |
parser.add_argument("--adapter", type=str, required=True) | |
parser.add_argument("--output", type=str, required=True) | |
parser.add_argument("--alpha", type=float, default=1.0) | |
args = parser.parse_args() | |
# Load base model index | |
logger.info("Loading base model index...") | |
with open(args.base_model, 'r') as f: | |
index_data = json.load(f) | |
weight_map = index_data['weight_map'] | |
# Load base weights | |
logger.info("Loading base model weights...") | |
base_dir = os.path.dirname(args.base_model) | |
base_weights = {} | |
for part_file in set(weight_map.values()): | |
part_path = os.path.join(base_dir, part_file) | |
logger.info(f"Loading from {part_path}") | |
weights = load_file(part_path) | |
base_weights.update(weights) | |
# Load LoRA | |
logger.info("Loading LoRA weights...") | |
lora_weights = load_file(args.adapter) | |
# Merge | |
logger.info(f"Merging with alpha={args.alpha}") | |
merged_weights = merge_lora_weights(base_weights, lora_weights, args.alpha) | |
# Save in sharded format | |
logger.info(f"Saving merged model to {args.output}") | |
save_sharded_model(merged_weights, index_data, args.output, args.base_model) | |
logger.info("Done!") | |
if __name__ == '__main__': | |
main() |