File size: 10,931 Bytes
6c482f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 |
"""
Domain-Specific Calibration Module for LLMs
This module implements calibration techniques for improving uncertainty estimates
across different domains, focusing on temperature scaling and domain adaptation.
"""
import numpy as np
import torch
from typing import List, Dict, Any, Union, Optional, Tuple
from scipy.optimize import minimize_scalar
class Calibrator:
"""Base class for calibration methods."""
def __init__(self, name: str):
"""
Initialize the calibrator.
Args:
name: Name of the calibration method
"""
self.name = name
self.is_fitted = False
def fit(self, confidences: List[float], accuracies: List[bool]) -> None:
"""
Fit the calibrator to the provided data.
Args:
confidences: List of confidence scores
accuracies: List of boolean accuracy indicators
"""
raise NotImplementedError("Subclasses must implement this method")
def calibrate(self, confidences: List[float]) -> List[float]:
"""
Calibrate the provided confidence scores.
Args:
confidences: List of confidence scores
Returns:
Calibrated confidence scores
"""
raise NotImplementedError("Subclasses must implement this method")
class TemperatureScaling(Calibrator):
"""Calibration using temperature scaling."""
def __init__(self):
"""Initialize the temperature scaling calibrator."""
super().__init__("temperature_scaling")
self.temperature = 1.0
def _nll_loss(self, temperature: float, confidences: np.ndarray, accuracies: np.ndarray) -> float:
"""
Calculate negative log likelihood loss for temperature scaling.
Args:
temperature: Temperature parameter
confidences: Array of confidence scores
accuracies: Array of boolean accuracy indicators
Returns:
Negative log likelihood loss
"""
# Apply temperature scaling
scaled_confidences = np.clip(confidences / temperature, 1e-10, 1.0 - 1e-10)
# Calculate binary cross-entropy loss
loss = -np.mean(
accuracies * np.log(scaled_confidences) +
(1 - accuracies) * np.log(1 - scaled_confidences)
)
return loss
def fit(self, confidences: List[float], accuracies: List[bool]) -> None:
"""
Fit the temperature parameter to the provided data.
Args:
confidences: List of confidence scores
accuracies: List of boolean accuracy indicators
"""
if not confidences or len(confidences) != len(accuracies):
raise ValueError("Confidences and accuracies must have the same non-zero length")
# Convert to numpy arrays
conf_array = np.array(confidences)
acc_array = np.array(accuracies, dtype=float)
# Optimize temperature parameter
result = minimize_scalar(
lambda t: self._nll_loss(t, conf_array, acc_array),
bounds=(0.1, 10.0),
method='bounded'
)
self.temperature = result.x
self.is_fitted = True
print(f"Fitted temperature parameter: {self.temperature:.4f}")
def calibrate(self, confidences: List[float]) -> List[float]:
"""
Calibrate the provided confidence scores using temperature scaling.
Args:
confidences: List of confidence scores
Returns:
Calibrated confidence scores
"""
if not self.is_fitted:
raise ValueError("Calibrator must be fitted before calibration")
# Apply temperature scaling
calibrated = [min(max(conf / self.temperature, 1e-10), 1.0 - 1e-10) for conf in confidences]
return calibrated
class DomainAdaptiveCalibration(Calibrator):
"""Calibration using domain-adaptive techniques."""
def __init__(self, source_domain: str, target_domain: str):
"""
Initialize the domain-adaptive calibrator.
Args:
source_domain: Source domain name
target_domain: Target domain name
"""
super().__init__("domain_adaptive_calibration")
self.source_domain = source_domain
self.target_domain = target_domain
self.source_temperature = 1.0
self.target_temperature = 1.0
self.domain_shift_factor = 1.0
def fit(
self,
source_confidences: List[float],
source_accuracies: List[bool],
target_confidences: Optional[List[float]] = None,
target_accuracies: Optional[List[bool]] = None
) -> None:
"""
Fit the domain-adaptive calibrator to the provided data.
Args:
source_confidences: List of confidence scores from source domain
source_accuracies: List of boolean accuracy indicators from source domain
target_confidences: List of confidence scores from target domain (if available)
target_accuracies: List of boolean accuracy indicators from target domain (if available)
"""
# Fit source domain temperature
source_calibrator = TemperatureScaling()
source_calibrator.fit(source_confidences, source_accuracies)
self.source_temperature = source_calibrator.temperature
# If target domain data is available, fit target temperature
if target_confidences and target_accuracies:
target_calibrator = TemperatureScaling()
target_calibrator.fit(target_confidences, target_accuracies)
self.target_temperature = target_calibrator.temperature
# Calculate domain shift factor
self.domain_shift_factor = self.target_temperature / self.source_temperature
else:
# Default domain shift factor based on heuristics
# This is a simplified approach; in a real system, this would be more sophisticated
self.domain_shift_factor = 1.2 # Assuming target domain is slightly more uncertain
self.target_temperature = self.source_temperature * self.domain_shift_factor
self.is_fitted = True
print(f"Fitted source temperature: {self.source_temperature:.4f}")
print(f"Fitted target temperature: {self.target_temperature:.4f}")
print(f"Domain shift factor: {self.domain_shift_factor:.4f}")
def calibrate(self, confidences: List[float], domain: str = None) -> List[float]:
"""
Calibrate the provided confidence scores using domain-adaptive calibration.
Args:
confidences: List of confidence scores
domain: Domain of the confidences ('source' or 'target', defaults to target)
Returns:
Calibrated confidence scores
"""
if not self.is_fitted:
raise ValueError("Calibrator must be fitted before calibration")
# Determine which temperature to use
if domain == "source":
temperature = self.source_temperature
else:
temperature = self.target_temperature
# Apply temperature scaling
calibrated = [min(max(conf / temperature, 1e-10), 1.0 - 1e-10) for conf in confidences]
return calibrated
class EnsembleCalibration(Calibrator):
"""Calibration using an ensemble of calibration methods."""
def __init__(self, calibrators: List[Calibrator], weights: Optional[List[float]] = None):
"""
Initialize the ensemble calibrator.
Args:
calibrators: List of calibrator instances
weights: List of weights for each calibrator (None for equal weights)
"""
super().__init__("ensemble_calibration")
self.calibrators = calibrators
# Initialize weights
if weights is None:
self.weights = [1.0 / len(calibrators)] * len(calibrators)
else:
if len(weights) != len(calibrators):
raise ValueError("Number of weights must match number of calibrators")
# Normalize weights
total = sum(weights)
self.weights = [w / total for w in weights]
def fit(self, confidences: List[float], accuracies: List[bool]) -> None:
"""
Fit all calibrators in the ensemble.
Args:
confidences: List of confidence scores
accuracies: List of boolean accuracy indicators
"""
for calibrator in self.calibrators:
calibrator.fit(confidences, accuracies)
self.is_fitted = True
def calibrate(self, confidences: List[float]) -> List[float]:
"""
Calibrate the provided confidence scores using the ensemble.
Args:
confidences: List of confidence scores
Returns:
Calibrated confidence scores
"""
if not self.is_fitted:
raise ValueError("Calibrator must be fitted before calibration")
# Get calibrated confidences from each calibrator
all_calibrated = []
for calibrator in self.calibrators:
all_calibrated.append(calibrator.calibrate(confidences))
# Combine calibrated confidences using weights
calibrated = []
for i in range(len(confidences)):
weighted_sum = sum(self.weights[j] * all_calibrated[j][i] for j in range(len(self.calibrators)))
calibrated.append(weighted_sum)
return calibrated
# Factory function to create calibrators
def create_calibrator(method: str, **kwargs) -> Calibrator:
"""
Create a calibrator based on the specified method.
Args:
method: Name of the calibration method
**kwargs: Additional arguments for the calibrator
Returns:
Calibrator instance
"""
if method == "temperature_scaling":
return TemperatureScaling()
elif method == "domain_adaptive":
if "source_domain" not in kwargs or "target_domain" not in kwargs:
raise ValueError("Domain-adaptive calibration requires source_domain and target_domain")
return DomainAdaptiveCalibration(kwargs["source_domain"], kwargs["target_domain"])
elif method == "ensemble":
if "calibrators" not in kwargs:
raise ValueError("Ensemble calibration requires a list of calibrators")
return EnsembleCalibration(kwargs["calibrators"], kwargs.get("weights"))
else:
raise ValueError(f"Unsupported calibration method: {method}")
|