TextLSRDemo / app.py
SiddharthAK's picture
Update app.py
1782def verified
raw
history blame
27.6 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import numpy as np
from tqdm.auto import tqdm
import os
# CSS to style the custom share button (for the "Sparse Representation" tab)
css = """
.share-button-container {
display: flex;
justify-content: center;
margin-top: 10px;
margin-bottom: 20px;
}
.custom-share-button {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: none;
color: white;
padding: 8px 16px;
text-align: center;
text-decoration: none;
display: inline-block;
font-size: 14px;
margin: 4px 2px;
cursor: pointer;
border-radius: 6px;
transition: all 0.3s ease;
}
.custom-share-button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0,0,0,0.2);
}
/* IMPORTANT: This CSS targets Gradio's *default* share button that appears
when demo.launch(share=True) is used.
You might want to comment this out if you prefer Gradio's default positioning
for the main share button (usually in the header/footer) and rely only on your custom one.
*/
.share-button {
position: fixed !important;
top: 20px !important;
right: 20px !important;
z-index: 1000 !important;
background: #4CAF50 !important;
color: white !important;
border-radius: 8px !important;
padding: 8px 16px !important;
font-weight: bold !important;
box-shadow: 0 2px 10px rgba(0,0,0,0.2) !important;
}
.share-button:hover {
background: #45a049 !important;
transform: translateY(-1px) !important;
}
"""
# --- Model Loading ---
tokenizer_splade = None
model_splade = None
tokenizer_splade_lexical = None
model_splade_lexical = None
tokenizer_splade_doc = None
model_splade_doc = None
# Load SPLADE v3 model (original)
try:
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade.eval()
print("SPLADE-cocondenser-distil model loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE-cocondenser-distil model: {e}")
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
# Load SPLADE v3 Lexical model
try:
splade_lexical_model_name = "naver/splade-v3-lexical"
tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
model_splade_lexical.eval()
print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE-v3-Lexical model: {e}")
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# Load SPLADE v3 Doc model - Model loading is still necessary even if its logits aren't used for BoW
try:
splade_doc_model_name = "naver/splade-v3-doc"
tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name) # Still load the model
model_splade_doc.eval()
print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE-v3-Doc model: {e}")
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# --- Helper function for lexical mask (now handles batches, but used for single input here) ---
def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
"""
Creates a batch of lexical BOW masks.
input_ids_batch: torch.Tensor of shape (batch_size, sequence_length)
vocab_size: int, size of the tokenizer vocabulary
tokenizer: the tokenizer object
Returns: torch.Tensor of shape (batch_size, vocab_size)
"""
batch_size = input_ids_batch.shape[0]
bow_masks = torch.zeros(batch_size, vocab_size, device=input_ids_batch.device)
for i in range(batch_size):
input_ids = input_ids_batch[i] # Get input_ids for the current item in the batch
meaningful_token_ids = []
for token_id in input_ids.tolist():
if token_id not in [
tokenizer.pad_token_id,
tokenizer.cls_token_id,
tokenizer.sep_token_id,
tokenizer.mask_token_id,
tokenizer.unk_token_id
]:
meaningful_token_ids.append(token_id)
if meaningful_token_ids:
# Apply mask to the current row in the batch
bow_masks[i, list(set(meaningful_token_ids))] = 1
return bow_masks
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
# These functions now return a tuple: (main_representation_str, info_str)
def get_splade_cocondenser_representation(text):
if tokenizer_splade is None or model_splade is None:
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors.", ""
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze() # Squeeze is fine here as it's a single input
else:
return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.", ""
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "MLM Representation:\n\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
# Changed to paragraph style
terms_list = []
for term, weight in sorted_representation:
terms_list.append(f"**{term}**: {weight:.4f}")
formatted_output += ", ".join(terms_list) + "."
info_output = f"" # Line 1
info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
return formatted_output, info_output
def get_splade_lexical_representation(text):
if tokenizer_splade_lexical is None or model_splade_lexical is None:
return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors.", ""
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade_lexical(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze() # Squeeze is fine here
else:
return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.", ""
# Always apply lexical mask for this model's specific behavior
vocab_size = tokenizer_splade_lexical.vocab_size
# Call with unsqueezed input_ids for single sample processing
bow_mask = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
).squeeze() # Squeeze back for single output
splade_vector = splade_vector * bow_mask
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade_lexical.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "MLP Representation:\n\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
# Changed to paragraph style
terms_list = []
for term, weight in sorted_representation:
terms_list.append(f"**{term}**: {weight:.4f}")
formatted_output += ", ".join(terms_list) + "."
info_output = f"" # Line 1
info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
return formatted_output, info_output
def get_splade_doc_representation(text):
if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
return "SPLADE-v3-Doc tokenizer is not loaded. Please check the console for loading errors.", ""
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
vocab_size = tokenizer_splade_doc.vocab_size
# Directly create the binary Bag-of-Words vector using the input_ids
binary_bow_vector = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_doc
).squeeze() # Squeeze back for single output
indices = torch.nonzero(binary_bow_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
values = [1.0] * len(indices) # All values are 1 for binary representation
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade_doc.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity
formatted_output = "Binary Bag-of-Words Representation:\n\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
# Changed to paragraph style
terms_list = []
for term, _ in sorted_representation: # For binary, weight is always 1, so no need to display
terms_list.append(f"**{term}**")
formatted_output += ", ".join(terms_list) + "."
info_output = f"" # Line 1
info_output += f"Total non-zero terms in vector: {len(indices)}" # Line 2
return formatted_output, info_output
# --- Unified Prediction Function for the Explorer Tab ---
def predict_representation_explorer(model_choice, text):
if model_choice == "MLM encoder (SPLADE-cocondenser-distil)":
return get_splade_cocondenser_representation(text)
elif model_choice == "MLP encoder (SPLADE-v3-lexical)":
return get_splade_lexical_representation(text)
elif model_choice == "Binary Bag-of-Words": # Changed name
return get_splade_doc_representation(text)
else:
return "Please select a model.", "" # Return two empty strings for consistency
# --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
# These functions remain unchanged from the previous iteration, as they return the raw tensors.
def get_splade_cocondenser_vector(text):
if tokenizer_splade is None or model_splade is None:
return None
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze()
return splade_vector
return None
def get_splade_lexical_vector(text):
if tokenizer_splade_lexical is None or model_splade_lexical is None:
return None
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade_lexical(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze()
vocab_size = tokenizer_splade_lexical.vocab_size
bow_mask = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
).squeeze()
splade_vector = splade_vector * bow_mask
return splade_vector
return None
def get_splade_doc_vector(text):
if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
return None
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
vocab_size = tokenizer_splade_doc.vocab_size
# Directly create the binary Bag-of-Words vector using the input_ids
binary_bow_vector = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_doc
).squeeze()
return binary_bow_vector
# --- Function to get formatted representation from a raw vector and tokenizer ---
# This function remains unchanged as it's a generic formatter for any sparse vector.
def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
if splade_vector is None:
return "Failed to generate vector.", ""
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
if is_binary:
values = [1.0] * len(indices)
else:
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
if is_binary:
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for binary
else:
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = ""
if not sorted_representation:
formatted_output += "No significant terms found.\n"
else:
terms_list = []
for i, (term, weight) in enumerate(sorted_representation):
if i >= 50:
terms_list.append(f"...and {len(sorted_representation) - 50} more terms.")
break
if is_binary:
terms_list.append(f"**{term}**")
else:
terms_list.append(f"**{term}**: {weight:.4f}")
formatted_output += ", ".join(terms_list) + "."
# This is the line that will now always be split into two
info_output = f"Total non-zero terms: {len(indices)}\n" # Line 1
return formatted_output, info_output
# --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag ---
def get_model_assets(model_choice_str):
if model_choice_str == "MLM encoder (SPLADE-cocondenser-distil)":
return get_splade_cocondenser_vector, tokenizer_splade, False, "MLM encoder (SPLADE-cocondenser-distil)"
elif model_choice_str == "MLP encoder (SPLADE-v3-lexical)":
return get_splade_lexical_vector, tokenizer_splade_lexical, False, "MLP encoder (SPLADE-v3-lexical)"
elif model_choice_str == "Binary Bag-of-Words":
return get_splade_doc_vector, tokenizer_splade_doc, True, "Binary Bag-of-Words"
else:
return None, None, False, "Unknown Model"
# --- MODIFIED: Dot Product Calculation Function for the new tab ---
def calculate_dot_product_and_representations_independent(query_model_choice, doc_model_choice, query_text, doc_text):
query_vector_fn, query_tokenizer, query_is_binary, query_model_name_display = get_model_assets(query_model_choice)
doc_vector_fn, doc_tokenizer, doc_is_binary, doc_model_name_display = get_model_assets(doc_model_choice)
if query_vector_fn is None or doc_vector_fn is None:
return "Please select valid models for both query and document encoding.", "", ""
query_vector = query_vector_fn(query_text)
doc_vector = doc_vector_fn(doc_text)
if query_vector is None or doc_vector is None:
return "Failed to generate one or both vectors. Please check model loading and input text.", "", ""
# Calculate dot product
# Ensure both vectors are on CPU before dot product to avoid device mismatch issues
# and to ensure .item() works reliably for conversion to float.
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
# Format representations - these functions now return two strings (main_output, info_output)
query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
# Combine output into a single string for the Markdown component
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
full_output += "---\n\n"
# Query Representation
full_output += f"Query Representation ({query_model_name_display}):\n\n"
full_output += query_main_rep_str + "\n\n" + query_info_str # Added an extra newline for better spacing
full_output += "\n\n---\n\n" # Separator
# Document Representation
full_output += f"Document Representation ({doc_model_name_display}):\n\n"
full_output += doc_main_rep_str + "\n\n" + doc_info_str # Added an extra newline for better spacing
return full_output
# Global variable to store the share URL once the app is launched
global_share_url = None
# JavaScript to copy text to clipboard
# This will be injected into the Gradio UI
copy_to_clipboard_js = """
async (url) => {
if (url) {
try {
await navigator.clipboard.writeText(url);
console.log('Share URL copied to clipboard:', url);
return 'Copied to clipboard!'; // Message to display to user
} catch (err) {
console.error('Failed to copy share URL:', err);
return 'Failed to copy!';
}
}
return 'URL not available.';
}
"""
def get_share_url_for_button():
"""
Function to provide the share URL to the Gradio Textbox.
This will be called when the custom button is clicked.
"""
if global_share_url:
return global_share_url
else:
# In case the URL isn't set yet, provide a placeholder.
# This function runs server-side, so it won't have the client-side URL
# until the app has truly launched.
return "Generating share link..."
# --- Gradio Interface Setup with Tabs ---
with gr.Blocks(title="SPLADE Demos", css=css) as demo:
gr.Markdown("# 🌌 Sparse Encoder Playground")
gr.Markdown("Explore different SPLADE models and their sparse representation types, and calculate similarity between query and document representations.")
with gr.Tabs():
with gr.TabItem("Sparse Representation"):
gr.Markdown("### Produce a Sparse Representation of an Input Text")
with gr.Row():
with gr.Column(scale=1):
model_radio = gr.Radio(
[
"MLM encoder (SPLADE-cocondenser-distil)",
"MLP encoder (SPLADE-v3-lexical)",
"Binary Bag-of-Words"
],
label="Choose Sparse Encoder",
value="MLM encoder (SPLADE-cocondenser-distil)"
)
input_text = gr.Textbox(
lines=5,
label="Enter your query or document text here:",
placeholder="e.g., Why is Padua the nicest city in Italy?"
)
# Custom Share Button and URL display
with gr.Row(elem_classes="share-button-container"):
share_button = gr.Button(
"🔗 Get & Copy Share Link",
elem_classes="custom-share-button",
size="sm"
)
share_output_textbox = gr.Textbox( # Renamed for clarity
label="Share URL (Click button to copy)",
interactive=True,
visible=False,
placeholder="Click 'Get & Copy Share Link' to generate and copy URL..."
)
# A small text display to confirm copy
copy_confirmation_text = gr.Markdown(value="", visible=False)
info_output_display = gr.Markdown(
value="",
label="Vector Information",
elem_id="info_output_display"
)
with gr.Column(scale=2):
main_representation_output = gr.Markdown()
# Connect share button:
# 1. Populate the textbox with the share URL (from backend)
# 2. Make the textbox visible
# 3. Use JS to copy its content to clipboard
share_button.click(
fn=get_share_url_for_button, # This populates the textbox
outputs=share_output_textbox
).then(
fn=lambda: gr.update(visible=True), # This makes the textbox visible
outputs=share_output_textbox
).then(
fn=None, # This is a dummy function to allow JS to run
_js=f'(url) => {{ {copy_to_clipboard_js} (url); }}', # Pass the URL from the textbox to JS
inputs=[share_output_textbox], # The output of previous step (share_output_textbox) becomes input to this step
outputs=[copy_confirmation_text] # Output status of copy to confirmation text
).then(
fn=lambda msg: gr.update(value=msg, visible=True), # Show confirmation message
inputs=[copy_confirmation_text],
outputs=[copy_confirmation_text]
).then(
fn=lambda: gr.update(value="", visible=False), # Hide confirmation message after a short delay
outputs=[copy_confirmation_text],
queue=False, # Don't block other events
api_name="hide_confirmation_message", # Optional: give it an API name
every=2 # Hide after 2 seconds
)
# Connect the core prediction logic
model_radio.change(
fn=predict_representation_explorer,
inputs=[model_radio, input_text],
outputs=[main_representation_output, info_output_display]
)
input_text.change(
fn=predict_representation_explorer,
inputs=[model_radio, input_text],
outputs=[main_representation_output, info_output_display]
)
# Initial call to populate on load (optional, but good for demo)
demo.load(
fn=lambda: predict_representation_explorer(model_radio.value, input_text.value),
outputs=[main_representation_output, info_output_display]
)
with gr.TabItem("Compare Encoders"):
gr.Markdown("### Calculate Dot Product Similarity Between Encoded Query and Document")
model_choices = [
"MLM encoder (SPLADE-cocondenser-distil)",
"MLP encoder (SPLADE-v3-lexical)",
"Binary Bag-of-Words"
]
gr.Interface(
fn=calculate_dot_product_and_representations_independent,
inputs=[
gr.Radio(
model_choices,
label="Choose Query Encoding Model",
value="MLM encoder (SPLADE-cocondenser-distil)"
),
gr.Radio(
model_choices,
label="Choose Document Encoding Model",
value="MLM encoder (SPLADE-cocondenser-distil)"
),
gr.Textbox(
lines=3,
label="Enter Query Text:",
placeholder="e.g., famous dishes of Padua"
),
gr.Textbox(
lines=5,
label="Enter Document Text:",
placeholder="e.g., Padua's cuisine is as famous as its legendary University."
)
],
outputs=gr.Markdown(),
allow_flagging="never"
)
# This block ensures the share URL is captured when the app launches
if __name__ == "__main__":
# Launch and capture the share URL
# Gradio will print the share URL to the console automatically when share=True
# For more reliable programmatic access *within* the app (especially for older Gradio versions
# or more complex scenarios), you might need a custom endpoint or to rely on the
# client-side `Gradio.share_url` variable.
# For simple local demo, printing to console is often sufficient.
launched_info = demo.launch(share=True)
# In newer Gradio versions (e.g., 3.x and 4.x), the share_url is typically
# available via `launched_info.share_url` directly.
if hasattr(launched_info, 'share_url') and launched_info.share_url:
global_share_url = launched_info.share_url
print(f"\n--- Public Share URL (for your custom button): {global_share_url} ---\n")
else:
print("\n--- Public Share URL not available directly from launch info. "
"Please check your console for the Gradio-generated share link. ---\n")
print("\n--- Gradio App Running ---")
print("Open the link in your browser to interact with the app.")
print("---------------------------\n")