Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,960 Bytes
5cdd81a 0ae0ade 5cdd81a 0ae0ade 5cdd81a 0ae0ade 5cdd81a 0ae0ade 5cdd81a 0ae0ade 5cdd81a 0ae0ade 5cdd81a 0ae0ade 5cdd81a 0ae0ade 5cdd81a 0ae0ade 5cdd81a 0ae0ade 5cdd81a 0ae0ade 5cdd81a 0ae0ade 5cdd81a 0ae0ade 5cdd81a 0ae0ade 5cdd81a 0ae0ade 5cdd81a 0ae0ade 5cdd81a 0ae0ade |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
import re
import fitz # pip install pymupdf
from unidecode import unidecode
from nltk.tokenize import sent_tokenize
from transformers import pipeline, AutoTokenizer
import torch
from typing import List, Tuple, Optional
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class BERTRetriever:
"""
BERT-based evidence retrieval using extractive question answering
"""
def __init__(self, model_name: str = "deepset/deberta-v3-large-squad2"):
"""
Initialize the BERT evidence retriever
Args:
model_name: HuggingFace model for question answering
"""
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.qa_pipeline = pipeline(
"question-answering",
model=model_name,
tokenizer=self.tokenizer,
device=0 if torch.cuda.is_available() else -1,
)
# Maximum context length for the model
self.max_length = self.tokenizer.model_max_length
logger.info(f"Initialized BERT retriever with model: {model_name}")
def _extract_and_clean_text(self, pdf_file: str) -> str:
"""
Extract and clean text from PDF file
Args:
pdf_file: Path to PDF file
Returns:
Cleaned text from PDF
"""
# Get PDF file as binary
with open(pdf_file, mode="rb") as f:
pdf_file_bytes = f.read()
# Extract text from the PDF
pdf_doc = fitz.open(stream=pdf_file_bytes, filetype="pdf")
pdf_text = ""
for page_num in range(pdf_doc.page_count):
page = pdf_doc.load_page(page_num)
pdf_text += page.get_text("text")
# Clean text
# Remove hyphens at end of lines
clean_text = re.sub("-\n", "", pdf_text)
# Replace remaining newline characters with space
clean_text = re.sub("\n", " ", clean_text)
# Replace unicode with ascii
clean_text = unidecode(clean_text)
return clean_text
def _chunk_text(self, text: str, max_chunk_size: int = 3000) -> List[str]:
"""
Split text into chunks that fit within model context window
Args:
text: Input text to chunk
max_chunk_size: Maximum size per chunk
Returns:
List of text chunks
"""
sentences = sent_tokenize(text)
chunks = []
current_chunk = ""
for sentence in sentences:
# Check if adding this sentence would exceed the limit
if len(current_chunk) + len(sentence) + 1 <= max_chunk_size:
current_chunk += " " + sentence if current_chunk else sentence
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence
# Add the last chunk
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def _format_claim_as_question(self, claim: str) -> str:
"""
Convert a claim into a question format for better QA performance
Args:
claim: Input claim
Returns:
Question formatted for QA model
"""
# Simple heuristics to convert claims to questions
claim = claim.strip()
# If already a question, return as is
if claim.endswith("?"):
return claim
# Convert common claim patterns to questions
if claim.lower().startswith(("the ", "a ", "an ")):
return f"What evidence supports that {claim.lower()}?"
elif "is" in claim.lower() or "are" in claim.lower():
return f"Is it true that {claim.lower()}?"
elif "can" in claim.lower() or "could" in claim.lower():
return f"{claim}?"
else:
return f"What evidence supports the claim that {claim.lower()}?"
def retrieve_evidence(self, pdf_file: str, claim: str, top_k: int = 5) -> str:
"""
Retrieve evidence from PDF using BERT-based question answering
Args:
pdf_file: Path to PDF file
claim: Claim to find evidence for
k: Number of evidence passages to retrieve
Returns:
Combined evidence text
"""
try:
# Extract and clean text from PDF
clean_text = self._extract_and_clean_text(pdf_file)
# Convert claim to question format
question = self._format_claim_as_question(claim)
# Split text into manageable chunks
chunks = self._chunk_text(clean_text)
# Get answers from each chunk
answers = []
for i, chunk in enumerate(chunks):
try:
result = self.qa_pipeline(
question=question, context=chunk, max_answer_len=200, top_k=1
)
# Handle both single answer and list of answers
if isinstance(result, list):
result = result[0]
if result["score"] > 0.1: # Confidence threshold
# Extract surrounding context for better evidence
answer_text = result["answer"]
start_idx = max(0, chunk.find(answer_text) - 100)
end_idx = min(
len(chunk), chunk.find(answer_text) + len(answer_text) + 100
)
context = chunk[start_idx:end_idx].strip()
answers.append(
{"text": context, "score": result["score"], "chunk_idx": i}
)
except Exception as e:
logger.warning(f"Error processing chunk {i}: {str(e)}")
continue
# Sort by confidence score and take top k
answers.sort(key=lambda x: x["score"], reverse=True)
top_answers = answers[:top_k]
# Combine evidence passages
if top_answers:
evidence_texts = [answer["text"] for answer in top_answers]
combined_evidence = " ".join(evidence_texts)
return combined_evidence
else:
logger.warning("No evidence found with sufficient confidence")
return "No relevant evidence found in the document."
except Exception as e:
logger.error(f"Error in BERT evidence retrieval: {str(e)}")
return f"Error retrieving evidence: {str(e)}"
def retrieve_with_deberta(pdf_file: str, claim: str, top_k: int = 5) -> str:
"""
Wrapper function for DeBERTa-based evidence retrieval
Compatible with the existing BM25S interface
Args:
pdf_file: Path to PDF file
claim: Claim to find evidence for
top_k: Number of evidence passages to retrieve
Returns:
Retrieved evidence text
"""
# Initialize retriever (in production, this should be cached)
retriever = BERTRetriever()
return retriever.retrieve_evidence(pdf_file, claim, top_k)
# Alternative lightweight model for faster inference
class DistilBERTRetriever(BERTRetriever):
"""
Lightweight version using smaller, faster models
"""
def __init__(self):
super().__init__(model_name="distilbert-base-cased-distilled-squad")
def retrieve_with_distilbert(pdf_file: str, claim: str, top_k: int = 5) -> str:
"""
Fast DistilBERT-based evidence retrieval
Args:
pdf_file: Path to PDF file
claim: Claim to find evidence for
top_k: Number of evidence passages to retrieve
Returns:
Retrieved evidence text
"""
retriever = DistilBERTRetriever()
return retriever.retrieve_evidence(pdf_file, claim, top_k)
|