""" Domain-Specific Calibration Module for LLMs This module implements calibration techniques for improving uncertainty estimates across different domains, focusing on temperature scaling and domain adaptation. """ import numpy as np import torch from typing import List, Dict, Any, Union, Optional, Tuple from scipy.optimize import minimize_scalar class Calibrator: """Base class for calibration methods.""" def __init__(self, name: str): """ Initialize the calibrator. Args: name: Name of the calibration method """ self.name = name self.is_fitted = False def fit(self, confidences: List[float], accuracies: List[bool]) -> None: """ Fit the calibrator to the provided data. Args: confidences: List of confidence scores accuracies: List of boolean accuracy indicators """ raise NotImplementedError("Subclasses must implement this method") def calibrate(self, confidences: List[float]) -> List[float]: """ Calibrate the provided confidence scores. Args: confidences: List of confidence scores Returns: Calibrated confidence scores """ raise NotImplementedError("Subclasses must implement this method") class TemperatureScaling(Calibrator): """Calibration using temperature scaling.""" def __init__(self): """Initialize the temperature scaling calibrator.""" super().__init__("temperature_scaling") self.temperature = 1.0 def _nll_loss(self, temperature: float, confidences: np.ndarray, accuracies: np.ndarray) -> float: """ Calculate negative log likelihood loss for temperature scaling. Args: temperature: Temperature parameter confidences: Array of confidence scores accuracies: Array of boolean accuracy indicators Returns: Negative log likelihood loss """ # Apply temperature scaling scaled_confidences = np.clip(confidences / temperature, 1e-10, 1.0 - 1e-10) # Calculate binary cross-entropy loss loss = -np.mean( accuracies * np.log(scaled_confidences) + (1 - accuracies) * np.log(1 - scaled_confidences) ) return loss def fit(self, confidences: List[float], accuracies: List[bool]) -> None: """ Fit the temperature parameter to the provided data. Args: confidences: List of confidence scores accuracies: List of boolean accuracy indicators """ if not confidences or len(confidences) != len(accuracies): raise ValueError("Confidences and accuracies must have the same non-zero length") # Convert to numpy arrays conf_array = np.array(confidences) acc_array = np.array(accuracies, dtype=float) # Optimize temperature parameter result = minimize_scalar( lambda t: self._nll_loss(t, conf_array, acc_array), bounds=(0.1, 10.0), method='bounded' ) self.temperature = result.x self.is_fitted = True print(f"Fitted temperature parameter: {self.temperature:.4f}") def calibrate(self, confidences: List[float]) -> List[float]: """ Calibrate the provided confidence scores using temperature scaling. Args: confidences: List of confidence scores Returns: Calibrated confidence scores """ if not self.is_fitted: raise ValueError("Calibrator must be fitted before calibration") # Apply temperature scaling calibrated = [min(max(conf / self.temperature, 1e-10), 1.0 - 1e-10) for conf in confidences] return calibrated class DomainAdaptiveCalibration(Calibrator): """Calibration using domain-adaptive techniques.""" def __init__(self, source_domain: str, target_domain: str): """ Initialize the domain-adaptive calibrator. Args: source_domain: Source domain name target_domain: Target domain name """ super().__init__("domain_adaptive_calibration") self.source_domain = source_domain self.target_domain = target_domain self.source_temperature = 1.0 self.target_temperature = 1.0 self.domain_shift_factor = 1.0 def fit( self, source_confidences: List[float], source_accuracies: List[bool], target_confidences: Optional[List[float]] = None, target_accuracies: Optional[List[bool]] = None ) -> None: """ Fit the domain-adaptive calibrator to the provided data. Args: source_confidences: List of confidence scores from source domain source_accuracies: List of boolean accuracy indicators from source domain target_confidences: List of confidence scores from target domain (if available) target_accuracies: List of boolean accuracy indicators from target domain (if available) """ # Fit source domain temperature source_calibrator = TemperatureScaling() source_calibrator.fit(source_confidences, source_accuracies) self.source_temperature = source_calibrator.temperature # If target domain data is available, fit target temperature if target_confidences and target_accuracies: target_calibrator = TemperatureScaling() target_calibrator.fit(target_confidences, target_accuracies) self.target_temperature = target_calibrator.temperature # Calculate domain shift factor self.domain_shift_factor = self.target_temperature / self.source_temperature else: # Default domain shift factor based on heuristics # This is a simplified approach; in a real system, this would be more sophisticated self.domain_shift_factor = 1.2 # Assuming target domain is slightly more uncertain self.target_temperature = self.source_temperature * self.domain_shift_factor self.is_fitted = True print(f"Fitted source temperature: {self.source_temperature:.4f}") print(f"Fitted target temperature: {self.target_temperature:.4f}") print(f"Domain shift factor: {self.domain_shift_factor:.4f}") def calibrate(self, confidences: List[float], domain: str = None) -> List[float]: """ Calibrate the provided confidence scores using domain-adaptive calibration. Args: confidences: List of confidence scores domain: Domain of the confidences ('source' or 'target', defaults to target) Returns: Calibrated confidence scores """ if not self.is_fitted: raise ValueError("Calibrator must be fitted before calibration") # Determine which temperature to use if domain == "source": temperature = self.source_temperature else: temperature = self.target_temperature # Apply temperature scaling calibrated = [min(max(conf / temperature, 1e-10), 1.0 - 1e-10) for conf in confidences] return calibrated class EnsembleCalibration(Calibrator): """Calibration using an ensemble of calibration methods.""" def __init__(self, calibrators: List[Calibrator], weights: Optional[List[float]] = None): """ Initialize the ensemble calibrator. Args: calibrators: List of calibrator instances weights: List of weights for each calibrator (None for equal weights) """ super().__init__("ensemble_calibration") self.calibrators = calibrators # Initialize weights if weights is None: self.weights = [1.0 / len(calibrators)] * len(calibrators) else: if len(weights) != len(calibrators): raise ValueError("Number of weights must match number of calibrators") # Normalize weights total = sum(weights) self.weights = [w / total for w in weights] def fit(self, confidences: List[float], accuracies: List[bool]) -> None: """ Fit all calibrators in the ensemble. Args: confidences: List of confidence scores accuracies: List of boolean accuracy indicators """ for calibrator in self.calibrators: calibrator.fit(confidences, accuracies) self.is_fitted = True def calibrate(self, confidences: List[float]) -> List[float]: """ Calibrate the provided confidence scores using the ensemble. Args: confidences: List of confidence scores Returns: Calibrated confidence scores """ if not self.is_fitted: raise ValueError("Calibrator must be fitted before calibration") # Get calibrated confidences from each calibrator all_calibrated = [] for calibrator in self.calibrators: all_calibrated.append(calibrator.calibrate(confidences)) # Combine calibrated confidences using weights calibrated = [] for i in range(len(confidences)): weighted_sum = sum(self.weights[j] * all_calibrated[j][i] for j in range(len(self.calibrators))) calibrated.append(weighted_sum) return calibrated # Factory function to create calibrators def create_calibrator(method: str, **kwargs) -> Calibrator: """ Create a calibrator based on the specified method. Args: method: Name of the calibration method **kwargs: Additional arguments for the calibrator Returns: Calibrator instance """ if method == "temperature_scaling": return TemperatureScaling() elif method == "domain_adaptive": if "source_domain" not in kwargs or "target_domain" not in kwargs: raise ValueError("Domain-adaptive calibration requires source_domain and target_domain") return DomainAdaptiveCalibration(kwargs["source_domain"], kwargs["target_domain"]) elif method == "ensemble": if "calibrators" not in kwargs: raise ValueError("Ensemble calibration requires a list of calibrators") return EnsembleCalibration(kwargs["calibrators"], kwargs.get("weights")) else: raise ValueError(f"Unsupported calibration method: {method}")