File size: 6,082 Bytes
6c482f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
"""
LLM Interface Module for Cross-Domain Uncertainty Quantification
This module provides a unified interface for interacting with large language models,
supporting multiple model architectures and uncertainty quantification methods.
"""
import torch
import numpy as np
from typing import List, Dict, Any, Union, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm import tqdm
class LLMInterface:
"""Interface for interacting with large language models with uncertainty quantification."""
def __init__(
self,
model_name: str,
model_type: str = "causal",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
cache_dir: Optional[str] = None,
max_length: int = 512,
temperature: float = 1.0,
top_p: float = 1.0,
num_beams: int = 1
):
"""
Initialize the LLM interface.
Args:
model_name: Name of the Hugging Face model to use
model_type: Type of model ('causal' or 'seq2seq')
device: Device to run the model on ('cpu' or 'cuda')
cache_dir: Directory to cache models
max_length: Maximum length of generated sequences
temperature: Sampling temperature
top_p: Nucleus sampling parameter
num_beams: Number of beams for beam search
"""
self.model_name = model_name
self.model_type = model_type
self.device = device
self.cache_dir = cache_dir
self.max_length = max_length
self.temperature = temperature
self.top_p = top_p
self.num_beams = num_beams
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=cache_dir
)
# Load model based on type
if model_type == "causal":
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=cache_dir,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
elif model_type == "seq2seq":
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
cache_dir=cache_dir,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
else:
raise ValueError(f"Unsupported model type: {model_type}")
# Response cache for efficiency
self.response_cache = {}
def generate(
self,
prompt: str,
num_samples: int = 1,
return_logits: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
Generate responses from the model with uncertainty quantification.
Args:
prompt: Input text prompt
num_samples: Number of samples to generate (for MC methods)
return_logits: Whether to return token logits
**kwargs: Additional generation parameters
Returns:
Dictionary containing:
- response: The generated text
- samples: Multiple samples if num_samples > 1
- logits: Token logits if return_logits is True
"""
# Check cache first
cache_key = (prompt, num_samples, return_logits, str(kwargs))
if cache_key in self.response_cache:
return self.response_cache[cache_key]
# Prepare inputs
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Set generation parameters
gen_kwargs = {
"max_length": self.max_length,
"temperature": self.temperature,
"top_p": self.top_p,
"num_beams": self.num_beams,
"do_sample": self.temperature > 0,
"pad_token_id": self.tokenizer.eos_token_id
}
gen_kwargs.update(kwargs)
# Generate multiple samples if requested
samples = []
all_logits = []
for _ in range(num_samples):
with torch.no_grad():
outputs = self.model.generate(
**inputs,
output_scores=return_logits,
return_dict_in_generate=True,
**gen_kwargs
)
# Extract generated tokens
if self.model_type == "causal":
gen_tokens = outputs.sequences[0, inputs.input_ids.shape[1]:]
else:
gen_tokens = outputs.sequences[0]
# Decode tokens to text
gen_text = self.tokenizer.decode(gen_tokens, skip_special_tokens=True)
samples.append(gen_text)
# Extract logits if requested
if return_logits and hasattr(outputs, "scores"):
all_logits.append([score.cpu().numpy() for score in outputs.scores])
# Prepare result
result = {
"response": samples[0], # Primary response is first sample
"samples": samples
}
if return_logits:
result["logits"] = all_logits
# Cache result
self.response_cache[cache_key] = result
return result
def batch_generate(
self,
prompts: List[str],
**kwargs
) -> List[Dict[str, Any]]:
"""
Generate responses for a batch of prompts.
Args:
prompts: List of input text prompts
**kwargs: Additional generation parameters
Returns:
List of generation results for each prompt
"""
results = []
for prompt in tqdm(prompts, desc="Generating responses"):
results.append(self.generate(prompt, **kwargs))
return results
def clear_cache(self):
"""Clear the response cache."""
self.response_cache = {}
|