TextLSRDemo / app.py
SiddharthAK's picture
Update app.py
d9c4db0 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import numpy as np
from tqdm.auto import tqdm
import os
# CSS to style the custom share button (for the "Sparse Representation" tab)
css = """
.share-button-container {
display: flex;
justify-content: center;
margin-top: 10px;
margin-bottom: 20px;
}
.custom-share-button {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: none;
color: white;
padding: 8px 16px;
text-align: center;
text-decoration: none;
display: inline-block;
font-size: 14px;
margin: 4px 2px;
cursor: pointer;
border-radius: 6px;
transition: all 0.3s ease;
}
.custom-share-button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0,0,0,0.2);
}
/* IMPORTANT: This CSS targets Gradio's *default* share button that appears
when demo.launch(share=True) is used.
You might want to comment this out if you prefer Gradio's default positioning
for the main share button (usually in the header/footer) and rely only on your custom one.
*/
.share-button {
position: fixed !important;
top: 20px !important;
right: 20px !important;
z-index: 1000 !important;
background: #4CAF50 !important;
color: white !important;
border-radius: 8px !important;
padding: 8px 16px !important;
font-weight: bold !important;
box-shadow: 0 2px 10px rgba(0,0,0,0.2) !important;
}
.share-button:hover {
background: #45a049 !important;
transform: translateY(-1px) !important;
}
"""
# --- Model Loading ---
tokenizer_splade = None
model_splade = None
tokenizer_splade_lexical = None
model_splade_lexical = None
tokenizer_splade_doc = None
model_splade_doc = None
# Load SPLADE v3 model (original)
try:
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade.eval()
print("SPLADE-cocondenser-distil model loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE-cocondenser-distil model: {e}")
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
# Load SPLADE v3 Lexical model
try:
splade_lexical_model_name = "naver/splade-v3-lexical"
tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
model_splade_lexical.eval()
print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE-v3-Lexical model: {e}")
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# Load SPLADE v3 Doc model - Model loading is still necessary even if its logits aren't used for BoW
try:
splade_doc_model_name = "naver/splade-v3-doc"
tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name) # Still load the model
model_splade_doc.eval()
print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE-v3-Doc model: {e}")
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# --- Helper function for lexical mask (now handles batches, but used for single input here) ---
def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
"""
Creates a batch of lexical BOW masks.
input_ids_batch: torch.Tensor of shape (batch_size, sequence_length)
vocab_size: int, size of the tokenizer vocabulary
tokenizer: the tokenizer object
Returns: torch.Tensor of shape (batch_size, vocab_size)
"""
batch_size = input_ids_batch.shape[0]
bow_masks = torch.zeros(batch_size, vocab_size, device=input_ids_batch.device)
for i in range(batch_size):
input_ids = input_ids_batch[i] # Get input_ids for the current item in the batch
meaningful_token_ids = []
for token_id in input_ids.tolist():
if token_id not in [
tokenizer.pad_token_id,
tokenizer.cls_token_id,
tokenizer.sep_token_id,
tokenizer.mask_token_id,
tokenizer.unk_token_id
]:
meaningful_token_ids.append(token_id)
if meaningful_token_ids:
# Apply mask to the current row in the batch
bow_masks[i, list(set(meaningful_token_ids))] = 1
return bow_masks
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
# These functions now return a tuple: (main_representation_str, info_str)
def get_splade_cocondenser_representation(text):
if tokenizer_splade is None or model_splade is None:
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors.", ""
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze() # Squeeze is fine here as it's a single input
else:
return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.", ""
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "MLM Representation:\n\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
# Changed to paragraph style
terms_list = []
for term, weight in sorted_representation:
terms_list.append(f"**{term}**: {weight:.4f}")
formatted_output += ", ".join(terms_list) + "."
info_output = f"" # Line 1
info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
return formatted_output, info_output
def get_splade_lexical_representation(text):
if tokenizer_splade_lexical is None or model_splade_lexical is None:
return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors.", ""
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade_lexical(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze() # Squeeze is fine here
else:
return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.", ""
# Always apply lexical mask for this model's specific behavior
vocab_size = tokenizer_splade_lexical.vocab_size
# Call with unsqueezed input_ids for single sample processing
bow_mask = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
).squeeze() # Squeeze back for single output
splade_vector = splade_vector * bow_mask
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade_lexical.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "MLP Representation:\n\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
# Changed to paragraph style
terms_list = []
for term, weight in sorted_representation:
terms_list.append(f"**{term}**: {weight:.4f}")
formatted_output += ", ".join(terms_list) + "."
info_output = f"" # Line 1
info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
return formatted_output, info_output
def get_splade_doc_representation(text):
if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
return "SPLADE-v3-Doc tokenizer is not loaded. Please check the console for loading errors.", ""
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
vocab_size = tokenizer_splade_doc.vocab_size
# Directly create the binary Bag-of-Words vector using the input_ids
binary_bow_vector = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_doc
).squeeze() # Squeeze back for single output
indices = torch.nonzero(binary_bow_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
values = [1.0] * len(indices) # All values are 1 for binary representation
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade_doc.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity
formatted_output = "Binary:\n\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
# Changed to paragraph style
terms_list = []
for term, _ in sorted_representation: # For binary, weight is always 1, so no need to display
terms_list.append(f"**{term}**")
formatted_output += ", ".join(terms_list) + "."
info_output = f"" # Line 1
info_output += f"Total non-zero terms in vector: {len(indices)}" # Line 2
return formatted_output, info_output
# --- Unified Prediction Function for the Explorer Tab ---
def predict_representation_explorer(model_choice, text):
if model_choice == "MLM encoder (SPLADE-cocondenser-distil)":
return get_splade_cocondenser_representation(text)
elif model_choice == "MLP encoder (SPLADE-v3-lexical)":
return get_splade_lexical_representation(text)
elif model_choice == "Binary": # Changed name
return get_splade_doc_representation(text)
else:
return "Please select a model.", "" # Return two empty strings for consistency
# --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
# These functions remain unchanged from the previous iteration, as they return the raw tensors.
def get_splade_cocondenser_vector(text):
if tokenizer_splade is None or model_splade is None:
return None
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze()
return splade_vector
return None
def get_splade_lexical_vector(text):
if tokenizer_splade_lexical is None or model_splade_lexical is None:
return None
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade_lexical(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze()
vocab_size = tokenizer_splade_lexical.vocab_size
bow_mask = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
).squeeze()
splade_vector = splade_vector * bow_mask
return splade_vector
return None
def get_splade_doc_vector(text):
if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
return None
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
vocab_size = tokenizer_splade_doc.vocab_size
# Directly create the binary Bag-of-Words vector using the input_ids
binary_bow_vector = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_doc
).squeeze()
return binary_bow_vector
# --- Function to get formatted representation from a raw vector and tokenizer ---
# This function remains unchanged as it's a generic formatter for any sparse vector.
def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
if splade_vector is None:
return "Failed to generate vector.", ""
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices] if indices else []
if is_binary:
values = [1.0] * len(indices)
else:
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
if is_binary:
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for binary
else:
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = ""
if not sorted_representation:
formatted_output += "No significant terms found.\n"
else:
terms_list = []
for i, (term, weight) in enumerate(sorted_representation):
# Limiting to 50 terms for display to avoid overly long output
if is_binary:
terms_list.append(f"**{term}**")
else:
terms_list.append(f"**{term}**: {weight:.4f}")
formatted_output += ", ".join(terms_list) + "."
# This is the line that will now always be split into two
info_output = f"Total non-zero terms: {len(indices)}\n" # Line 1
return formatted_output, info_output
# --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag ---
def get_model_assets(model_choice_str):
if model_choice_str == "MLM encoder (SPLADE-cocondenser-distil)":
return get_splade_cocondenser_vector, tokenizer_splade, False, "MLM encoder (SPLADE-cocondenser-distil)"
elif model_choice_str == "MLP encoder (SPLADE-v3-lexical)":
return get_splade_lexical_vector, tokenizer_splade_lexical, False, "MLP encoder (SPLADE-v3-lexical)"
elif model_choice_str == "Binary":
return get_splade_doc_vector, tokenizer_splade_doc, True, "Binary Bag-of-Words"
else:
return None, None, False, "Unknown Model"
# --- MODIFIED: Dot Product Calculation Function for the new tab ---
def calculate_dot_product_and_representations_independent(query_model_choice, doc_model_choice, query_text, doc_text):
query_vector_fn, query_tokenizer, query_is_binary, query_model_name_display = get_model_assets(query_model_choice)
doc_vector_fn, doc_tokenizer, doc_is_binary, doc_model_name_display = get_model_assets(doc_model_choice)
if query_vector_fn is None or doc_vector_fn is None:
return "Please select valid models for both query and document encoding.", ""
query_vector = query_vector_fn(query_text)
doc_vector = doc_vector_fn(doc_text)
if query_vector is None or doc_vector is None:
return "Failed to generate one or both vectors. Please check model loading and input text.", ""
# Calculate overall dot product
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
# Format representations for display
query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
# --- NEW FEATURE: Calculate dot product of overlapping terms ---
overlapping_terms_dot_products = {}
query_indices = torch.nonzero(query_vector).squeeze().cpu()
doc_indices = torch.nonzero(doc_vector).squeeze().cpu()
# Handle cases where vectors are empty or single element
if query_indices.dim() == 0 and query_indices.numel() == 1:
query_indices = query_indices.unsqueeze(0)
if doc_indices.dim() == 0 and doc_indices.numel() == 1:
doc_indices = doc_indices.unsqueeze(0)
# Convert indices to sets for efficient intersection
query_index_set = set(query_indices.tolist())
doc_index_set = set(doc_indices.tolist())
common_indices = sorted(list(query_index_set.intersection(doc_index_set)))
if common_indices:
for idx in common_indices:
query_weight = query_vector[idx].item()
doc_weight = doc_vector[idx].item()
term = query_tokenizer.decode([idx]) # Tokenizers should be the same for this purpose
if term not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(term.strip()) > 0:
overlapping_terms_dot_products[term] = query_weight * doc_weight
sorted_overlapping_dot_products = sorted(
overlapping_terms_dot_products.items(),
key=lambda item: item[1],
reverse=True
)
# --- End NEW FEATURE ---
# Combine output into a single string for the Markdown component
full_output = f"### Overall Dot Product Score: {dot_product:.6f}\n\n"
full_output += "---\n\n"
# Overlapping Terms Dot Products
if sorted_overlapping_dot_products:
full_output += "### Product of Query and Document Term Scores:\n"
full_output += "\n" # Removed the individual weight explanation
overlap_list = []
for term, product_val in sorted_overlapping_dot_products:
overlap_list.append(f"**{term}**: {product_val:.4f}") # Simplified to just the dot product
full_output += ", ".join(overlap_list) + ".\n\n"
full_output += "---\n\n"
else:
full_output += "### No Overlapping Terms Found.\n\n"
full_output += "---\n\n"
# Query Representation
full_output += f"#### Query Representation: {query_model_name_display}\n" # Smaller heading for sub-section
full_output += f"> {query_main_rep_str}\n" # Using blockquote for the sparse list
full_output += f"> {query_info_str}\n" # Using blockquote for info as well
full_output += "\n---\n\n" # Separator
# Document Representation
full_output += f"#### Document Representation: {doc_model_name_display}\n" # Smaller heading for sub-section
full_output += f"> {doc_main_rep_str}\n" # Using blockquote
full_output += f"> {doc_info_str}" # Using blockquote
return full_output
# Global variable to store the share URL once the app is launched
global_share_url = None
def get_current_share_url():
"""Returns the globally stored share URL."""
return global_share_url if global_share_url else "Share URL not available yet."
# --- Gradio Interface Setup with Tabs ---
with gr.Blocks(title="SPLADE Demos", css=css) as demo:
gr.Markdown("# 🌌 Sparse Encoder Playground") # Updated title
gr.Markdown("Explore different SPLADE models and their sparse representation types, and calculate similarity between query and document representations.") # Updated description
with gr.Tabs():
with gr.TabItem("Sparse Representation"):
gr.Markdown("### Produce a Sparse Representation of an Input Text")
with gr.Row():
with gr.Column(scale=1): # Left column for inputs and info
model_radio = gr.Radio(
[
"MLM encoder (SPLADE-cocondenser-distil)",
"MLP encoder (SPLADE-v3-lexical)",
"Binary"
],
label="Choose Sparse Encoder",
value="MLM encoder (SPLADE-cocondenser-distil)"
)
input_text = gr.Textbox(
lines=5,
label="Enter your query or document text here:",
placeholder="e.g., Why is Padua the nicest city in Italy?"
)
# Custom Share Button and URL display
with gr.Row(elem_classes="share-button-container"):
share_button = gr.Button(
"🔗 Get Share Link",
elem_classes="custom-share-button",
size="sm"
)
share_output = gr.Textbox(
label="Share URL",
interactive=True,
visible=False,
placeholder="Click 'Get Share Link' to generate URL..."
)
info_output_display = gr.Markdown(
value="",
label="Vector Information",
elem_id="info_output_display"
)
with gr.Column(scale=2): # Right column for the main representation output
main_representation_output = gr.Markdown()
# Connect share button.
share_button.click(
fn=get_current_share_url,
outputs=share_output
).then(
fn=lambda: gr.update(visible=True),
outputs=share_output
)
# Connect the core prediction logic
model_radio.change(
fn=predict_representation_explorer,
inputs=[model_radio, input_text],
outputs=[main_representation_output, info_output_display]
)
input_text.change(
fn=predict_representation_explorer,
inputs=[model_radio, input_text],
outputs=[main_representation_output, info_output_display]
)
# Initial call to populate on load (optional, but good for demo)
demo.load(
fn=lambda: predict_representation_explorer(model_radio.value, input_text.value),
outputs=[main_representation_output, info_output_display]
)
with gr.TabItem("Compute Query-Document Similarity Score"):
gr.Markdown("### Calculate Dot Product Similarity Between Encoded Query and Document")
model_choices = [
"MLM encoder (SPLADE-cocondenser-distil)",
"MLP encoder (SPLADE-v3-lexical)",
"Binary"
]
# Input components for the second tab
query_model_radio = gr.Radio(
model_choices,
label="Choose Query Encoding Model",
value="MLM encoder (SPLADE-cocondenser-distil)"
)
doc_model_radio = gr.Radio(
model_choices,
label="Choose Document Encoding Model",
value="MLM encoder (SPLADE-cocondenser-distil)"
)
query_text_input = gr.Textbox(
lines=3,
label="Enter Query Text:",
placeholder="e.g., famous dishes of Padua"
)
doc_text_input = gr.Textbox(
lines=5,
label="Enter Document Text:",
placeholder="e.g., Padua's cuisine is as famous as its legendary University."
)
# --- MODIFIED: Output component as a gr.Markdown with scrolling ---
# Reverting to gr.Markdown, and adding height/scroll for it
output_dot_product_markdown = gr.Markdown(
# Use value="" to initialize, content will be set by the function
value="",
# Fixed height for the scrollable area
# You can adjust this value (e.g., "500px") to your preference
# Or set it as a percentage of available space, e.g., "80%"
height=500, # Example: 500 pixels height
# Enable vertical scrolling if content overflows
# "auto" is often good, "scroll" always shows scrollbar
# Gradio uses `css` for this, so these parameters might translate to inline styles
# or custom CSS classes automatically added by Gradio.
elem_classes=["scrollable-output"] # Add a custom class for CSS targeting if needed
)
# Add CSS specifically for this scrollable markdown output
# This needs to be added to the overall `css` string or handled directly here
# For simplicity, let's assume `height` itself will enable scroll in newer Gradio,
# or add a specific CSS class targeting the markdown.
# However, for pure markdown, `height` is the primary way.
# Update the gr.Interface call to use the new Markdown output
gr.Interface(
fn=calculate_dot_product_and_representations_independent,
inputs=[
query_model_radio,
doc_model_radio,
query_text_input,
doc_text_input
],
outputs=output_dot_product_markdown, # Changed back to Markdown
allow_flagging="never"
)
# --- UPDATED CITATION BLOCK WITH TWO REFERENCES ---
gr.Markdown(
"""
---
### References
This demo utilizes **SPLADE** models. For more details, please refer to the following papers:
1. Formal, T., Lassance, C., Piwowarski, B., & Clinchant, S. (2022). **From Distillation to Hard Negative Sampling: Making Sparse Neural IR Models More Effective**. *arXiv preprint arXiv:2205.04733*. Available at: [https://arxiv.org/abs/2205.04733](https://arxiv.org/abs/2205.04733)
2. Lassance, C., Déjean, H., Formal, T., & Clinchant, S. (2024). **SPLADE-v3: New baselines for SPLADE**. *arXiv preprint arXiv:2403.06789*. Available at: [https://arxiv.org/abs/2403.06789](https://arxiv.org/abs/2403.06789)
"""
)
# This block ensures the share URL is captured when the app launches
if __name__ == "__main__":
launched_demo = demo.launch(share=True)
print("\n--- Gradio App Launched ---")
print("If a public share link is generated, it will be displayed in your console.")
print("You can also use the '🔗 Get Share Link' button on the 'Sparse Representation' tab.")
print("---------------------------\n")