TextLSRDemo / app.py
SiddharthAK's picture
Update app.py
372cab2 verified
raw
history blame
20.4 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import numpy as np
from tqdm.auto import tqdm # Still useful for model loading progress if desired, but not strictly necessary for this simplified version
import os # Still useful for general purpose, but not explicitly used in this simplified version
# --- Model Loading ---
tokenizer_splade = None
model_splade = None
tokenizer_splade_lexical = None
model_splade_lexical = None
tokenizer_splade_doc = None
model_splade_doc = None
# Load SPLADE v3 model (original)
try:
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade.eval()
print("SPLADE-cocondenser-distil model loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE-cocondenser-distil model: {e}")
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
# Load SPLADE v3 Lexical model
try:
splade_lexical_model_name = "naver/splade-v3-lexical"
tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
model_splade_lexical.eval()
print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE-v3-Lexical model: {e}")
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# Load SPLADE v3 Doc model
try:
splade_doc_model_name = "naver/splade-v3-doc"
tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name)
model_splade_doc.eval()
print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE-v3-Doc model: {e}")
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# --- Helper function for lexical mask (now handles batches, but used for single input here) ---
def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
"""
Creates a batch of lexical BOW masks.
input_ids_batch: torch.Tensor of shape (batch_size, sequence_length)
vocab_size: int, size of the tokenizer vocabulary
tokenizer: the tokenizer object
Returns: torch.Tensor of shape (batch_size, vocab_size)
"""
batch_size = input_ids_batch.shape[0]
bow_masks = torch.zeros(batch_size, vocab_size, device=input_ids_batch.device)
for i in range(batch_size):
input_ids = input_ids_batch[i] # Get input_ids for the current item in the batch
meaningful_token_ids = []
for token_id in input_ids.tolist():
if token_id not in [
tokenizer.pad_token_id,
tokenizer.cls_token_id,
tokenizer.sep_token_id,
tokenizer.mask_token_id,
tokenizer.unk_token_id
]:
meaningful_token_ids.append(token_id)
if meaningful_token_ids:
# Apply mask to the current row in the batch
bow_masks[i, list(set(meaningful_token_ids))] = 1
return bow_masks
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
# These functions take single text input for the Explorer tab
def get_splade_cocondenser_representation(text):
if tokenizer_splade is None or model_splade is None:
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze() # Squeeze is fine here as it's a single input
else:
return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found."
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
for term, weight in sorted_representation:
formatted_output += f"- **{term}**: {weight:.4f}\n"
formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n"
return formatted_output
def get_splade_lexical_representation(text):
if tokenizer_splade_lexical is None or model_splade_lexical is None:
return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors."
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade_lexical(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze() # Squeeze is fine here
else:
return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found."
# Always apply lexical mask for this model's specific behavior
vocab_size = tokenizer_splade_lexical.vocab_size
# Call with unsqueezed input_ids for single sample processing
bow_mask = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
).squeeze() # Squeeze back for single output
splade_vector = splade_vector * bow_mask
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade_lexical.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "SPLADE-v3-Lexical Representation (Weighting):\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
for term, weight in sorted_representation:
formatted_output += f"- **{term}**: {weight:.4f}\n"
formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_lexical.vocab_size):.2%}\n"
return formatted_output
def get_splade_doc_representation(text):
if tokenizer_splade_doc is None or model_splade_doc is None:
return "SPLADE-v3-Doc model is not loaded. Please check the console for loading errors."
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade_doc(**inputs)
if not hasattr(output, "logits"):
return "Model output structure not as expected. 'logits' not found."
vocab_size = tokenizer_splade_doc.vocab_size
# Call with unsqueezed input_ids for single sample processing
binary_splade_vector = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_doc
).squeeze() # Squeeze back for single output
indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
values = [1.0] * len(indices) # All values are 1 for binary representation
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade_doc.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity
formatted_output = "SPLADE-v3-Doc Representation (Binary):\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
for i, (term, _) in enumerate(sorted_representation):
if i >= 50: # Limit display for very long lists
formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n"
break
formatted_output += f"- **{term}**\n"
formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
formatted_output += f"Total activated terms: {len(indices)}\n"
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
return formatted_output
# --- Unified Prediction Function for the Explorer Tab ---
def predict_representation_explorer(model_choice, text):
if model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
return get_splade_cocondenser_representation(text)
elif model_choice == "SPLADE-v3-Lexical (weighting)":
return get_splade_lexical_representation(text)
elif model_choice == "SPLADE-v3-Doc (binary)":
return get_splade_doc_representation(text)
else:
return "Please select a model."
# --- NEW: Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
def get_splade_cocondenser_vector(text):
if tokenizer_splade is None or model_splade is None:
return None
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze()
return splade_vector
return None
def get_splade_lexical_vector(text):
if tokenizer_splade_lexical is None or model_splade_lexical is None:
return None
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade_lexical(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze()
vocab_size = tokenizer_splade_lexical.vocab_size
bow_mask = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
).squeeze()
splade_vector = splade_vector * bow_mask
return splade_vector
return None
def get_splade_doc_vector(text):
if tokenizer_splade_doc is None or model_splade_doc is None:
return None
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade_doc(**inputs)
if hasattr(output, "logits"):
vocab_size = tokenizer_splade_doc.vocab_size
binary_splade_vector = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_doc
).squeeze()
return binary_splade_vector
return None
# --- NEW: Function to get formatted representation from a raw vector and tokenizer ---
def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
if splade_vector is None:
return "Failed to generate vector."
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
if is_binary:
values = [1.0] * len(indices)
else:
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
if is_binary:
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for binary
else:
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = ""
if not sorted_representation:
formatted_output += "No significant terms found.\n"
else:
for i, (term, weight) in enumerate(sorted_representation):
if i >= 50 and is_binary: # Limit display for very long binary lists
formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n"
break
if is_binary:
formatted_output += f"- **{term}**\n"
else:
formatted_output += f"- **{term}**: {weight:.4f}\n"
formatted_output += f"\nTotal non-zero terms: {len(indices)}\n"
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer.vocab_size):.2%}\n"
return formatted_output
# --- NEW: Dot Product Calculation Function for the new tab ---
def calculate_dot_product_and_representations(model_choice, query_text, doc_text):
query_vector = None
doc_vector = None
query_rep_str = ""
doc_rep_str = ""
selected_tokenizer = None
if model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
query_vector = get_splade_cocondenser_vector(query_text)
doc_vector = get_splade_cocondenser_vector(doc_text)
selected_tokenizer = tokenizer_splade
query_rep_str = "Query SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n"
doc_rep_str = "Document SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n"
is_binary = False
elif model_choice == "SPLADE-v3-Lexical (weighting)":
query_vector = get_splade_lexical_vector(query_text)
doc_vector = get_splade_lexical_vector(doc_text)
selected_tokenizer = tokenizer_splade_lexical
query_rep_str = "Query SPLADE-v3-Lexical Representation (Weighting):\n"
doc_rep_str = "Document SPLADE-v3-Lexical Representation (Weighting):\n"
is_binary = False
elif model_choice == "SPLADE-v3-Doc (binary)":
query_vector = get_splade_doc_vector(query_text)
doc_vector = get_splade_doc_vector(doc_text)
selected_tokenizer = tokenizer_splade_doc
query_rep_str = "Query SPLADE-v3-Doc Representation (Binary):\n"
doc_rep_str = "Document SPLADE-v3-Doc Representation (Binary):\n"
is_binary = True
else:
return "Please select a model.", "", ""
if query_vector is None or doc_vector is None:
return "Failed to generate one or both vectors. Please check model loading.", "", ""
# Calculate dot product
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
# Format representations
query_rep_str += format_sparse_vector_output(query_vector, selected_tokenizer, is_binary)
doc_rep_str += format_sparse_vector_output(doc_vector, selected_tokenizer, is_binary)
# Combine output
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
full_output += "---\n\n"
full_output += f"{query_rep_str}\n\n---\n\n{doc_rep_str}"
return full_output
# --- Gradio Interface Setup with Tabs ---
with gr.Blocks(title="SPLADE Demos") as demo:
gr.Markdown("# 🌌 SPLADE Demos: Sparse Representation Explorer and Retriever") # Updated title
gr.Markdown("Explore different SPLADE models and their sparse representation types, and calculate similarity between query and document representations.") # Updated description
with gr.Tabs():
with gr.TabItem("Sparse Representation Explorer"):
gr.Markdown("### Explore Raw SPLADE Representations for Any Text")
gr.Interface(
fn=predict_representation_explorer,
inputs=[
gr.Radio(
[
"SPLADE-cocondenser-distil (weighting and expansion)",
"SPLADE-v3-Lexical (weighting)",
"SPLADE-v3-Doc (binary)"
],
label="Choose Representation Model",
value="SPLADE-cocondenser-distil (weighting and expansion)"
),
gr.Textbox(
lines=5,
label="Enter your query or document text here:",
placeholder="e.g., Why is Padua the nicest city in Italy?"
)
],
outputs=gr.Markdown(),
allow_flagging="never",
# live=True # Setting live=True might be slow for complex models on every keystroke
)
with gr.TabItem("Query-Document Dot Product Calculator"): # NEW TAB
gr.Markdown("### Calculate Dot Product Similarity between Query and Document")
gr.Markdown("Select a SPLADE model to encode both your query and document, then see their sparse representations and their similarity score.")
gr.Interface(
fn=calculate_dot_product_and_representations,
inputs=[
gr.Radio(
[
"SPLADE-cocondenser-distil (weighting and expansion)",
"SPLADE-v3-Lexical (weighting)",
"SPLADE-v3-Doc (binary)"
],
label="Choose Encoding Model",
value="SPLADE-cocondenser-distil (weighting and expansion)"
),
gr.Textbox(
lines=3,
label="Enter Query Text:",
placeholder="e.g., best pizza in Naples"
),
gr.Textbox(
lines=5,
label="Enter Document Text:",
placeholder="e.g., Naples is famous for its delicious pizza, known for its soft, chewy crust and fresh ingredients."
)
],
outputs=gr.Markdown(),
allow_flagging="never"
)
demo.launch()