File size: 10,931 Bytes
6c482f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
"""
Domain-Specific Calibration Module for LLMs

This module implements calibration techniques for improving uncertainty estimates
across different domains, focusing on temperature scaling and domain adaptation.
"""

import numpy as np
import torch
from typing import List, Dict, Any, Union, Optional, Tuple
from scipy.optimize import minimize_scalar

class Calibrator:
    """Base class for calibration methods."""
    
    def __init__(self, name: str):
        """
        Initialize the calibrator.
        
        Args:
            name: Name of the calibration method
        """
        self.name = name
        self.is_fitted = False
    
    def fit(self, confidences: List[float], accuracies: List[bool]) -> None:
        """
        Fit the calibrator to the provided data.
        
        Args:
            confidences: List of confidence scores
            accuracies: List of boolean accuracy indicators
        """
        raise NotImplementedError("Subclasses must implement this method")
    
    def calibrate(self, confidences: List[float]) -> List[float]:
        """
        Calibrate the provided confidence scores.
        
        Args:
            confidences: List of confidence scores
            
        Returns:
            Calibrated confidence scores
        """
        raise NotImplementedError("Subclasses must implement this method")


class TemperatureScaling(Calibrator):
    """Calibration using temperature scaling."""
    
    def __init__(self):
        """Initialize the temperature scaling calibrator."""
        super().__init__("temperature_scaling")
        self.temperature = 1.0
    
    def _nll_loss(self, temperature: float, confidences: np.ndarray, accuracies: np.ndarray) -> float:
        """
        Calculate negative log likelihood loss for temperature scaling.
        
        Args:
            temperature: Temperature parameter
            confidences: Array of confidence scores
            accuracies: Array of boolean accuracy indicators
            
        Returns:
            Negative log likelihood loss
        """
        # Apply temperature scaling
        scaled_confidences = np.clip(confidences / temperature, 1e-10, 1.0 - 1e-10)
        
        # Calculate binary cross-entropy loss
        loss = -np.mean(
            accuracies * np.log(scaled_confidences) + 
            (1 - accuracies) * np.log(1 - scaled_confidences)
        )
        
        return loss
    
    def fit(self, confidences: List[float], accuracies: List[bool]) -> None:
        """
        Fit the temperature parameter to the provided data.
        
        Args:
            confidences: List of confidence scores
            accuracies: List of boolean accuracy indicators
        """
        if not confidences or len(confidences) != len(accuracies):
            raise ValueError("Confidences and accuracies must have the same non-zero length")
        
        # Convert to numpy arrays
        conf_array = np.array(confidences)
        acc_array = np.array(accuracies, dtype=float)
        
        # Optimize temperature parameter
        result = minimize_scalar(
            lambda t: self._nll_loss(t, conf_array, acc_array),
            bounds=(0.1, 10.0),
            method='bounded'
        )
        
        self.temperature = result.x
        self.is_fitted = True
        
        print(f"Fitted temperature parameter: {self.temperature:.4f}")
    
    def calibrate(self, confidences: List[float]) -> List[float]:
        """
        Calibrate the provided confidence scores using temperature scaling.
        
        Args:
            confidences: List of confidence scores
            
        Returns:
            Calibrated confidence scores
        """
        if not self.is_fitted:
            raise ValueError("Calibrator must be fitted before calibration")
        
        # Apply temperature scaling
        calibrated = [min(max(conf / self.temperature, 1e-10), 1.0 - 1e-10) for conf in confidences]
        
        return calibrated


class DomainAdaptiveCalibration(Calibrator):
    """Calibration using domain-adaptive techniques."""
    
    def __init__(self, source_domain: str, target_domain: str):
        """
        Initialize the domain-adaptive calibrator.
        
        Args:
            source_domain: Source domain name
            target_domain: Target domain name
        """
        super().__init__("domain_adaptive_calibration")
        self.source_domain = source_domain
        self.target_domain = target_domain
        self.source_temperature = 1.0
        self.target_temperature = 1.0
        self.domain_shift_factor = 1.0
    
    def fit(
        self, 
        source_confidences: List[float], 
        source_accuracies: List[bool],
        target_confidences: Optional[List[float]] = None,
        target_accuracies: Optional[List[bool]] = None
    ) -> None:
        """
        Fit the domain-adaptive calibrator to the provided data.
        
        Args:
            source_confidences: List of confidence scores from source domain
            source_accuracies: List of boolean accuracy indicators from source domain
            target_confidences: List of confidence scores from target domain (if available)
            target_accuracies: List of boolean accuracy indicators from target domain (if available)
        """
        # Fit source domain temperature
        source_calibrator = TemperatureScaling()
        source_calibrator.fit(source_confidences, source_accuracies)
        self.source_temperature = source_calibrator.temperature
        
        # If target domain data is available, fit target temperature
        if target_confidences and target_accuracies:
            target_calibrator = TemperatureScaling()
            target_calibrator.fit(target_confidences, target_accuracies)
            self.target_temperature = target_calibrator.temperature
            
            # Calculate domain shift factor
            self.domain_shift_factor = self.target_temperature / self.source_temperature
        else:
            # Default domain shift factor based on heuristics
            # This is a simplified approach; in a real system, this would be more sophisticated
            self.domain_shift_factor = 1.2  # Assuming target domain is slightly more uncertain
            self.target_temperature = self.source_temperature * self.domain_shift_factor
        
        self.is_fitted = True
        
        print(f"Fitted source temperature: {self.source_temperature:.4f}")
        print(f"Fitted target temperature: {self.target_temperature:.4f}")
        print(f"Domain shift factor: {self.domain_shift_factor:.4f}")
    
    def calibrate(self, confidences: List[float], domain: str = None) -> List[float]:
        """
        Calibrate the provided confidence scores using domain-adaptive calibration.
        
        Args:
            confidences: List of confidence scores
            domain: Domain of the confidences ('source' or 'target', defaults to target)
            
        Returns:
            Calibrated confidence scores
        """
        if not self.is_fitted:
            raise ValueError("Calibrator must be fitted before calibration")
        
        # Determine which temperature to use
        if domain == "source":
            temperature = self.source_temperature
        else:
            temperature = self.target_temperature
        
        # Apply temperature scaling
        calibrated = [min(max(conf / temperature, 1e-10), 1.0 - 1e-10) for conf in confidences]
        
        return calibrated


class EnsembleCalibration(Calibrator):
    """Calibration using an ensemble of calibration methods."""
    
    def __init__(self, calibrators: List[Calibrator], weights: Optional[List[float]] = None):
        """
        Initialize the ensemble calibrator.
        
        Args:
            calibrators: List of calibrator instances
            weights: List of weights for each calibrator (None for equal weights)
        """
        super().__init__("ensemble_calibration")
        self.calibrators = calibrators
        
        # Initialize weights
        if weights is None:
            self.weights = [1.0 / len(calibrators)] * len(calibrators)
        else:
            if len(weights) != len(calibrators):
                raise ValueError("Number of weights must match number of calibrators")
            
            # Normalize weights
            total = sum(weights)
            self.weights = [w / total for w in weights]
    
    def fit(self, confidences: List[float], accuracies: List[bool]) -> None:
        """
        Fit all calibrators in the ensemble.
        
        Args:
            confidences: List of confidence scores
            accuracies: List of boolean accuracy indicators
        """
        for calibrator in self.calibrators:
            calibrator.fit(confidences, accuracies)
        
        self.is_fitted = True
    
    def calibrate(self, confidences: List[float]) -> List[float]:
        """
        Calibrate the provided confidence scores using the ensemble.
        
        Args:
            confidences: List of confidence scores
            
        Returns:
            Calibrated confidence scores
        """
        if not self.is_fitted:
            raise ValueError("Calibrator must be fitted before calibration")
        
        # Get calibrated confidences from each calibrator
        all_calibrated = []
        for calibrator in self.calibrators:
            all_calibrated.append(calibrator.calibrate(confidences))
        
        # Combine calibrated confidences using weights
        calibrated = []
        for i in range(len(confidences)):
            weighted_sum = sum(self.weights[j] * all_calibrated[j][i] for j in range(len(self.calibrators)))
            calibrated.append(weighted_sum)
        
        return calibrated


# Factory function to create calibrators
def create_calibrator(method: str, **kwargs) -> Calibrator:
    """
    Create a calibrator based on the specified method.
    
    Args:
        method: Name of the calibration method
        **kwargs: Additional arguments for the calibrator
        
    Returns:
        Calibrator instance
    """
    if method == "temperature_scaling":
        return TemperatureScaling()
    elif method == "domain_adaptive":
        if "source_domain" not in kwargs or "target_domain" not in kwargs:
            raise ValueError("Domain-adaptive calibration requires source_domain and target_domain")
        return DomainAdaptiveCalibration(kwargs["source_domain"], kwargs["target_domain"])
    elif method == "ensemble":
        if "calibrators" not in kwargs:
            raise ValueError("Ensemble calibration requires a list of calibrators")
        return EnsembleCalibration(kwargs["calibrators"], kwargs.get("weights"))
    else:
        raise ValueError(f"Unsupported calibration method: {method}")