import argparse import logging import gc import math from pathlib import Path from typing import Dict, Set, Tuple, List, Any import torch from safetensors import safe_open from tqdm import tqdm import numpy as np # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) # Use MPS if available on Mac, otherwise CUDA or CPU if torch.backends.mps.is_available(): DEFAULT_DEVICE = "mps" elif torch.cuda.is_available(): DEFAULT_DEVICE = "cuda" else: DEFAULT_DEVICE = "cpu" def get_tensor_keys(filepath: Path) -> Set[str]: """Gets all tensor keys from a safetensors file without loading tensors.""" keys = set() try: with safe_open(filepath, framework="pt", device="cpu") as f: keys = set(f.keys()) logging.debug(f"Found {len(keys)} keys in {filepath.name}") return keys except Exception as e: logging.error(f"Error opening or reading keys from {filepath}: {e}") raise def compare_tensors( key: str, file1: Path, file2: Path, device: torch.device, atol: float ) -> Tuple[bool, float, float, float]: """ Loads and compares a single tensor from two files. Args: key: The tensor key to compare. file1: Path to the first safetensors file. file2: Path to the second safetensors file. device: The torch device to use for comparison. atol: Absolute tolerance for torch.allclose check. Returns: Tuple containing: - is_close: Boolean indicating if tensors are close within tolerance. - mean_abs_diff: Mean absolute difference. - max_abs_diff: Maximum absolute difference. - cosine_sim: Cosine similarity (-2.0 if not applicable/error). """ # Initialize variables to handle potential early returns t1, t2, diff = None, None, None mean_abs_diff = float('nan') max_abs_diff = float('nan') cosine_sim = -2.0 # Use -2.0 to indicate not computed or error is_close = False try: # Use safe_open for lazy loading with safe_open(file1, framework="pt", device="cpu") as f1, \ safe_open(file2, framework="pt", device="cpu") as f2: if key not in f1.keys(): logging.warning(f"Key '{key}' missing in Model 1 ({file1.name}). Skipping comparison for this key.") # No need to return here, let finally block handle cleanup if t2 was loaded elif key not in f2.keys(): logging.warning(f"Key '{key}' missing in Model 2 ({file2.name}). Skipping comparison for this key.") # Load t1 to ensure it's deleted in finally if needed t1 = f1.get_tensor(key) else: # Both keys exist, proceed with loading t1 = f1.get_tensor(key) t2 = f2.get_tensor(key) # --- Basic Checks --- if t1.shape != t2.shape: logging.warning( f"Shape mismatch for key '{key}': {t1.shape} vs {t2.shape}. Cannot compare." ) # Return values indicating mismatch; t1/t2 will be cleaned up by finally return False, float('nan'), float('nan'), -2.0 # Use NaN/special value for mismatch if t1.dtype != t2.dtype: logging.warning( f"Dtype mismatch for key '{key}': {t1.dtype} vs {t2.dtype}. Will attempt cast for comparison." ) # Attempt comparison anyway, might fail or give less meaningful results try: t2 = t2.to(t1.dtype) except Exception as cast_e: logging.error(f"Could not cast tensor '{key}' for comparison: {cast_e}") # Return values indicating error; t1/t2 will be cleaned up by finally return False, float('nan'), float('nan'), -2.0 # --- Move to device for computation --- try: # Move original tensors (or casted t2) t1_dev = t1.to(device) t2_dev = t2.to(device) except Exception as move_e: logging.error(f"Could not move tensor '{key}' to device '{device}': {move_e}. Trying CPU.") device = torch.device('cpu') t1_dev = t1.to(device) t2_dev = t2.to(device) # --- Comparison Metrics --- with torch.no_grad(): # Use float32 for difference calculation stability diff = torch.abs(t1_dev.float() - t2_dev.float()) # Assign diff here mean_abs_diff = torch.mean(diff).item() max_abs_diff = torch.max(diff).item() # torch.allclose check is_close = torch.allclose(t1_dev, t2_dev, atol=atol, rtol=0) # rtol=0 for FP16 comparison mostly depends on atol # Cosine Similarity (avoid for scalars, ensure vectors are flat) if t1_dev.numel() > 1: try: # Ensure tensors are flat and float for cosine sim cos_sim_val = torch.nn.functional.cosine_similarity( t1_dev.flatten().float(), t2_dev.flatten().float(), dim=0 ).item() # Handle potential NaN/Inf from zero vectors etc. cosine_sim = cos_sim_val if math.isfinite(cos_sim_val) else -1.0 except Exception as cs_err: logging.warning(f"Could not compute cosine similarity for '{key}': {cs_err}") cosine_sim = -1.0 # Indicate computation error elif t1_dev.numel() == 1: cosine_sim = 1.0 if torch.allclose(t1_dev, t2_dev) else 0.0 # Define for scalars # Clean up device tensors explicitly after use del t1_dev, t2_dev except Exception as e: logging.error(f"Unhandled error comparing tensor '{key}': {e}", exc_info=True) # Return default failure values return False, float('nan'), float('nan'), -2.0 finally: # --- Modified Finally Block --- # Clear potential tensor references if t1 is not None: del t1 if t2 is not None: del t2 if diff is not None: # Now 'diff' might be defined or not del diff # Aggressive garbage collection and cache clearing gc.collect() if device.type == 'cuda': torch.cuda.empty_cache() elif device.type == 'mps': try: # Newer pytorch versions have empty_cache for mps torch.mps.empty_cache() except AttributeError: pass # Ignore if not available # Return the calculated values if comparison was successful return is_close, mean_abs_diff, max_abs_diff, cosine_sim def compare_models(file1_path: Path, file2_path: Path, device_str: str, atol: float, top_n_diff: int): """ Compares two safetensors models weight by weight. Args: file1_path: Path to the first model file. file2_path: Path to the second model file. device_str: Device string ('cpu', 'cuda', 'mps'). atol: Absolute tolerance for closeness check. top_n_diff: Number of most different tensors to report. """ if not file1_path.is_file(): logging.error(f"File not found: {file1_path}") return if not file2_path.is_file(): logging.error(f"File not found: {file2_path}") return try: device = torch.device(device_str) logging.info(f"Using device: {device}") except Exception as e: logging.warning(f"Could not select device '{device_str}', falling back to CPU. Error: {e}") device = torch.device("cpu") logging.info(f"Comparing Model 1: {file1_path.name}") logging.info(f" Model 2: {file2_path.name}") logging.info(f"Absolute tolerance (atol) for closeness: {atol}") try: keys1 = get_tensor_keys(file1_path) keys2 = get_tensor_keys(file2_path) except Exception: return # Error already logged by get_tensor_keys common_keys = sorted(list(keys1.intersection(keys2))) unique_keys1 = sorted(list(keys1 - keys2)) unique_keys2 = sorted(list(keys2 - keys1)) logging.info(f"Found {len(common_keys)} common tensor keys.") if unique_keys1: logging.warning(f"{len(unique_keys1)} keys unique to Model 1 ({file1_path.name}): {unique_keys1[:10]}{'...' if len(unique_keys1) > 10 else ''}") if unique_keys2: logging.warning(f"{len(unique_keys2)} keys unique to Model 2 ({file2_path.name}): {unique_keys2[:10]}{'...' if len(unique_keys2) > 10 else ''}") if not common_keys: logging.error("No common keys found between models. Cannot compare.") return results: List[Dict[str, Any]] = [] close_count = 0 compared_count = 0 # Track how many comparisons were attempted valid_comparisons = 0 # Track successful comparisons with numerical results mismatched_shape_keys = [] comparison_error_keys = [] all_mean_abs_diffs = [] all_max_abs_diffs = [] all_cosine_sims = [] logging.info("Starting tensor comparison...") for key in tqdm(common_keys, desc="Comparing Tensors"): compared_count += 1 is_close, mean_ad, max_ad, cos_sim = compare_tensors( key, file1_path, file2_path, device, atol ) # Check for comparison failure (NaN or -2) if math.isnan(mean_ad) or math.isnan(max_ad) or cos_sim == -2.0: # Check if it was specifically a shape mismatch (common case) # Re-check shapes briefly - less efficient but simple for logging try: with safe_open(file1_path, framework="pt", device="cpu") as f1, \ safe_open(file2_path, framework="pt", device="cpu") as f2: t1_shape = f1.get_shape(key) t2_shape = f2.get_shape(key) if t1_shape != t2_shape: mismatched_shape_keys.append(key) else: comparison_error_keys.append(key) # Other error except Exception: comparison_error_keys.append(key) # Error getting shapes or other issue logging.debug(f"Skipping results aggregation for key '{key}' due to comparison errors/mismatch.") continue # Skip adding results for this key # If we reach here, comparison was numerically successful valid_comparisons += 1 if is_close: close_count += 1 all_mean_abs_diffs.append(mean_ad) all_max_abs_diffs.append(max_ad) # Store cosine similarity if validly computed (-1 means computation issue like 0 vector) if cos_sim >= -1.0: # Allow -1 (error during calc) but not -2 (no calc attempted/major error) all_cosine_sims.append(cos_sim) results.append({ "key": key, "is_close": is_close, "mean_abs_diff": mean_ad, "max_abs_diff": max_ad, "cosine_sim": cos_sim }) # --- Summary --- logging.info("\n--- Comparison Summary ---") logging.info(f"Attempted comparison for {compared_count} common keys.") if mismatched_shape_keys: logging.warning(f"Found {len(mismatched_shape_keys)} keys with mismatched shapes (skipped): {mismatched_shape_keys[:5]}{'...' if len(mismatched_shape_keys) > 5 else ''}") if comparison_error_keys: logging.error(f"Encountered errors during comparison for {len(comparison_error_keys)} keys (skipped): {comparison_error_keys[:5]}{'...' if len(comparison_error_keys) > 5 else ''}") if valid_comparisons == 0: logging.error("No tensors could be validly compared numerically (check for shape mismatches or errors).") return logging.info(f"Successfully compared {valid_comparisons} tensors numerically.") logging.info(f"Tensors within tolerance (atol={atol}): {close_count} / {valid_comparisons} ({close_count/valid_comparisons:.2%})") # Calculate overall stats only on valid comparisons avg_mean_ad = np.mean(all_mean_abs_diffs) if all_mean_abs_diffs else float('nan') avg_max_ad = np.mean(all_max_abs_diffs) if all_max_abs_diffs else float('nan') overall_max_ad = np.max(all_max_abs_diffs) if all_max_abs_diffs else float('nan') overall_max_ad_key = max(results, key=lambda x: x.get('max_abs_diff', -float('inf')))['key'] if results else 'N/A' # Filter out -1 cosine sims before calculating stats if desired, or include them valid_cosine_sims = [cs for cs in all_cosine_sims if cs >= 0] # Only positive sims for avg/min avg_cosine_sim = np.mean(valid_cosine_sims) if valid_cosine_sims else float('nan') min_cosine_sim = np.min(valid_cosine_sims) if valid_cosine_sims else float('nan') logging.info(f"Average Mean Absolute Difference (MAD): {avg_mean_ad:.6g}") logging.info(f"Average Max Absolute Difference: {avg_max_ad:.6g}") logging.info(f"Overall Maximum Absolute Difference: {overall_max_ad:.6g} (found in tensor '{overall_max_ad_key}')") logging.info(f"Average Cosine Similarity (valid>=0): {avg_cosine_sim:.6f}" if not math.isnan(avg_cosine_sim) else "Average Cosine Similarity (valid>=0): N/A") logging.info(f"Minimum Cosine Similarity (valid>=0): {min_cosine_sim:.6f}" if not math.isnan(min_cosine_sim) else "Minimum Cosine Similarity (valid>=0): N/A") # --- Top Differences --- # Sort by max absolute difference descending (handle potential missing keys) results.sort(key=lambda x: x.get("max_abs_diff", -float('inf')), reverse=True) logging.info(f"\n--- Top {min(top_n_diff, len(results))} Tensors by Max Absolute Difference (Numerically Compared Only) ---") for i in range(min(top_n_diff, len(results))): res = results[i] # Ensure keys exist before accessing key = res.get('key', 'ERROR_MISSING_KEY') max_ad_val = res.get('max_abs_diff', float('nan')) mean_ad_val = res.get('mean_abs_diff', float('nan')) cos_sim_val = res.get('cosine_sim', float('nan')) close_val = res.get('is_close', 'N/A') logging.info( f"{i+1}. Key: {key:<50} " f"MaxAD: {max_ad_val:.6g} | " f"MeanAD: {mean_ad_val:.6g} | " f"CosSim: {cos_sim_val:.4f} | " f"Close: {close_val}" ) # --- Interpretation for LoRA --- logging.info("\n--- LoRA Compatibility Interpretation ---") # Prioritize architectural differences if unique_keys1 or unique_keys2 or mismatched_shape_keys: logging.error("Models have architectural differences (unique keys or mismatched shapes found). Direct LoRA swapping is NOT recommended.") if unique_keys1 or unique_keys2: logging.warning(" - Different sets of weights exist.") if mismatched_shape_keys: logging.warning(f" - Mismatched shapes found for keys like: {mismatched_shape_keys[0]}") elif comparison_error_keys: logging.warning("Some tensors could not be compared due to errors (other than shape mismatch). Check logs. LoRA compatibility might be affected.") else: # Assess based on numerical differences if architecture matches logging.info("Models appear to have the same architecture (matching keys and shapes). Assessing numerical similarity:") if avg_mean_ad < 1e-5 and overall_max_ad < 1e-3: logging.info(" -> Differences are very small. Models appear highly similar. High LoRA compatibility expected.") elif avg_mean_ad < 1e-4 and overall_max_ad < 5e-3: logging.info(" -> Differences are small. Models appear quite similar. Good LoRA compatibility expected.") elif avg_mean_ad < 1e-3 and overall_max_ad < 1e-2: logging.info(" -> Moderate differences detected. LoRAs might work but performance could vary, especially if targeting layers with larger differences.") else: logging.warning(" -> Significant numerical differences detected (Average MAD > 1e-3 or Overall MaxAD > 0.01). LoRA compatibility is questionable. Performance may degrade even with matching architecture.") if not math.isnan(min_cosine_sim) and min_cosine_sim < 0.98: # Stricter threshold for matching architecture logging.warning(f" -> Some tensors have lower cosine similarity (min >= 0: {min_cosine_sim:.4f}), indicating potential directional differences. This could affect LoRA.") def main(): parser = argparse.ArgumentParser( description="Compare weights between two safetensors model files.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "model1_path", type=str, help="Path to the first .safetensors model file." ) parser.add_argument( "model2_path", type=str, help="Path to the second .safetensors model file." ) parser.add_argument( "--device", type=str, default=DEFAULT_DEVICE, choices=["cpu", "cuda", "mps"], help="Device to use for tensor comparisons ('cuda'/'mps' recommended if available).", ) parser.add_argument( "--atol", type=float, default=1e-4, # A reasonable default for FP16 comparison help="Absolute tolerance (atol) for torch.allclose check to consider tensors 'close'.", ) parser.add_argument( "--top_n_diff", type=int, default=10, help="Report details for the top N tensors with the largest maximum absolute difference.", ) parser.add_argument( "-v", "--verbose", action="store_true", help="Enable debug logging." ) args = parser.parse_args() if args.verbose: logging.getLogger().setLevel(logging.DEBUG) compare_models( Path(args.model1_path), Path(args.model2_path), args.device, args.atol, args.top_n_diff, ) if __name__ == "__main__": main()