File size: 6,082 Bytes
6c482f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""
LLM Interface Module for Cross-Domain Uncertainty Quantification

This module provides a unified interface for interacting with large language models,
supporting multiple model architectures and uncertainty quantification methods.
"""

import torch
import numpy as np
from typing import List, Dict, Any, Union, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm import tqdm

class LLMInterface:
    """Interface for interacting with large language models with uncertainty quantification."""
    
    def __init__(
        self, 
        model_name: str,
        model_type: str = "causal",
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        cache_dir: Optional[str] = None,
        max_length: int = 512,
        temperature: float = 1.0,
        top_p: float = 1.0,
        num_beams: int = 1
    ):
        """
        Initialize the LLM interface.
        
        Args:
            model_name: Name of the Hugging Face model to use
            model_type: Type of model ('causal' or 'seq2seq')
            device: Device to run the model on ('cpu' or 'cuda')
            cache_dir: Directory to cache models
            max_length: Maximum length of generated sequences
            temperature: Sampling temperature
            top_p: Nucleus sampling parameter
            num_beams: Number of beams for beam search
        """
        self.model_name = model_name
        self.model_type = model_type
        self.device = device
        self.cache_dir = cache_dir
        self.max_length = max_length
        self.temperature = temperature
        self.top_p = top_p
        self.num_beams = num_beams
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, 
            cache_dir=cache_dir
        )
        
        # Load model based on type
        if model_type == "causal":
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                cache_dir=cache_dir,
                torch_dtype=torch.float16 if device == "cuda" else torch.float32
            ).to(device)
        elif model_type == "seq2seq":
            self.model = AutoModelForSeq2SeqLM.from_pretrained(
                model_name,
                cache_dir=cache_dir,
                torch_dtype=torch.float16 if device == "cuda" else torch.float32
            ).to(device)
        else:
            raise ValueError(f"Unsupported model type: {model_type}")
        
        # Response cache for efficiency
        self.response_cache = {}
    
    def generate(
        self, 
        prompt: str,
        num_samples: int = 1,
        return_logits: bool = False,
        **kwargs
    ) -> Dict[str, Any]:
        """
        Generate responses from the model with uncertainty quantification.
        
        Args:
            prompt: Input text prompt
            num_samples: Number of samples to generate (for MC methods)
            return_logits: Whether to return token logits
            **kwargs: Additional generation parameters
            
        Returns:
            Dictionary containing:
                - response: The generated text
                - samples: Multiple samples if num_samples > 1
                - logits: Token logits if return_logits is True
        """
        # Check cache first
        cache_key = (prompt, num_samples, return_logits, str(kwargs))
        if cache_key in self.response_cache:
            return self.response_cache[cache_key]
        
        # Prepare inputs
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        # Set generation parameters
        gen_kwargs = {
            "max_length": self.max_length,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "num_beams": self.num_beams,
            "do_sample": self.temperature > 0,
            "pad_token_id": self.tokenizer.eos_token_id
        }
        gen_kwargs.update(kwargs)
        
        # Generate multiple samples if requested
        samples = []
        all_logits = []
        
        for _ in range(num_samples):
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    output_scores=return_logits,
                    return_dict_in_generate=True,
                    **gen_kwargs
                )
            
            # Extract generated tokens
            if self.model_type == "causal":
                gen_tokens = outputs.sequences[0, inputs.input_ids.shape[1]:]
            else:
                gen_tokens = outputs.sequences[0]
            
            # Decode tokens to text
            gen_text = self.tokenizer.decode(gen_tokens, skip_special_tokens=True)
            samples.append(gen_text)
            
            # Extract logits if requested
            if return_logits and hasattr(outputs, "scores"):
                all_logits.append([score.cpu().numpy() for score in outputs.scores])
        
        # Prepare result
        result = {
            "response": samples[0],  # Primary response is first sample
            "samples": samples
        }
        
        if return_logits:
            result["logits"] = all_logits
        
        # Cache result
        self.response_cache[cache_key] = result
        return result
    
    def batch_generate(
        self, 
        prompts: List[str],
        **kwargs
    ) -> List[Dict[str, Any]]:
        """
        Generate responses for a batch of prompts.
        
        Args:
            prompts: List of input text prompts
            **kwargs: Additional generation parameters
            
        Returns:
            List of generation results for each prompt
        """
        results = []
        for prompt in tqdm(prompts, desc="Generating responses"):
            results.append(self.generate(prompt, **kwargs))
        return results
    
    def clear_cache(self):
        """Clear the response cache."""
        self.response_cache = {}