TextLSRDemo / app.py
SiddharthAK's picture
Update app.py
8cbba1e verified
raw
history blame
26.8 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import numpy as np
from tqdm.auto import tqdm
import os
# CSS to style the custom share button (for the "Sparse Representation" tab)
css = """
.share-button-container {
display: flex;
justify-content: center;
margin-top: 10px;
margin-bottom: 20px;
}
.custom-share-button {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: none;
color: white;
padding: 8px 16px;
text-align: center;
text-decoration: none;
display: inline-block;
font-size: 14px;
margin: 4px 2px;
cursor: pointer;
border-radius: 6px;
transition: all 0.3s ease;
}
.custom-share-button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0,0,0,0.2);
}
/* IMPORTANT: This CSS targets Gradio's *default* share button that appears
when demo.launch(share=True) is used.
You might want to comment this out if you prefer Gradio's default positioning
for the main share button (usually in the header/footer) and rely only on your custom one.
*/
.share-button {
position: fixed !important;
top: 20px !important;
right: 20px !important;
z-index: 1000 !important;
background: #4CAF50 !important;
color: white !important;
border-radius: 8px !important;
padding: 8px 16px !important;
font-weight: bold !important;
box-shadow: 0 2px 10px rgba(0,0,0,0.2) !important;
}
.share-button:hover {
background: #45a049 !important;
transform: translateY(-1px) !important;
}
"""
# --- Model Loading ---
tokenizer_splade = None
model_splade = None
tokenizer_splade_lexical = None
model_splade_lexical = None
tokenizer_splade_doc = None
model_splade_doc = None
# Load SPLADE v3 model (original)
try:
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade.eval()
print("SPLADE-cocondenser-distil model loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE-cocondenser-distil model: {e}")
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
# Load SPLADE v3 Lexical model
try:
splade_lexical_model_name = "naver/splade-v3-lexical"
tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
model_splade_lexical.eval()
print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE-v3-Lexical model: {e}")
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# Load SPLADE v3 Doc model - Model loading is still necessary even if its logits aren't used for BoW
try:
splade_doc_model_name = "naver/splade-v3-doc"
tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name) # Still load the model
model_splade_doc.eval()
print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE-v3-Doc model: {e}")
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# --- Helper function for lexical mask (now handles batches, but used for single input here) ---
def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
"""
Creates a batch of lexical BOW masks.
input_ids_batch: torch.Tensor of shape (batch_size, sequence_length)
vocab_size: int, size of the tokenizer vocabulary
tokenizer: the tokenizer object
Returns: torch.Tensor of shape (batch_size, vocab_size)
"""
batch_size = input_ids_batch.shape[0]
bow_masks = torch.zeros(batch_size, vocab_size, device=input_ids_batch.device)
for i in range(batch_size):
input_ids = input_ids_batch[i] # Get input_ids for the current item in the batch
meaningful_token_ids = []
for token_id in input_ids.tolist():
if token_id not in [
tokenizer.pad_token_id,
tokenizer.cls_token_id,
tokenizer.sep_token_id,
tokenizer.mask_token_id,
tokenizer.unk_token_id
]:
meaningful_token_ids.append(token_id)
if meaningful_token_ids:
# Apply mask to the current row in the batch
bow_masks[i, list(set(meaningful_token_ids))] = 1
return bow_masks
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
# These functions now return a tuple: (main_representation_str, info_str)
def get_splade_cocondenser_representation(text):
if tokenizer_splade is None or model_splade is None:
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors.", ""
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze() # Squeeze is fine here as it's a single input
else:
return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.", ""
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "MLM Representation:\n\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
# Changed to paragraph style
terms_list = []
for term, weight in sorted_representation:
terms_list.append(f"**{term}**: {weight:.4f}")
formatted_output += ", ".join(terms_list) + "."
info_output = f"" # Line 1
info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
return formatted_output, info_output
def get_splade_lexical_representation(text):
if tokenizer_splade_lexical is None or model_splade_lexical is None:
return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors.", ""
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade_lexical(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze() # Squeeze is fine here
else:
return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.", ""
# Always apply lexical mask for this model's specific behavior
vocab_size = tokenizer_splade_lexical.vocab_size
# Call with unsqueezed input_ids for single sample processing
bow_mask = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
).squeeze() # Squeeze back for single output
splade_vector = splade_vector * bow_mask
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade_lexical.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "MLP Representation:\n\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
# Changed to paragraph style
terms_list = []
for term, weight in sorted_representation:
terms_list.append(f"**{term}**: {weight:.4f}")
formatted_output += ", ".join(terms_list) + "."
info_output = f"" # Line 1
info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
return formatted_output, info_output
def get_splade_doc_representation(text):
if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
return "SPLADE-v3-Doc tokenizer is not loaded. Please check the console for loading errors.", ""
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
vocab_size = tokenizer_splade_doc.vocab_size
# Directly create the binary Bag-of-Words vector using the input_ids
binary_bow_vector = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_doc
).squeeze() # Squeeze back for single output
indices = torch.nonzero(binary_bow_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
values = [1.0] * len(indices) # All values are 1 for binary representation
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade_doc.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity
formatted_output = "Binary Bag-of-Words Representation:\n\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
# Changed to paragraph style
terms_list = []
for term, _ in sorted_representation: # For binary, weight is always 1, so no need to display
terms_list.append(f"**{term}**")
formatted_output += ", ".join(terms_list) + "."
info_output = f"" # Line 1
info_output += f"Total non-zero terms in vector: {len(indices)}" # Line 2
return formatted_output, info_output
# --- Unified Prediction Function for the Explorer Tab ---
def predict_representation_explorer(model_choice, text):
if model_choice == "MLM encoder (SPLADE-cocondenser-distil)":
return get_splade_cocondenser_representation(text)
elif model_choice == "MLP encoder (SPLADE-v3-lexical)":
return get_splade_lexical_representation(text)
elif model_choice == "Binary Bag-of-Words": # Changed name
return get_splade_doc_representation(text)
else:
return "Please select a model.", "" # Return two empty strings for consistency
# --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
# These functions remain unchanged from the previous iteration, as they return the raw tensors.
def get_splade_cocondenser_vector(text):
if tokenizer_splade is None or model_splade is None:
return None
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze()
return splade_vector
return None
def get_splade_lexical_vector(text):
if tokenizer_splade_lexical is None or model_splade_lexical is None:
return None
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade_lexical(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze()
vocab_size = tokenizer_splade_lexical.vocab_size
bow_mask = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
).squeeze()
splade_vector = splade_vector * bow_mask
return splade_vector
return None
def get_splade_doc_vector(text):
if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
return None
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
vocab_size = tokenizer_splade_doc.vocab_size
# Directly create the binary Bag-of-Words vector using the input_ids
binary_bow_vector = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_doc
).squeeze()
return binary_bow_vector
# --- Function to get formatted representation from a raw vector and tokenizer ---
# This function remains unchanged as it's a generic formatter for any sparse vector.
def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
if splade_vector is None:
return "Failed to generate vector.", ""
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
if is_binary:
values = [1.0] * len(indices)
else:
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
if is_binary:
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for binary
else:
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = ""
if not sorted_representation:
formatted_output += "No significant terms found.\n"
else:
terms_list = []
for i, (term, weight) in enumerate(sorted_representation):
if i >= 50:
terms_list.append(f"...and {len(sorted_representation) - 50} more terms.")
break
if is_binary:
terms_list.append(f"**{term}**")
else:
terms_list.append(f"**{term}**: {weight:.4f}")
formatted_output += ", ".join(terms_list) + "."
# This is the line that will now always be split into two
info_output = f"Total non-zero terms: {len(indices)}\n" # Line 1
return formatted_output, info_output
# --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag ---
def get_model_assets(model_choice_str):
if model_choice_str == "MLM encoder (SPLADE-cocondenser-distil)":
return get_splade_cocondenser_vector, tokenizer_splade, False, "MLM encoder (SPLADE-cocondenser-distil)"
elif model_choice_str == "MLP encoder (SPLADE-v3-lexical)":
return get_splade_lexical_vector, tokenizer_splade_lexical, False, "MLP encoder (SPLADE-v3-lexical)"
elif model_choice_str == "Binary Bag-of-Words":
return get_splade_doc_vector, tokenizer_splade_doc, True, "Binary Bag-of-Words"
else:
return None, None, False, "Unknown Model"
# --- MODIFIED: Dot Product Calculation Function for the new tab ---
def calculate_dot_product_and_representations_independent(query_model_choice, doc_model_choice, query_text, doc_text):
query_vector_fn, query_tokenizer, query_is_binary, query_model_name_display = get_model_assets(query_model_choice)
doc_vector_fn, doc_tokenizer, doc_is_binary, doc_model_name_display = get_model_assets(doc_model_choice)
if query_vector_fn is None or doc_vector_fn is None:
return "Please select valid models for both query and document encoding.", "", ""
query_vector = query_vector_fn(query_text)
doc_vector = doc_vector_fn(doc_text)
if query_vector is None or doc_vector is None:
return "Failed to generate one or both vectors. Please check model loading and input text.", "", ""
# Calculate dot product
# Ensure both vectors are on CPU before dot product to avoid device mismatch issues
# and to ensure .item() works reliably for conversion to float.
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
# Format representations - these functions now return two strings (main_output, info_output)
query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
# Combine output into a single string for the Markdown component
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
full_output += "---\n\n"
# Query Representation
full_output += f"Query Representation ({query_model_name_display}):\n\n"
full_output += query_main_rep_str + "\n\n" + query_info_str # Added an extra newline for better spacing
full_output += "\n\n---\n\n" # Separator
# Document Representation
full_output += f"Document Representation ({doc_model_name_display}):\n\n"
full_rep_str = doc_main_rep_str + "\n\n" + doc_info_str # Adjusted for consistency
full_output += full_rep_str
return full_output
# Global variable to store the share URL once the app is launched
# This will be populated by the demo.launch() call
global_share_url = None
def get_current_share_url():
"""Returns the globally stored share URL."""
return global_share_url if global_share_url else "Share URL not available yet."
# --- Gradio Interface Setup with Tabs ---
with gr.Blocks(title="SPLADE Demos", css=css) as demo:
gr.Markdown("# 🌌 Sparse Encoder Playground") # Updated title
gr.Markdown("Explore different SPLADE models and their sparse representation types, and calculate similarity between query and document representations.") # Updated description
with gr.Tabs():
with gr.TabItem("Sparse Representation"):
gr.Markdown("### Produce a Sparse Representation of an Input Text")
with gr.Row():
with gr.Column(scale=1): # Left column for inputs and info
model_radio = gr.Radio(
[
"MLM encoder (SPLADE-cocondenser-distil)",
"MLP encoder (SPLADE-v3-lexical)",
"Binary Bag-of-Words"
],
label="Choose Sparse Encoder",
value="MLM encoder (SPLADE-cocondenser-distil)"
)
input_text = gr.Textbox(
lines=5,
label="Enter your query or document text here:",
placeholder="e.g., Why is Padua the nicest city in Italy?"
)
# Custom Share Button and URL display
with gr.Row(elem_classes="share-button-container"):
share_button = gr.Button(
"🔗 Get Share Link",
elem_classes="custom-share-button",
size="sm"
)
share_output = gr.Textbox(
label="Share URL",
interactive=True, # Make it interactive so user can copy
visible=False, # Start as hidden
placeholder="Click 'Get Share Link' to generate URL..."
)
info_output_display = gr.Markdown(
value="",
label="Vector Information",
elem_id="info_output_display"
)
with gr.Column(scale=2): # Right column for the main representation output
main_representation_output = gr.Markdown()
# Connect share button.
# When share_button is clicked, get_current_share_url is called to populate share_output
# and then share_output is made visible.
share_button.click(
fn=get_current_share_url,
outputs=share_output
).then(
fn=lambda: gr.update(visible=True),
outputs=share_output
)
# Connect the core prediction logic
model_radio.change(
fn=predict_representation_explorer,
inputs=[model_radio, input_text],
outputs=[main_representation_output, info_output_display]
)
input_text.change(
fn=predict_representation_explorer,
inputs=[model_radio, input_text],
outputs=[main_representation_output, info_output_display]
)
# Initial call to populate on load (optional, but good for demo)
demo.load(
fn=lambda: predict_representation_explorer(model_radio.value, input_text.value),
outputs=[main_representation_output, info_output_display]
)
with gr.TabItem("Compare Encoders"):
gr.Markdown("### Calculate Dot Product Similarity Between Encoded Query and Document")
model_choices = [
"MLM encoder (SPLADE-cocondenser-distil)",
"MLP encoder (SPLADE-v3-lexical)",
"Binary Bag-of-Words"
]
gr.Interface(
fn=calculate_dot_product_and_representations_independent,
inputs=[
gr.Radio(
model_choices,
label="Choose Query Encoding Model",
value="MLM encoder (SPLADE-cocondenser-distil)"
),
gr.Radio(
model_choices,
label="Choose Document Encoding Model",
value="MLM encoder (SPLADE-cocondenser-distil)"
),
gr.Textbox(
lines=3,
label="Enter Query Text:",
placeholder="e.g., famous dishes of Padua"
),
gr.Textbox(
lines=5,
label="Enter Document Text:",
placeholder="e.g., Padua's cuisine is as famous as its legendary University."
)
],
outputs=gr.Markdown(),
allow_flagging="never"
)
# This block ensures the share URL is captured when the app launches
if __name__ == "__main__":
# Launch and capture the share URL
# Setting live=False for better control over when functions run
# and to ensure the share_url is available early.
launched_demo = demo.launch(share=True)
# The share_url is typically available on the launched_demo object
# This might vary slightly based on Gradio version.
# For newer Gradio versions (>=3.x), the share_url is usually printed to console
# and can sometimes be accessed via launched_demo.share_url directly if `return_url=True`
# (though that was removed in recent versions as it's default now)
# In older Gradio, you might need to manually copy the URL from console.
# In newer Gradio, the share URL is often automatically put into a share button/link
# in the UI (usually top right or footer) when share=True.
# For our custom button, we're assuming we'll get the URL once it's launched.
# Gradio automatically injects the share URL into the client-side JavaScript.
# A simple way to get it from within the app is using Javascript or by storing it globally.
# To bridge the gap, we'll tell the user to copy from the console for now,
# or the built-in Gradio share icon (usually on the top right) will work automatically.
print("\n--- Gradio App Launched ---")
print("If a public share link is generated, it will be displayed in your console.")
print("You can also use the '🔗 Get Share Link' button on the 'Sparse Representation' tab.")
print("---------------------------\n")
# In a real deployed scenario, this global_share_url might be populated by
# a Gradio callback for `launch`, but for typical local execution,
# the URL is primarily given via console/built-in UI.
# The `get_current_share_url` function will return 'Share URL not available yet.'
# until a share URL is actually generated by Gradio's backend when running `share=True`.
# To truly populate `global_share_url` reliably for the custom button,
# you often need to manually copy it or rely on a more advanced Gradio feature (like a custom JS callback or an API endpoint)
# which is beyond simple `gr.Blocks` component connections.
# The most direct approach for users is the default Gradio share icon/text.