|
""" |
|
Evaluation Framework for Cross-Domain Uncertainty Quantification |
|
|
|
This module provides functionality for evaluating uncertainty quantification methods |
|
across different domains, including metrics for uncertainty quality and cross-domain performance. |
|
""" |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from typing import List, Dict, Any, Union, Optional, Tuple |
|
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc |
|
|
|
class UncertaintyEvaluator: |
|
"""Evaluator for uncertainty quantification methods.""" |
|
|
|
def __init__(self, name: str): |
|
""" |
|
Initialize the uncertainty evaluator. |
|
|
|
Args: |
|
name: Name of the evaluation method |
|
""" |
|
self.name = name |
|
|
|
def evaluate( |
|
self, |
|
uncertainties: List[float], |
|
correctness: List[bool] |
|
) -> Dict[str, float]: |
|
""" |
|
Evaluate uncertainty estimates against correctness. |
|
|
|
Args: |
|
uncertainties: List of uncertainty scores (higher means more uncertain) |
|
correctness: List of boolean correctness indicators |
|
|
|
Returns: |
|
Dictionary of evaluation metrics |
|
""" |
|
raise NotImplementedError("Subclasses must implement this method") |
|
|
|
|
|
class CalibrationEvaluator(UncertaintyEvaluator): |
|
"""Evaluator for calibration quality.""" |
|
|
|
def __init__(self): |
|
"""Initialize the calibration evaluator.""" |
|
super().__init__("calibration_evaluator") |
|
|
|
def expected_calibration_error( |
|
self, |
|
confidences: List[float], |
|
correctness: List[bool], |
|
num_bins: int = 10 |
|
) -> float: |
|
""" |
|
Calculate Expected Calibration Error (ECE). |
|
|
|
Args: |
|
confidences: List of confidence scores |
|
correctness: List of boolean correctness indicators |
|
num_bins: Number of bins for binning confidences |
|
|
|
Returns: |
|
Expected Calibration Error |
|
""" |
|
if len(confidences) != len(correctness): |
|
raise ValueError("Confidences and correctness must have the same length") |
|
|
|
if not confidences: |
|
return 0.0 |
|
|
|
|
|
bin_indices = np.digitize(confidences, np.linspace(0, 1, num_bins)) |
|
ece = 0.0 |
|
|
|
for bin_idx in range(1, num_bins + 1): |
|
bin_mask = (bin_indices == bin_idx) |
|
if np.any(bin_mask): |
|
bin_confidences = np.array(confidences)[bin_mask] |
|
bin_correctness = np.array(correctness)[bin_mask] |
|
bin_confidence = np.mean(bin_confidences) |
|
bin_accuracy = np.mean(bin_correctness) |
|
bin_size = np.sum(bin_mask) |
|
|
|
|
|
ece += (bin_size / len(confidences)) * np.abs(bin_confidence - bin_accuracy) |
|
|
|
return float(ece) |
|
|
|
def maximum_calibration_error( |
|
self, |
|
confidences: List[float], |
|
correctness: List[bool], |
|
num_bins: int = 10 |
|
) -> float: |
|
""" |
|
Calculate Maximum Calibration Error (MCE). |
|
|
|
Args: |
|
confidences: List of confidence scores |
|
correctness: List of boolean correctness indicators |
|
num_bins: Number of bins for binning confidences |
|
|
|
Returns: |
|
Maximum Calibration Error |
|
""" |
|
if len(confidences) != len(correctness): |
|
raise ValueError("Confidences and correctness must have the same length") |
|
|
|
if not confidences: |
|
return 0.0 |
|
|
|
|
|
bin_indices = np.digitize(confidences, np.linspace(0, 1, num_bins)) |
|
max_ce = 0.0 |
|
|
|
for bin_idx in range(1, num_bins + 1): |
|
bin_mask = (bin_indices == bin_idx) |
|
if np.any(bin_mask): |
|
bin_confidences = np.array(confidences)[bin_mask] |
|
bin_correctness = np.array(correctness)[bin_mask] |
|
bin_confidence = np.mean(bin_confidences) |
|
bin_accuracy = np.mean(bin_correctness) |
|
|
|
|
|
ce = np.abs(bin_confidence - bin_accuracy) |
|
max_ce = max(max_ce, ce) |
|
|
|
return float(max_ce) |
|
|
|
def evaluate( |
|
self, |
|
confidences: List[float], |
|
correctness: List[bool] |
|
) -> Dict[str, float]: |
|
""" |
|
Evaluate calibration quality. |
|
|
|
Args: |
|
confidences: List of confidence scores |
|
correctness: List of boolean correctness indicators |
|
|
|
Returns: |
|
Dictionary of calibration metrics: |
|
- ece: Expected Calibration Error |
|
- mce: Maximum Calibration Error |
|
""" |
|
return { |
|
"ece": self.expected_calibration_error(confidences, correctness), |
|
"mce": self.maximum_calibration_error(confidences, correctness) |
|
} |
|
|
|
def plot_reliability_diagram( |
|
self, |
|
confidences: List[float], |
|
correctness: List[bool], |
|
num_bins: int = 10, |
|
title: str = "Reliability Diagram", |
|
save_path: Optional[str] = None |
|
) -> None: |
|
""" |
|
Plot a reliability diagram for calibration visualization. |
|
|
|
Args: |
|
confidences: List of confidence scores |
|
correctness: List of boolean correctness indicators |
|
num_bins: Number of bins for binning confidences |
|
title: Title for the plot |
|
save_path: Path to save the plot (None to display) |
|
""" |
|
if len(confidences) != len(correctness): |
|
raise ValueError("Confidences and correctness must have the same length") |
|
|
|
|
|
bin_edges = np.linspace(0, 1, num_bins + 1) |
|
bin_indices = np.digitize(confidences, bin_edges[:-1]) |
|
|
|
|
|
bin_accuracies = [] |
|
bin_confidences = [] |
|
bin_sizes = [] |
|
|
|
for bin_idx in range(1, num_bins + 1): |
|
bin_mask = (bin_indices == bin_idx) |
|
if np.any(bin_mask): |
|
bin_confidences.append(np.mean(np.array(confidences)[bin_mask])) |
|
bin_accuracies.append(np.mean(np.array(correctness)[bin_mask])) |
|
bin_sizes.append(np.sum(bin_mask)) |
|
else: |
|
bin_confidences.append(0) |
|
bin_accuracies.append(0) |
|
bin_sizes.append(0) |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
|
|
|
|
plt.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration') |
|
|
|
|
|
plt.bar( |
|
bin_edges[:-1], |
|
bin_accuracies, |
|
width=1/num_bins, |
|
align='edge', |
|
alpha=0.7, |
|
label='Observed Accuracy' |
|
) |
|
|
|
|
|
ax2 = plt.twinx() |
|
ax2.hist( |
|
confidences, |
|
bins=bin_edges, |
|
alpha=0.3, |
|
color='gray', |
|
label='Confidence Histogram' |
|
) |
|
|
|
|
|
ece = self.expected_calibration_error(confidences, correctness, num_bins) |
|
mce = self.maximum_calibration_error(confidences, correctness, num_bins) |
|
|
|
|
|
plt.title(f"{title}\nECE: {ece:.4f}, MCE: {mce:.4f}") |
|
|
|
|
|
plt.xlabel('Confidence') |
|
plt.ylabel('Accuracy') |
|
ax2.set_ylabel('Count') |
|
|
|
|
|
lines, labels = plt.gca().get_legend_handles_labels() |
|
lines2, labels2 = ax2.get_legend_handles_labels() |
|
ax2.legend(lines + lines2, labels + labels2, loc='best') |
|
|
|
|
|
if save_path: |
|
plt.savefig(save_path) |
|
plt.close() |
|
else: |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
class SelectivePredictionEvaluator(UncertaintyEvaluator): |
|
"""Evaluator for selective prediction performance.""" |
|
|
|
def __init__(self): |
|
"""Initialize the selective prediction evaluator.""" |
|
super().__init__("selective_prediction_evaluator") |
|
|
|
def evaluate( |
|
self, |
|
uncertainties: List[float], |
|
correctness: List[bool] |
|
) -> Dict[str, float]: |
|
""" |
|
Evaluate selective prediction performance. |
|
|
|
Args: |
|
uncertainties: List of uncertainty scores (higher means more uncertain) |
|
correctness: List of boolean correctness indicators |
|
|
|
Returns: |
|
Dictionary of selective prediction metrics: |
|
- auroc: Area Under ROC Curve for predicting errors |
|
- auprc: Area Under Precision-Recall Curve for predicting errors |
|
- uncertainty_error_correlation: Correlation between uncertainty and errors |
|
""" |
|
if len(uncertainties) != len(correctness): |
|
raise ValueError("Uncertainties and correctness must have the same length") |
|
|
|
if not uncertainties: |
|
return { |
|
"auroc": 0.5, |
|
"auprc": 0.5, |
|
"uncertainty_error_correlation": 0.0 |
|
} |
|
|
|
|
|
errors = [1 - int(c) for c in correctness] |
|
|
|
|
|
try: |
|
auroc = roc_auc_score(errors, uncertainties) |
|
except: |
|
|
|
auroc = 0.5 |
|
|
|
|
|
try: |
|
precision, recall, _ = precision_recall_curve(errors, uncertainties) |
|
auprc = auc(recall, precision) |
|
except: |
|
|
|
auprc = 0.5 |
|
|
|
|
|
uncertainty_error_correlation = np.corrcoef(uncertainties, errors)[0, 1] |
|
|
|
return { |
|
"auroc": float(auroc), |
|
"auprc": float(auprc), |
|
"uncertainty_error_correlation": float(uncertainty_error_correlation) |
|
} |
|
|
|
def plot_selective_prediction_curve( |
|
self, |
|
uncertainties: List[float], |
|
correctness: List[bool], |
|
title: str = "Selective Prediction Performance", |
|
save_path: Optional[str] = None |
|
) -> None: |
|
""" |
|
Plot a selective prediction curve. |
|
|
|
Args: |
|
uncertainties: List of uncertainty scores (higher means more uncertain) |
|
correctness: List of boolean correctness indicators |
|
title: Title for the plot |
|
save_path: Path to save the plot (None to display) |
|
""" |
|
if len(uncertainties) != len(correctness): |
|
raise ValueError("Uncertainties and correctness must have the same length") |
|
|
|
|
|
sorted_indices = np.argsort(uncertainties) |
|
sorted_correctness = np.array(correctness)[sorted_indices] |
|
|
|
|
|
coverages = np.linspace(0, 1, 100) |
|
accuracies = [] |
|
|
|
for coverage in coverages: |
|
if coverage == 0: |
|
accuracies.append(1.0) |
|
else: |
|
n_samples = int(coverage * len(sorted_correctness)) |
|
if n_samples == 0: |
|
accuracies.append(1.0) |
|
else: |
|
accuracies.append(np.mean(sorted_correctness[:n_samples])) |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
plt.plot(coverages, accuracies, 'b-', linewidth=2) |
|
|
|
|
|
plt.plot([0, 1], [np.mean(correctness), np.mean(correctness)], 'k--', label='Random Selection') |
|
|
|
|
|
metrics = self.evaluate(uncertainties, correctness) |
|
|
|
|
|
plt.title(f"{title}\nAUROC: {metrics['auroc']:.4f}") |
|
|
|
|
|
plt.xlabel('Coverage') |
|
plt.ylabel('Accuracy') |
|
plt.legend(loc='best') |
|
|
|
|
|
if save_path: |
|
plt.savefig(save_path) |
|
plt.close() |
|
else: |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
class CrossDomainEvaluator: |
|
"""Evaluator for cross-domain uncertainty performance.""" |
|
|
|
def __init__(self): |
|
"""Initialize the cross-domain evaluator.""" |
|
self.name = "cross_domain_evaluator" |
|
self.calibration_evaluator = CalibrationEvaluator() |
|
self.selective_prediction_evaluator = SelectivePredictionEvaluator() |
|
|
|
def evaluate_domain_transfer( |
|
self, |
|
source_uncertainties: List[float], |
|
source_correctness: List[bool], |
|
target_uncertainties: List[float], |
|
target_correctness: List[bool] |
|
) -> Dict[str, float]: |
|
""" |
|
Evaluate domain transfer performance. |
|
|
|
Args: |
|
source_uncertainties: List of uncertainty scores from source domain |
|
source_correctness: List of boolean correctness indicators from source domain |
|
target_uncertainties: List of uncertainty scores from target domain |
|
target_correctness: List of boolean correctness indicators from target domain |
|
|
|
Returns: |
|
Dictionary of domain transfer metrics: |
|
- source_auroc: AUROC in source domain |
|
- target_auroc: AUROC in target domain |
|
- transfer_degradation: Degradation in AUROC from source to target |
|
- source_ece: ECE in source domain |
|
- target_ece: ECE in target domain |
|
- calibration_shift: Shift in calibration from source to target |
|
""" |
|
|
|
source_selective = self.selective_prediction_evaluator.evaluate( |
|
source_uncertainties, source_correctness |
|
) |
|
source_calibration = self.calibration_evaluator.evaluate( |
|
[1 - u for u in source_uncertainties], source_correctness |
|
) |
|
|
|
|
|
target_selective = self.selective_prediction_evaluator.evaluate( |
|
target_uncertainties, target_correctness |
|
) |
|
target_calibration = self.calibration_evaluator.evaluate( |
|
[1 - u for u in target_uncertainties], target_correctness |
|
) |
|
|
|
|
|
transfer_degradation = source_selective["auroc"] - target_selective["auroc"] |
|
calibration_shift = target_calibration["ece"] - source_calibration["ece"] |
|
|
|
return { |
|
"source_auroc": source_selective["auroc"], |
|
"target_auroc": target_selective["auroc"], |
|
"transfer_degradation": float(transfer_degradation), |
|
"source_ece": source_calibration["ece"], |
|
"target_ece": target_calibration["ece"], |
|
"calibration_shift": float(calibration_shift) |
|
} |
|
|
|
def evaluate_all_domains( |
|
self, |
|
domain_results: Dict[str, Dict[str, Any]] |
|
) -> Dict[str, Dict[str, float]]: |
|
""" |
|
Evaluate uncertainty performance across all domains. |
|
|
|
Args: |
|
domain_results: Dictionary mapping domain names to results |
|
Each result should contain: |
|
- uncertainties: List of uncertai |
|
(Content truncated due to size limit. Use line ranges to read in chunks) |