Spaces:
Sleeping
Sleeping
""" | |
Evaluation Framework for Cross-Domain Uncertainty Quantification | |
This module provides functionality for evaluating uncertainty quantification methods | |
across different domains, including metrics for uncertainty quality and cross-domain performance. | |
""" | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from typing import List, Dict, Any, Union, Optional, Tuple | |
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc | |
class UncertaintyEvaluator: | |
"""Evaluator for uncertainty quantification methods.""" | |
def __init__(self, name: str): | |
""" | |
Initialize the uncertainty evaluator. | |
Args: | |
name: Name of the evaluation method | |
""" | |
self.name = name | |
def evaluate( | |
self, | |
uncertainties: List[float], | |
correctness: List[bool] | |
) -> Dict[str, float]: | |
""" | |
Evaluate uncertainty estimates against correctness. | |
Args: | |
uncertainties: List of uncertainty scores (higher means more uncertain) | |
correctness: List of boolean correctness indicators | |
Returns: | |
Dictionary of evaluation metrics | |
""" | |
raise NotImplementedError("Subclasses must implement this method") | |
class CalibrationEvaluator(UncertaintyEvaluator): | |
"""Evaluator for calibration quality.""" | |
def __init__(self): | |
"""Initialize the calibration evaluator.""" | |
super().__init__("calibration_evaluator") | |
def expected_calibration_error( | |
self, | |
confidences: List[float], | |
correctness: List[bool], | |
num_bins: int = 10 | |
) -> float: | |
""" | |
Calculate Expected Calibration Error (ECE). | |
Args: | |
confidences: List of confidence scores | |
correctness: List of boolean correctness indicators | |
num_bins: Number of bins for binning confidences | |
Returns: | |
Expected Calibration Error | |
""" | |
if len(confidences) != len(correctness): | |
raise ValueError("Confidences and correctness must have the same length") | |
if not confidences: | |
return 0.0 | |
# Create bins and calculate ECE | |
bin_indices = np.digitize(confidences, np.linspace(0, 1, num_bins)) | |
ece = 0.0 | |
for bin_idx in range(1, num_bins + 1): | |
bin_mask = (bin_indices == bin_idx) | |
if np.any(bin_mask): | |
bin_confidences = np.array(confidences)[bin_mask] | |
bin_correctness = np.array(correctness)[bin_mask] | |
bin_confidence = np.mean(bin_confidences) | |
bin_accuracy = np.mean(bin_correctness) | |
bin_size = np.sum(bin_mask) | |
# Weighted absolute difference between confidence and accuracy | |
ece += (bin_size / len(confidences)) * np.abs(bin_confidence - bin_accuracy) | |
return float(ece) | |
def maximum_calibration_error( | |
self, | |
confidences: List[float], | |
correctness: List[bool], | |
num_bins: int = 10 | |
) -> float: | |
""" | |
Calculate Maximum Calibration Error (MCE). | |
Args: | |
confidences: List of confidence scores | |
correctness: List of boolean correctness indicators | |
num_bins: Number of bins for binning confidences | |
Returns: | |
Maximum Calibration Error | |
""" | |
if len(confidences) != len(correctness): | |
raise ValueError("Confidences and correctness must have the same length") | |
if not confidences: | |
return 0.0 | |
# Create bins and calculate MCE | |
bin_indices = np.digitize(confidences, np.linspace(0, 1, num_bins)) | |
max_ce = 0.0 | |
for bin_idx in range(1, num_bins + 1): | |
bin_mask = (bin_indices == bin_idx) | |
if np.any(bin_mask): | |
bin_confidences = np.array(confidences)[bin_mask] | |
bin_correctness = np.array(correctness)[bin_mask] | |
bin_confidence = np.mean(bin_confidences) | |
bin_accuracy = np.mean(bin_correctness) | |
# Absolute difference between confidence and accuracy | |
ce = np.abs(bin_confidence - bin_accuracy) | |
max_ce = max(max_ce, ce) | |
return float(max_ce) | |
def evaluate( | |
self, | |
confidences: List[float], | |
correctness: List[bool] | |
) -> Dict[str, float]: | |
""" | |
Evaluate calibration quality. | |
Args: | |
confidences: List of confidence scores | |
correctness: List of boolean correctness indicators | |
Returns: | |
Dictionary of calibration metrics: | |
- ece: Expected Calibration Error | |
- mce: Maximum Calibration Error | |
""" | |
return { | |
"ece": self.expected_calibration_error(confidences, correctness), | |
"mce": self.maximum_calibration_error(confidences, correctness) | |
} | |
def plot_reliability_diagram( | |
self, | |
confidences: List[float], | |
correctness: List[bool], | |
num_bins: int = 10, | |
title: str = "Reliability Diagram", | |
save_path: Optional[str] = None | |
) -> None: | |
""" | |
Plot a reliability diagram for calibration visualization. | |
Args: | |
confidences: List of confidence scores | |
correctness: List of boolean correctness indicators | |
num_bins: Number of bins for binning confidences | |
title: Title for the plot | |
save_path: Path to save the plot (None to display) | |
""" | |
if len(confidences) != len(correctness): | |
raise ValueError("Confidences and correctness must have the same length") | |
# Create bins | |
bin_edges = np.linspace(0, 1, num_bins + 1) | |
bin_indices = np.digitize(confidences, bin_edges[:-1]) | |
# Calculate accuracy and confidence for each bin | |
bin_accuracies = [] | |
bin_confidences = [] | |
bin_sizes = [] | |
for bin_idx in range(1, num_bins + 1): | |
bin_mask = (bin_indices == bin_idx) | |
if np.any(bin_mask): | |
bin_confidences.append(np.mean(np.array(confidences)[bin_mask])) | |
bin_accuracies.append(np.mean(np.array(correctness)[bin_mask])) | |
bin_sizes.append(np.sum(bin_mask)) | |
else: | |
bin_confidences.append(0) | |
bin_accuracies.append(0) | |
bin_sizes.append(0) | |
# Plot reliability diagram | |
plt.figure(figsize=(10, 6)) | |
# Plot perfect calibration line | |
plt.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration') | |
# Plot bin accuracies vs. confidences | |
plt.bar( | |
bin_edges[:-1], | |
bin_accuracies, | |
width=1/num_bins, | |
align='edge', | |
alpha=0.7, | |
label='Observed Accuracy' | |
) | |
# Plot confidence histogram | |
ax2 = plt.twinx() | |
ax2.hist( | |
confidences, | |
bins=bin_edges, | |
alpha=0.3, | |
color='gray', | |
label='Confidence Histogram' | |
) | |
# Calculate ECE and MCE | |
ece = self.expected_calibration_error(confidences, correctness, num_bins) | |
mce = self.maximum_calibration_error(confidences, correctness, num_bins) | |
# Add ECE and MCE to title | |
plt.title(f"{title}\nECE: {ece:.4f}, MCE: {mce:.4f}") | |
# Add labels and legend | |
plt.xlabel('Confidence') | |
plt.ylabel('Accuracy') | |
ax2.set_ylabel('Count') | |
# Add legend | |
lines, labels = plt.gca().get_legend_handles_labels() | |
lines2, labels2 = ax2.get_legend_handles_labels() | |
ax2.legend(lines + lines2, labels + labels2, loc='best') | |
# Save or display the plot | |
if save_path: | |
plt.savefig(save_path) | |
plt.close() | |
else: | |
plt.tight_layout() | |
plt.show() | |
class SelectivePredictionEvaluator(UncertaintyEvaluator): | |
"""Evaluator for selective prediction performance.""" | |
def __init__(self): | |
"""Initialize the selective prediction evaluator.""" | |
super().__init__("selective_prediction_evaluator") | |
def evaluate( | |
self, | |
uncertainties: List[float], | |
correctness: List[bool] | |
) -> Dict[str, float]: | |
""" | |
Evaluate selective prediction performance. | |
Args: | |
uncertainties: List of uncertainty scores (higher means more uncertain) | |
correctness: List of boolean correctness indicators | |
Returns: | |
Dictionary of selective prediction metrics: | |
- auroc: Area Under ROC Curve for predicting errors | |
- auprc: Area Under Precision-Recall Curve for predicting errors | |
- uncertainty_error_correlation: Correlation between uncertainty and errors | |
""" | |
if len(uncertainties) != len(correctness): | |
raise ValueError("Uncertainties and correctness must have the same length") | |
if not uncertainties: | |
return { | |
"auroc": 0.5, | |
"auprc": 0.5, | |
"uncertainty_error_correlation": 0.0 | |
} | |
# Convert correctness to errors (1 for error, 0 for correct) | |
errors = [1 - int(c) for c in correctness] | |
# Calculate AUROC for predicting errors | |
try: | |
auroc = roc_auc_score(errors, uncertainties) | |
except: | |
# Handle case where all predictions are correct or all are wrong | |
auroc = 0.5 | |
# Calculate AUPRC for predicting errors | |
try: | |
precision, recall, _ = precision_recall_curve(errors, uncertainties) | |
auprc = auc(recall, precision) | |
except: | |
# Handle case where all predictions are correct or all are wrong | |
auprc = 0.5 | |
# Calculate correlation between uncertainty and errors | |
uncertainty_error_correlation = np.corrcoef(uncertainties, errors)[0, 1] | |
return { | |
"auroc": float(auroc), | |
"auprc": float(auprc), | |
"uncertainty_error_correlation": float(uncertainty_error_correlation) | |
} | |
def plot_selective_prediction_curve( | |
self, | |
uncertainties: List[float], | |
correctness: List[bool], | |
title: str = "Selective Prediction Performance", | |
save_path: Optional[str] = None | |
) -> None: | |
""" | |
Plot a selective prediction curve. | |
Args: | |
uncertainties: List of uncertainty scores (higher means more uncertain) | |
correctness: List of boolean correctness indicators | |
title: Title for the plot | |
save_path: Path to save the plot (None to display) | |
""" | |
if len(uncertainties) != len(correctness): | |
raise ValueError("Uncertainties and correctness must have the same length") | |
# Sort by uncertainty (ascending) | |
sorted_indices = np.argsort(uncertainties) | |
sorted_correctness = np.array(correctness)[sorted_indices] | |
# Calculate cumulative accuracy at different coverage levels | |
coverages = np.linspace(0, 1, 100) | |
accuracies = [] | |
for coverage in coverages: | |
if coverage == 0: | |
accuracies.append(1.0) # Perfect accuracy at 0% coverage | |
else: | |
n_samples = int(coverage * len(sorted_correctness)) | |
if n_samples == 0: | |
accuracies.append(1.0) | |
else: | |
accuracies.append(np.mean(sorted_correctness[:n_samples])) | |
# Plot selective prediction curve | |
plt.figure(figsize=(10, 6)) | |
plt.plot(coverages, accuracies, 'b-', linewidth=2) | |
# Add reference line for random selection | |
plt.plot([0, 1], [np.mean(correctness), np.mean(correctness)], 'k--', label='Random Selection') | |
# Calculate AUROC | |
metrics = self.evaluate(uncertainties, correctness) | |
# Add AUROC to title | |
plt.title(f"{title}\nAUROC: {metrics['auroc']:.4f}") | |
# Add labels and legend | |
plt.xlabel('Coverage') | |
plt.ylabel('Accuracy') | |
plt.legend(loc='best') | |
# Save or display the plot | |
if save_path: | |
plt.savefig(save_path) | |
plt.close() | |
else: | |
plt.tight_layout() | |
plt.show() | |
class CrossDomainEvaluator: | |
"""Evaluator for cross-domain uncertainty performance.""" | |
def __init__(self): | |
"""Initialize the cross-domain evaluator.""" | |
self.name = "cross_domain_evaluator" | |
self.calibration_evaluator = CalibrationEvaluator() | |
self.selective_prediction_evaluator = SelectivePredictionEvaluator() | |
def evaluate_domain_transfer( | |
self, | |
source_uncertainties: List[float], | |
source_correctness: List[bool], | |
target_uncertainties: List[float], | |
target_correctness: List[bool] | |
) -> Dict[str, float]: | |
""" | |
Evaluate domain transfer performance. | |
Args: | |
source_uncertainties: List of uncertainty scores from source domain | |
source_correctness: List of boolean correctness indicators from source domain | |
target_uncertainties: List of uncertainty scores from target domain | |
target_correctness: List of boolean correctness indicators from target domain | |
Returns: | |
Dictionary of domain transfer metrics: | |
- source_auroc: AUROC in source domain | |
- target_auroc: AUROC in target domain | |
- transfer_degradation: Degradation in AUROC from source to target | |
- source_ece: ECE in source domain | |
- target_ece: ECE in target domain | |
- calibration_shift: Shift in calibration from source to target | |
""" | |
# Evaluate source domain | |
source_selective = self.selective_prediction_evaluator.evaluate( | |
source_uncertainties, source_correctness | |
) | |
source_calibration = self.calibration_evaluator.evaluate( | |
[1 - u for u in source_uncertainties], source_correctness | |
) | |
# Evaluate target domain | |
target_selective = self.selective_prediction_evaluator.evaluate( | |
target_uncertainties, target_correctness | |
) | |
target_calibration = self.calibration_evaluator.evaluate( | |
[1 - u for u in target_uncertainties], target_correctness | |
) | |
# Calculate transfer metrics | |
transfer_degradation = source_selective["auroc"] - target_selective["auroc"] | |
calibration_shift = target_calibration["ece"] - source_calibration["ece"] | |
return { | |
"source_auroc": source_selective["auroc"], | |
"target_auroc": target_selective["auroc"], | |
"transfer_degradation": float(transfer_degradation), | |
"source_ece": source_calibration["ece"], | |
"target_ece": target_calibration["ece"], | |
"calibration_shift": float(calibration_shift) | |
} | |
def evaluate_all_domains( | |
self, | |
domain_results: Dict[str, Dict[str, Any]] | |
) -> Dict[str, Dict[str, float]]: | |
""" | |
Evaluate uncertainty performance across all domains. | |
Args: | |
domain_results: Dictionary mapping domain names to results | |
Each result should contain: | |
- uncertainties: List of uncertai | |
(Content truncated due to size limit. Use line ranges to read in chunks) |