|
""" |
|
Domain-Specific Calibration Module for LLMs |
|
|
|
This module implements calibration techniques for improving uncertainty estimates |
|
across different domains, focusing on temperature scaling and domain adaptation. |
|
""" |
|
|
|
import numpy as np |
|
import torch |
|
from typing import List, Dict, Any, Union, Optional, Tuple |
|
from scipy.optimize import minimize_scalar |
|
|
|
class Calibrator: |
|
"""Base class for calibration methods.""" |
|
|
|
def __init__(self, name: str): |
|
""" |
|
Initialize the calibrator. |
|
|
|
Args: |
|
name: Name of the calibration method |
|
""" |
|
self.name = name |
|
self.is_fitted = False |
|
|
|
def fit(self, confidences: List[float], accuracies: List[bool]) -> None: |
|
""" |
|
Fit the calibrator to the provided data. |
|
|
|
Args: |
|
confidences: List of confidence scores |
|
accuracies: List of boolean accuracy indicators |
|
""" |
|
raise NotImplementedError("Subclasses must implement this method") |
|
|
|
def calibrate(self, confidences: List[float]) -> List[float]: |
|
""" |
|
Calibrate the provided confidence scores. |
|
|
|
Args: |
|
confidences: List of confidence scores |
|
|
|
Returns: |
|
Calibrated confidence scores |
|
""" |
|
raise NotImplementedError("Subclasses must implement this method") |
|
|
|
|
|
class TemperatureScaling(Calibrator): |
|
"""Calibration using temperature scaling.""" |
|
|
|
def __init__(self): |
|
"""Initialize the temperature scaling calibrator.""" |
|
super().__init__("temperature_scaling") |
|
self.temperature = 1.0 |
|
|
|
def _nll_loss(self, temperature: float, confidences: np.ndarray, accuracies: np.ndarray) -> float: |
|
""" |
|
Calculate negative log likelihood loss for temperature scaling. |
|
|
|
Args: |
|
temperature: Temperature parameter |
|
confidences: Array of confidence scores |
|
accuracies: Array of boolean accuracy indicators |
|
|
|
Returns: |
|
Negative log likelihood loss |
|
""" |
|
|
|
scaled_confidences = np.clip(confidences / temperature, 1e-10, 1.0 - 1e-10) |
|
|
|
|
|
loss = -np.mean( |
|
accuracies * np.log(scaled_confidences) + |
|
(1 - accuracies) * np.log(1 - scaled_confidences) |
|
) |
|
|
|
return loss |
|
|
|
def fit(self, confidences: List[float], accuracies: List[bool]) -> None: |
|
""" |
|
Fit the temperature parameter to the provided data. |
|
|
|
Args: |
|
confidences: List of confidence scores |
|
accuracies: List of boolean accuracy indicators |
|
""" |
|
if not confidences or len(confidences) != len(accuracies): |
|
raise ValueError("Confidences and accuracies must have the same non-zero length") |
|
|
|
|
|
conf_array = np.array(confidences) |
|
acc_array = np.array(accuracies, dtype=float) |
|
|
|
|
|
result = minimize_scalar( |
|
lambda t: self._nll_loss(t, conf_array, acc_array), |
|
bounds=(0.1, 10.0), |
|
method='bounded' |
|
) |
|
|
|
self.temperature = result.x |
|
self.is_fitted = True |
|
|
|
print(f"Fitted temperature parameter: {self.temperature:.4f}") |
|
|
|
def calibrate(self, confidences: List[float]) -> List[float]: |
|
""" |
|
Calibrate the provided confidence scores using temperature scaling. |
|
|
|
Args: |
|
confidences: List of confidence scores |
|
|
|
Returns: |
|
Calibrated confidence scores |
|
""" |
|
if not self.is_fitted: |
|
raise ValueError("Calibrator must be fitted before calibration") |
|
|
|
|
|
calibrated = [min(max(conf / self.temperature, 1e-10), 1.0 - 1e-10) for conf in confidences] |
|
|
|
return calibrated |
|
|
|
|
|
class DomainAdaptiveCalibration(Calibrator): |
|
"""Calibration using domain-adaptive techniques.""" |
|
|
|
def __init__(self, source_domain: str, target_domain: str): |
|
""" |
|
Initialize the domain-adaptive calibrator. |
|
|
|
Args: |
|
source_domain: Source domain name |
|
target_domain: Target domain name |
|
""" |
|
super().__init__("domain_adaptive_calibration") |
|
self.source_domain = source_domain |
|
self.target_domain = target_domain |
|
self.source_temperature = 1.0 |
|
self.target_temperature = 1.0 |
|
self.domain_shift_factor = 1.0 |
|
|
|
def fit( |
|
self, |
|
source_confidences: List[float], |
|
source_accuracies: List[bool], |
|
target_confidences: Optional[List[float]] = None, |
|
target_accuracies: Optional[List[bool]] = None |
|
) -> None: |
|
""" |
|
Fit the domain-adaptive calibrator to the provided data. |
|
|
|
Args: |
|
source_confidences: List of confidence scores from source domain |
|
source_accuracies: List of boolean accuracy indicators from source domain |
|
target_confidences: List of confidence scores from target domain (if available) |
|
target_accuracies: List of boolean accuracy indicators from target domain (if available) |
|
""" |
|
|
|
source_calibrator = TemperatureScaling() |
|
source_calibrator.fit(source_confidences, source_accuracies) |
|
self.source_temperature = source_calibrator.temperature |
|
|
|
|
|
if target_confidences and target_accuracies: |
|
target_calibrator = TemperatureScaling() |
|
target_calibrator.fit(target_confidences, target_accuracies) |
|
self.target_temperature = target_calibrator.temperature |
|
|
|
|
|
self.domain_shift_factor = self.target_temperature / self.source_temperature |
|
else: |
|
|
|
|
|
self.domain_shift_factor = 1.2 |
|
self.target_temperature = self.source_temperature * self.domain_shift_factor |
|
|
|
self.is_fitted = True |
|
|
|
print(f"Fitted source temperature: {self.source_temperature:.4f}") |
|
print(f"Fitted target temperature: {self.target_temperature:.4f}") |
|
print(f"Domain shift factor: {self.domain_shift_factor:.4f}") |
|
|
|
def calibrate(self, confidences: List[float], domain: str = None) -> List[float]: |
|
""" |
|
Calibrate the provided confidence scores using domain-adaptive calibration. |
|
|
|
Args: |
|
confidences: List of confidence scores |
|
domain: Domain of the confidences ('source' or 'target', defaults to target) |
|
|
|
Returns: |
|
Calibrated confidence scores |
|
""" |
|
if not self.is_fitted: |
|
raise ValueError("Calibrator must be fitted before calibration") |
|
|
|
|
|
if domain == "source": |
|
temperature = self.source_temperature |
|
else: |
|
temperature = self.target_temperature |
|
|
|
|
|
calibrated = [min(max(conf / temperature, 1e-10), 1.0 - 1e-10) for conf in confidences] |
|
|
|
return calibrated |
|
|
|
|
|
class EnsembleCalibration(Calibrator): |
|
"""Calibration using an ensemble of calibration methods.""" |
|
|
|
def __init__(self, calibrators: List[Calibrator], weights: Optional[List[float]] = None): |
|
""" |
|
Initialize the ensemble calibrator. |
|
|
|
Args: |
|
calibrators: List of calibrator instances |
|
weights: List of weights for each calibrator (None for equal weights) |
|
""" |
|
super().__init__("ensemble_calibration") |
|
self.calibrators = calibrators |
|
|
|
|
|
if weights is None: |
|
self.weights = [1.0 / len(calibrators)] * len(calibrators) |
|
else: |
|
if len(weights) != len(calibrators): |
|
raise ValueError("Number of weights must match number of calibrators") |
|
|
|
|
|
total = sum(weights) |
|
self.weights = [w / total for w in weights] |
|
|
|
def fit(self, confidences: List[float], accuracies: List[bool]) -> None: |
|
""" |
|
Fit all calibrators in the ensemble. |
|
|
|
Args: |
|
confidences: List of confidence scores |
|
accuracies: List of boolean accuracy indicators |
|
""" |
|
for calibrator in self.calibrators: |
|
calibrator.fit(confidences, accuracies) |
|
|
|
self.is_fitted = True |
|
|
|
def calibrate(self, confidences: List[float]) -> List[float]: |
|
""" |
|
Calibrate the provided confidence scores using the ensemble. |
|
|
|
Args: |
|
confidences: List of confidence scores |
|
|
|
Returns: |
|
Calibrated confidence scores |
|
""" |
|
if not self.is_fitted: |
|
raise ValueError("Calibrator must be fitted before calibration") |
|
|
|
|
|
all_calibrated = [] |
|
for calibrator in self.calibrators: |
|
all_calibrated.append(calibrator.calibrate(confidences)) |
|
|
|
|
|
calibrated = [] |
|
for i in range(len(confidences)): |
|
weighted_sum = sum(self.weights[j] * all_calibrated[j][i] for j in range(len(self.calibrators))) |
|
calibrated.append(weighted_sum) |
|
|
|
return calibrated |
|
|
|
|
|
|
|
def create_calibrator(method: str, **kwargs) -> Calibrator: |
|
""" |
|
Create a calibrator based on the specified method. |
|
|
|
Args: |
|
method: Name of the calibration method |
|
**kwargs: Additional arguments for the calibrator |
|
|
|
Returns: |
|
Calibrator instance |
|
""" |
|
if method == "temperature_scaling": |
|
return TemperatureScaling() |
|
elif method == "domain_adaptive": |
|
if "source_domain" not in kwargs or "target_domain" not in kwargs: |
|
raise ValueError("Domain-adaptive calibration requires source_domain and target_domain") |
|
return DomainAdaptiveCalibration(kwargs["source_domain"], kwargs["target_domain"]) |
|
elif method == "ensemble": |
|
if "calibrators" not in kwargs: |
|
raise ValueError("Ensemble calibration requires a list of calibrators") |
|
return EnsembleCalibration(kwargs["calibrators"], kwargs.get("weights")) |
|
else: |
|
raise ValueError(f"Unsupported calibration method: {method}") |
|
|