TextLSRDemo / app.py
SiddharthAK's picture
Update app.py
0c9e4b7 verified
raw
history blame
23.3 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import numpy as np
from tqdm.auto import tqdm
import os
# Add this CSS at the top of your file, after the imports
css = """
/* Move share button to top-right corner */
.share-button {
position: fixed !important;
top: 20px !important;
right: 20px !important;
z-index: 1000 !important;
background: #4CAF50 !important;
color: white !important;
border-radius: 8px !important;
padding: 8px 16px !important;
font-weight: bold !important;
box-shadow: 0 2px 10px rgba(0,0,0,0.2) !important;
}
.share-button:hover {
background: #45a049 !important;
transform: translateY(-1px) !important;
}
/* Alternative positions - uncomment the one you want instead */
/* Top-left corner */
/*
.share-button {
position: fixed !important;
top: 20px !important;
left: 20px !important;
z-index: 1000 !important;
}
*/
/* Bottom-right corner (mobile-friendly) */
/*
.share-button {
position: fixed !important;
bottom: 20px !important;
right: 20px !important;
z-index: 1000 !important;
}
*/
/* Bottom-center */
/*
.share-button {
position: fixed !important;
bottom: 20px !important;
left: 50% !important;
transform: translateX(-50%) !important;
z-index: 1000 !important;
}
*/
"""
# --- Model Loading ---
tokenizer_splade = None
model_splade = None
tokenizer_splade_lexical = None
model_splade_lexical = None
tokenizer_splade_doc = None
model_splade_doc = None
# Load SPLADE v3 model (original)
try:
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade.eval()
print("SPLADE-cocondenser-distil model loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE-cocondenser-distil model: {e}")
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
# Load SPLADE v3 Lexical model
try:
splade_lexical_model_name = "naver/splade-v3-lexical"
tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
model_splade_lexical.eval()
print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE-v3-Lexical model: {e}")
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# Load SPLADE v3 Doc model - Model loading is still necessary even if its logits aren't used for BoW
try:
splade_doc_model_name = "naver/splade-v3-doc"
tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name) # Still load the model
model_splade_doc.eval()
print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE-v3-Doc model: {e}")
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# --- Helper function for lexical mask (now handles batches, but used for single input here) ---
def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
"""
Creates a batch of lexical BOW masks.
input_ids_batch: torch.Tensor of shape (batch_size, sequence_length)
vocab_size: int, size of the tokenizer vocabulary
tokenizer: the tokenizer object
Returns: torch.Tensor of shape (batch_size, vocab_size)
"""
batch_size = input_ids_batch.shape[0]
bow_masks = torch.zeros(batch_size, vocab_size, device=input_ids_batch.device)
for i in range(batch_size):
input_ids = input_ids_batch[i] # Get input_ids for the current item in the batch
meaningful_token_ids = []
for token_id in input_ids.tolist():
if token_id not in [
tokenizer.pad_token_id,
tokenizer.cls_token_id,
tokenizer.sep_token_id,
tokenizer.mask_token_id,
tokenizer.unk_token_id
]:
meaningful_token_ids.append(token_id)
if meaningful_token_ids:
# Apply mask to the current row in the batch
bow_masks[i, list(set(meaningful_token_ids))] = 1
return bow_masks
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
# These functions now return a tuple: (main_representation_str, info_str)
def get_splade_cocondenser_representation(text):
if tokenizer_splade is None or model_splade is None:
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors.", ""
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze() # Squeeze is fine here as it's a single input
else:
return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.", ""
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "MLM encoder (SPLADE-cocondenser-distil):\n\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
# Changed to paragraph style
terms_list = []
for term, weight in sorted_representation:
terms_list.append(f"**{term}**: {weight:.4f}")
formatted_output += ", ".join(terms_list) + "."
info_output = f"" # Line 1
info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
return formatted_output, info_output
def get_splade_lexical_representation(text):
if tokenizer_splade_lexical is None or model_splade_lexical is None:
return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors.", ""
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade_lexical(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze() # Squeeze is fine here
else:
return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.", ""
# Always apply lexical mask for this model's specific behavior
vocab_size = tokenizer_splade_lexical.vocab_size
# Call with unsqueezed input_ids for single sample processing
bow_mask = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
).squeeze() # Squeeze back for single output
splade_vector = splade_vector * bow_mask
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade_lexical.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "SPLADE-v3-Lexical Representation (Weighting):\n\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
# Changed to paragraph style
terms_list = []
for term, weight in sorted_representation:
terms_list.append(f"**{term}**: {weight:.4f}")
formatted_output += ", ".join(terms_list) + "."
info_output = f"" # Line 1
info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
return formatted_output, info_output
def get_splade_doc_representation(text):
if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
return "SPLADE-v3-Doc tokenizer is not loaded. Please check the console for loading errors.", ""
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
vocab_size = tokenizer_splade_doc.vocab_size
# Directly create the binary Bag-of-Words vector using the input_ids
binary_bow_vector = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_doc
).squeeze() # Squeeze back for single output
indices = torch.nonzero(binary_bow_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
values = [1.0] * len(indices) # All values are 1 for binary representation
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade_doc.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity
formatted_output = "Binary Bag-of-Words Representation:\n\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
# Changed to paragraph style
terms_list = []
for term, _ in sorted_representation: # For binary, weight is always 1, so no need to display
terms_list.append(f"**{term}**")
formatted_output += ", ".join(terms_list) + "."
info_output = f"" # Line 1
info_output += f"Total non-zero terms in vector: {len(indices)}" # Line 2
return formatted_output, info_output
# --- Unified Prediction Function for the Explorer Tab ---
def predict_representation_explorer(model_choice, text):
if model_choice == "MLM encoder (SPLADE-cocondenser-distil)":
return get_splade_cocondenser_representation(text)
elif model_choice == "MLP encoder (SPLADE-v3-lexical)":
return get_splade_lexical_representation(text)
elif model_choice == "Binary Bag-of-Words": # Changed name
return get_splade_doc_representation(text)
else:
return "Please select a model.", "" # Return two empty strings for consistency
# --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
# These functions remain unchanged from the previous iteration, as they return the raw tensors.
def get_splade_cocondenser_vector(text):
if tokenizer_splade is None or model_splade is None:
return None
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze()
return splade_vector
return None
def get_splade_lexical_vector(text):
if tokenizer_splade_lexical is None or model_splade_lexical is None:
return None
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade_lexical(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze()
vocab_size = tokenizer_splade_lexical.vocab_size
bow_mask = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
).squeeze()
splade_vector = splade_vector * bow_mask
return splade_vector
return None
def get_splade_doc_vector(text):
if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
return None
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
vocab_size = tokenizer_splade_doc.vocab_size
# Directly create the binary Bag-of-Words vector using the input_ids
binary_bow_vector = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_doc
).squeeze()
return binary_bow_vector
# --- Function to get formatted representation from a raw vector and tokenizer ---
# This function remains unchanged as it's a generic formatter for any sparse vector.
def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
if splade_vector is None:
return "Failed to generate vector.", ""
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
if is_binary:
values = [1.0] * len(indices)
else:
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
if is_binary:
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for binary
else:
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = ""
if not sorted_representation:
formatted_output += "No significant terms found.\n"
else:
terms_list = []
for i, (term, weight) in enumerate(sorted_representation):
if i >= 50:
terms_list.append(f"...and {len(sorted_representation) - 50} more terms.")
break
if is_binary:
terms_list.append(f"**{term}**")
else:
terms_list.append(f"**{term}**: {weight:.4f}")
formatted_output += ", ".join(terms_list) + "."
# This is the line that will now always be split into two
info_output = f"Total non-zero terms: {len(indices)}\n" # Line 1
return formatted_output, info_output
# --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag ---
def get_model_assets(model_choice_str):
if model_choice_str == "MLM encoder (SPLADE-cocondenser-distil)":
return get_splade_cocondenser_vector, tokenizer_splade, False, "MLM encoder (SPLADE-cocondenser-distil)"
elif model_choice_str == "MLP encoder (SPLADE-v3-lexical)":
return get_splade_lexical_vector, tokenizer_splade_lexical, False, "MLP encoder (SPLADE-v3-lexical)"
elif model_choice_str == "Binary Bag-of-Words":
return get_splade_doc_vector, tokenizer_splade_doc, True, "Binary Bag-of-Words"
else:
return None, None, False, "Unknown Model"
# --- MODIFIED: Dot Product Calculation Function for the new tab ---
def calculate_dot_product_and_representations_independent(query_model_choice, doc_model_choice, query_text, doc_text):
query_vector_fn, query_tokenizer, query_is_binary, query_model_name_display = get_model_assets(query_model_choice)
doc_vector_fn, doc_tokenizer, doc_is_binary, doc_model_name_display = get_model_assets(doc_model_choice)
if query_vector_fn is None or doc_vector_fn is None:
return "Please select valid models for both query and document encoding.", "", ""
query_vector = query_vector_fn(query_text)
doc_vector = doc_vector_fn(doc_text)
if query_vector is None or doc_vector is None:
return "Failed to generate one or both vectors. Please check model loading and input text.", "", ""
# Calculate dot product
# Ensure both vectors are on CPU before dot product to avoid device mismatch issues
# and to ensure .item() works reliably for conversion to float.
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
# Format representations - these functions now return two strings (main_output, info_output)
query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
# Combine output into a single string for the Markdown component
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
full_output += "---\n\n"
# Query Representation
full_output += f"Query Representation ({query_model_name_display}):\n\n"
full_output += query_main_rep_str + "\n\n" + query_info_str # Added an extra newline for better spacing
full_output += "\n\n---\n\n" # Separator
# Document Representation
full_output += f"Document Representation ({doc_model_name_display}):\n\n"
full_output += doc_main_rep_str + "\n\n" + doc_info_str # Added an extra newline for better spacing
return full_output
# --- Gradio Interface Setup with Tabs ---
with gr.Blocks(title="SPLADE Demos", css=css) as demo:
gr.Markdown("# 🌌 Sparse Encoder Playground") # Updated title
gr.Markdown("Explore different SPLADE models and their sparse representation types, and calculate similarity between query and document representations.") # Updated description
with gr.Tabs():
with gr.TabItem("Sparse Representation"):
gr.Markdown("### Produce a Sparse Representation of an Input Text")
with gr.Row():
with gr.Column(scale=1): # Left column for inputs and info
model_radio = gr.Radio(
[
"MLM encoder (SPLADE-cocondenser-distil)",
"MLP encoder (SPLADE-v3-lexical)",
"Binary Bag-of-Words"
],
label="Choose Sparse Encoder",
value="MLM encoder (SPLADE-cocondenser-distil)"
)
input_text = gr.Textbox(
lines=5,
label="Enter your query or document text here:",
placeholder="e.g., Why is Padua the nicest city in Italy?"
)
# New Markdown component for the info output
info_output_display = gr.Markdown(
value="",
label="Vector Information",
elem_id="info_output_display" # Add an ID for potential CSS if needed
)
with gr.Column(scale=2): # Right column for the main representation output
main_representation_output = gr.Markdown()
# Connect the interface elements
model_radio.change(
fn=predict_representation_explorer,
inputs=[model_radio, input_text],
outputs=[main_representation_output, info_output_display]
)
input_text.change(
fn=predict_representation_explorer,
inputs=[model_radio, input_text],
outputs=[main_representation_output, info_output_display]
)
# Initial call to populate on load (optional, but good for demo)
demo.load(
fn=lambda: predict_representation_explorer(model_radio.value, input_text.value),
outputs=[main_representation_output, info_output_display]
)
with gr.TabItem("Compare Encoders"): # NEW TAB
gr.Markdown("### Calculate Dot Product Similarity between Query and Document")
gr.Markdown("Select **independent** SPLADE models to encode your query and document, then see their sparse representations and their similarity score.")
# Define the common model choices for cleaner code
model_choices = [
"MLM encoder (SPLADE-cocondenser-distil)",
"MLP encoder (SPLADE-v3-lexical)",
"Binary Bag-of-Words"
]
gr.Interface(
fn=calculate_dot_product_and_representations_independent, # MODIFIED FUNCTION NAME
inputs=[
gr.Radio(
model_choices,
label="Choose Query Encoding Model",
value="MLM encoder (SPLADE-cocondenser-distil)" # Default value
),
gr.Radio(
model_choices,
label="Choose Document Encoding Model",
value="MLM encoder (SPLADE-cocondenser-distil)" # Default value
),
gr.Textbox(
lines=3,
label="Enter Query Text:",
placeholder="e.g., famous dishes of Padua"
),
gr.Textbox(
lines=5,
label="Enter Document Text:",
placeholder="e.g., Padua's cuisine is as famous as its legendary University."
)
],
outputs=gr.Markdown(),
allow_flagging="never"
)
demo.launch()