Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
import torch | |
import numpy as np | |
from tqdm.auto import tqdm | |
import os | |
# CSS to style the custom share button (for the "Sparse Representation" tab) | |
css = """ | |
.share-button-container { | |
display: flex; | |
justify-content: center; | |
margin-top: 10px; | |
margin-bottom: 20px; | |
} | |
.custom-share-button { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
border: none; | |
color: white; | |
padding: 8px 16px; | |
text-align: center; | |
text-decoration: none; | |
display: inline-block; | |
font-size: 14px; | |
margin: 4px 2px; | |
cursor: pointer; | |
border-radius: 6px; | |
transition: all 0.3s ease; | |
} | |
.custom-share-button:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 4px 8px rgba(0,0,0,0.2); | |
} | |
/* IMPORTANT: This CSS targets Gradio's *default* share button that appears | |
when demo.launch(share=True) is used. | |
You might want to comment this out if you prefer Gradio's default positioning | |
for the main share button (usually in the header/footer) and rely only on your custom one. | |
*/ | |
.share-button { | |
position: fixed !important; | |
top: 20px !important; | |
right: 20px !important; | |
z-index: 1000 !important; | |
background: #4CAF50 !important; | |
color: white !important; | |
border-radius: 8px !important; | |
padding: 8px 16px !important; | |
font-weight: bold !important; | |
box-shadow: 0 2px 10px rgba(0,0,0,0.2) !important; | |
} | |
.share-button:hover { | |
background: #45a049 !important; | |
transform: translateY(-1px) !important; | |
} | |
""" | |
# --- Model Loading --- | |
tokenizer_splade = None | |
model_splade = None | |
tokenizer_splade_lexical = None | |
model_splade_lexical = None | |
tokenizer_splade_doc = None | |
model_splade_doc = None | |
# Load SPLADE v3 model (original) | |
try: | |
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil") | |
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil") | |
model_splade.eval() | |
print("SPLADE-cocondenser-distil model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading SPLADE-cocondenser-distil model: {e}") | |
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.") | |
# Load SPLADE v3 Lexical model | |
try: | |
splade_lexical_model_name = "naver/splade-v3-lexical" | |
tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name) | |
model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name) | |
model_splade_lexical.eval() | |
print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!") | |
except Exception as e: | |
print(f"Error loading SPLADE-v3-Lexical model: {e}") | |
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).") | |
# Load SPLADE v3 Doc model - Model loading is still necessary even if its logits aren't used for BoW | |
try: | |
splade_doc_model_name = "naver/splade-v3-doc" | |
tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name) | |
model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name) # Still load the model | |
model_splade_doc.eval() | |
print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!") | |
except Exception as e: | |
print(f"Error loading SPLADE-v3-Doc model: {e}") | |
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).") | |
# --- Helper function for lexical mask (now handles batches, but used for single input here) --- | |
def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer): | |
""" | |
Creates a batch of lexical BOW masks. | |
input_ids_batch: torch.Tensor of shape (batch_size, sequence_length) | |
vocab_size: int, size of the tokenizer vocabulary | |
tokenizer: the tokenizer object | |
Returns: torch.Tensor of shape (batch_size, vocab_size) | |
""" | |
batch_size = input_ids_batch.shape[0] | |
bow_masks = torch.zeros(batch_size, vocab_size, device=input_ids_batch.device) | |
for i in range(batch_size): | |
input_ids = input_ids_batch[i] # Get input_ids for the current item in the batch | |
meaningful_token_ids = [] | |
for token_id in input_ids.tolist(): | |
if token_id not in [ | |
tokenizer.pad_token_id, | |
tokenizer.cls_token_id, | |
tokenizer.sep_token_id, | |
tokenizer.mask_token_id, | |
tokenizer.unk_token_id | |
]: | |
meaningful_token_ids.append(token_id) | |
if meaningful_token_ids: | |
# Apply mask to the current row in the batch | |
bow_masks[i, list(set(meaningful_token_ids))] = 1 | |
return bow_masks | |
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) --- | |
# These functions now return a tuple: (main_representation_str, info_str) | |
def get_splade_cocondenser_representation(text): | |
if tokenizer_splade is None or model_splade is None: | |
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors.", "" | |
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True) | |
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
output = model_splade(**inputs) | |
if hasattr(output, 'logits'): | |
splade_vector = torch.max( | |
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), | |
dim=1 | |
)[0].squeeze() # Squeeze is fine here as it's a single input | |
else: | |
return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.", "" | |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() | |
if not isinstance(indices, list): | |
indices = [indices] if indices else [] | |
values = splade_vector[indices].cpu().tolist() | |
token_weights = dict(zip(indices, values)) | |
meaningful_tokens = {} | |
for token_id, weight in token_weights.items(): | |
decoded_token = tokenizer_splade.decode([token_id]) | |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
meaningful_tokens[decoded_token] = weight | |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) | |
formatted_output = "MLM Representation:\n\n" | |
if not sorted_representation: | |
formatted_output += "No significant terms found for this input.\n" | |
else: | |
# Changed to paragraph style | |
terms_list = [] | |
for term, weight in sorted_representation: | |
terms_list.append(f"**{term}**: {weight:.4f}") | |
formatted_output += ", ".join(terms_list) + "." | |
info_output = f"" # Line 1 | |
info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity) | |
return formatted_output, info_output | |
def get_splade_lexical_representation(text): | |
if tokenizer_splade_lexical is None or model_splade_lexical is None: | |
return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors.", "" | |
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True) | |
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
output = model_splade_lexical(**inputs) | |
if hasattr(output, 'logits'): | |
splade_vector = torch.max( | |
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), | |
dim=1 | |
)[0].squeeze() # Squeeze is fine here | |
else: | |
return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.", "" | |
# Always apply lexical mask for this model's specific behavior | |
vocab_size = tokenizer_splade_lexical.vocab_size | |
# Call with unsqueezed input_ids for single sample processing | |
bow_mask = create_lexical_bow_mask( | |
inputs['input_ids'], vocab_size, tokenizer_splade_lexical | |
).squeeze() # Squeeze back for single output | |
splade_vector = splade_vector * bow_mask | |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() | |
if not isinstance(indices, list): | |
indices = [indices] if indices else [] | |
values = splade_vector[indices].cpu().tolist() | |
token_weights = dict(zip(indices, values)) | |
meaningful_tokens = {} | |
for token_id, weight in token_weights.items(): | |
decoded_token = tokenizer_splade_lexical.decode([token_id]) | |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
meaningful_tokens[decoded_token] = weight | |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) | |
formatted_output = "MLP Representation:\n\n" | |
if not sorted_representation: | |
formatted_output += "No significant terms found for this input.\n" | |
else: | |
# Changed to paragraph style | |
terms_list = [] | |
for term, weight in sorted_representation: | |
terms_list.append(f"**{term}**: {weight:.4f}") | |
formatted_output += ", ".join(terms_list) + "." | |
info_output = f"" # Line 1 | |
info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity) | |
return formatted_output, info_output | |
def get_splade_doc_representation(text): | |
if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits' | |
return "SPLADE-v3-Doc tokenizer is not loaded. Please check the console for loading errors.", "" | |
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True) | |
inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation | |
vocab_size = tokenizer_splade_doc.vocab_size | |
# Directly create the binary Bag-of-Words vector using the input_ids | |
binary_bow_vector = create_lexical_bow_mask( | |
inputs['input_ids'], vocab_size, tokenizer_splade_doc | |
).squeeze() # Squeeze back for single output | |
indices = torch.nonzero(binary_bow_vector).squeeze().cpu().tolist() | |
if not isinstance(indices, list): | |
indices = [indices] if indices else [] | |
values = [1.0] * len(indices) # All values are 1 for binary representation | |
token_weights = dict(zip(indices, values)) | |
meaningful_tokens = {} | |
for token_id, weight in token_weights.items(): | |
decoded_token = tokenizer_splade_doc.decode([token_id]) | |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
meaningful_tokens[decoded_token] = weight | |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity | |
formatted_output = "Binary Bag-of-Words Representation:\n\n" | |
if not sorted_representation: | |
formatted_output += "No significant terms found for this input.\n" | |
else: | |
# Changed to paragraph style | |
terms_list = [] | |
for term, _ in sorted_representation: # For binary, weight is always 1, so no need to display | |
terms_list.append(f"**{term}**") | |
formatted_output += ", ".join(terms_list) + "." | |
info_output = f"" # Line 1 | |
info_output += f"Total non-zero terms in vector: {len(indices)}" # Line 2 | |
return formatted_output, info_output | |
# --- Unified Prediction Function for the Explorer Tab --- | |
def predict_representation_explorer(model_choice, text): | |
if model_choice == "MLM encoder (SPLADE-cocondenser-distil)": | |
return get_splade_cocondenser_representation(text) | |
elif model_choice == "MLP encoder (SPLADE-v3-lexical)": | |
return get_splade_lexical_representation(text) | |
elif model_choice == "Binary Bag-of-Words": # Changed name | |
return get_splade_doc_representation(text) | |
else: | |
return "Please select a model.", "" # Return two empty strings for consistency | |
# --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) --- | |
# These functions remain unchanged from the previous iteration, as they return the raw tensors. | |
def get_splade_cocondenser_vector(text): | |
if tokenizer_splade is None or model_splade is None: | |
return None | |
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True) | |
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
output = model_splade(**inputs) | |
if hasattr(output, 'logits'): | |
splade_vector = torch.max( | |
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), | |
dim=1 | |
)[0].squeeze() | |
return splade_vector | |
return None | |
def get_splade_lexical_vector(text): | |
if tokenizer_splade_lexical is None or model_splade_lexical is None: | |
return None | |
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True) | |
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
output = model_splade_lexical(**inputs) | |
if hasattr(output, 'logits'): | |
splade_vector = torch.max( | |
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), | |
dim=1 | |
)[0].squeeze() | |
vocab_size = tokenizer_splade_lexical.vocab_size | |
bow_mask = create_lexical_bow_mask( | |
inputs['input_ids'], vocab_size, tokenizer_splade_lexical | |
).squeeze() | |
splade_vector = splade_vector * bow_mask | |
return splade_vector | |
return None | |
def get_splade_doc_vector(text): | |
if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits' | |
return None | |
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True) | |
inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation | |
vocab_size = tokenizer_splade_doc.vocab_size | |
# Directly create the binary Bag-of-Words vector using the input_ids | |
binary_bow_vector = create_lexical_bow_mask( | |
inputs['input_ids'], vocab_size, tokenizer_splade_doc | |
).squeeze() | |
return binary_bow_vector | |
# --- Function to get formatted representation from a raw vector and tokenizer --- | |
# This function remains unchanged as it's a generic formatter for any sparse vector. | |
def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False): | |
if splade_vector is None: | |
return "Failed to generate vector.", "" | |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() | |
if not isinstance(indices, list): | |
indices = [indices] if indices else [] | |
if is_binary: | |
values = [1.0] * len(indices) | |
else: | |
values = splade_vector[indices].cpu().tolist() | |
token_weights = dict(zip(indices, values)) | |
meaningful_tokens = {} | |
for token_id, weight in token_weights.items(): | |
decoded_token = tokenizer.decode([token_id]) | |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
meaningful_tokens[decoded_token] = weight | |
if is_binary: | |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for binary | |
else: | |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) | |
formatted_output = "" | |
if not sorted_representation: | |
formatted_output += "No significant terms found.\n" | |
else: | |
terms_list = [] | |
for i, (term, weight) in enumerate(sorted_representation): | |
if i >= 50: | |
terms_list.append(f"...and {len(sorted_representation) - 50} more terms.") | |
break | |
if is_binary: | |
terms_list.append(f"**{term}**") | |
else: | |
terms_list.append(f"**{term}**: {weight:.4f}") | |
formatted_output += ", ".join(terms_list) + "." | |
# This is the line that will now always be split into two | |
info_output = f"Total non-zero terms: {len(indices)}\n" # Line 1 | |
return formatted_output, info_output | |
# --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag --- | |
def get_model_assets(model_choice_str): | |
if model_choice_str == "MLM encoder (SPLADE-cocondenser-distil)": | |
return get_splade_cocondenser_vector, tokenizer_splade, False, "MLM encoder (SPLADE-cocondenser-distil)" | |
elif model_choice_str == "MLP encoder (SPLADE-v3-lexical)": | |
return get_splade_lexical_vector, tokenizer_splade_lexical, False, "MLP encoder (SPLADE-v3-lexical)" | |
elif model_choice_str == "Binary Bag-of-Words": | |
return get_splade_doc_vector, tokenizer_splade_doc, True, "Binary Bag-of-Words" | |
else: | |
return None, None, False, "Unknown Model" | |
# --- MODIFIED: Dot Product Calculation Function for the new tab --- | |
def calculate_dot_product_and_representations_independent(query_model_choice, doc_model_choice, query_text, doc_text): | |
query_vector_fn, query_tokenizer, query_is_binary, query_model_name_display = get_model_assets(query_model_choice) | |
doc_vector_fn, doc_tokenizer, doc_is_binary, doc_model_name_display = get_model_assets(doc_model_choice) | |
if query_vector_fn is None or doc_vector_fn is None: | |
return "Please select valid models for both query and document encoding.", "", "" | |
query_vector = query_vector_fn(query_text) | |
doc_vector = doc_vector_fn(doc_text) | |
if query_vector is None or doc_vector is None: | |
return "Failed to generate one or both vectors. Please check model loading and input text.", "", "" | |
# Calculate dot product | |
# Ensure both vectors are on CPU before dot product to avoid device mismatch issues | |
# and to ensure .item() works reliably for conversion to float. | |
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item()) | |
# Format representations - these functions now return two strings (main_output, info_output) | |
query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary) | |
doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary) | |
# Combine output into a single string for the Markdown component | |
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n" | |
full_output += "---\n\n" | |
# Query Representation | |
full_output += f"Query Representation ({query_model_name_display}):\n\n" | |
full_output += query_main_rep_str + "\n\n" + query_info_str # Added an extra newline for better spacing | |
full_output += "\n\n---\n\n" # Separator | |
# Document Representation | |
full_output += f"Document Representation ({doc_model_name_display}):\n\n" | |
full_output += doc_main_rep_str + "\n\n" + doc_info_str # Added an extra newline for better spacing | |
return full_output | |
# Global variable to store the share URL once the app is launched | |
global_share_url = None | |
# JavaScript to copy text to clipboard | |
# This will be injected into the Gradio UI | |
copy_to_clipboard_js = """ | |
async (url) => { | |
if (url) { | |
try { | |
await navigator.clipboard.writeText(url); | |
console.log('Share URL copied to clipboard:', url); | |
return 'Copied to clipboard!'; // Message to display to user | |
} catch (err) { | |
console.error('Failed to copy share URL:', err); | |
return 'Failed to copy!'; | |
} | |
} | |
return 'URL not available.'; | |
} | |
""" | |
def get_share_url_for_button(): | |
""" | |
Function to provide the share URL to the Gradio Textbox. | |
This will be called when the custom button is clicked. | |
""" | |
if global_share_url: | |
return global_share_url | |
else: | |
# In case the URL isn't set yet, provide a placeholder. | |
# This function runs server-side, so it won't have the client-side URL | |
# until the app has truly launched. | |
return "Generating share link..." | |
# --- Gradio Interface Setup with Tabs --- | |
with gr.Blocks(title="SPLADE Demos", css=css) as demo: | |
gr.Markdown("# 🌌 Sparse Encoder Playground") | |
gr.Markdown("Explore different SPLADE models and their sparse representation types, and calculate similarity between query and document representations.") | |
with gr.Tabs(): | |
with gr.TabItem("Sparse Representation"): | |
gr.Markdown("### Produce a Sparse Representation of an Input Text") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
model_radio = gr.Radio( | |
[ | |
"MLM encoder (SPLADE-cocondenser-distil)", | |
"MLP encoder (SPLADE-v3-lexical)", | |
"Binary Bag-of-Words" | |
], | |
label="Choose Sparse Encoder", | |
value="MLM encoder (SPLADE-cocondenser-distil)" | |
) | |
input_text = gr.Textbox( | |
lines=5, | |
label="Enter your query or document text here:", | |
placeholder="e.g., Why is Padua the nicest city in Italy?" | |
) | |
# Custom Share Button and URL display | |
with gr.Row(elem_classes="share-button-container"): | |
share_button = gr.Button( | |
"🔗 Get & Copy Share Link", | |
elem_classes="custom-share-button", | |
size="sm" | |
) | |
share_output_textbox = gr.Textbox( # Renamed for clarity | |
label="Share URL (Click button to copy)", | |
interactive=True, | |
visible=False, | |
placeholder="Click 'Get & Copy Share Link' to generate and copy URL..." | |
) | |
# A small text display to confirm copy | |
copy_confirmation_text = gr.Markdown(value="", visible=False) | |
info_output_display = gr.Markdown( | |
value="", | |
label="Vector Information", | |
elem_id="info_output_display" | |
) | |
with gr.Column(scale=2): | |
main_representation_output = gr.Markdown() | |
# Connect share button: | |
# 1. Populate the textbox with the share URL (from backend) | |
# 2. Make the textbox visible | |
# 3. Use JS to copy its content to clipboard | |
share_button.click( | |
fn=get_share_url_for_button, # This populates the textbox | |
outputs=share_output_textbox | |
).then( | |
fn=lambda: gr.update(visible=True), # This makes the textbox visible | |
outputs=share_output_textbox | |
).then( | |
fn=None, # This is a dummy function to allow JS to run | |
_js=f'(url) => {{ {copy_to_clipboard_js} (url); }}', # Pass the URL from the textbox to JS | |
inputs=[share_output_textbox], # The output of previous step (share_output_textbox) becomes input to this step | |
outputs=[copy_confirmation_text] # Output status of copy to confirmation text | |
).then( | |
fn=lambda msg: gr.update(value=msg, visible=True), # Show confirmation message | |
inputs=[copy_confirmation_text], | |
outputs=[copy_confirmation_text] | |
).then( | |
fn=lambda: gr.update(value="", visible=False), # Hide confirmation message after a short delay | |
outputs=[copy_confirmation_text], | |
queue=False, # Don't block other events | |
api_name="hide_confirmation_message", # Optional: give it an API name | |
every=2 # Hide after 2 seconds | |
) | |
# Connect the core prediction logic | |
model_radio.change( | |
fn=predict_representation_explorer, | |
inputs=[model_radio, input_text], | |
outputs=[main_representation_output, info_output_display] | |
) | |
input_text.change( | |
fn=predict_representation_explorer, | |
inputs=[model_radio, input_text], | |
outputs=[main_representation_output, info_output_display] | |
) | |
# Initial call to populate on load (optional, but good for demo) | |
demo.load( | |
fn=lambda: predict_representation_explorer(model_radio.value, input_text.value), | |
outputs=[main_representation_output, info_output_display] | |
) | |
with gr.TabItem("Compare Encoders"): | |
gr.Markdown("### Calculate Dot Product Similarity Between Encoded Query and Document") | |
model_choices = [ | |
"MLM encoder (SPLADE-cocondenser-distil)", | |
"MLP encoder (SPLADE-v3-lexical)", | |
"Binary Bag-of-Words" | |
] | |
gr.Interface( | |
fn=calculate_dot_product_and_representations_independent, | |
inputs=[ | |
gr.Radio( | |
model_choices, | |
label="Choose Query Encoding Model", | |
value="MLM encoder (SPLADE-cocondenser-distil)" | |
), | |
gr.Radio( | |
model_choices, | |
label="Choose Document Encoding Model", | |
value="MLM encoder (SPLADE-cocondenser-distil)" | |
), | |
gr.Textbox( | |
lines=3, | |
label="Enter Query Text:", | |
placeholder="e.g., famous dishes of Padua" | |
), | |
gr.Textbox( | |
lines=5, | |
label="Enter Document Text:", | |
placeholder="e.g., Padua's cuisine is as famous as its legendary University." | |
) | |
], | |
outputs=gr.Markdown(), | |
allow_flagging="never" | |
) | |
# This block ensures the share URL is captured when the app launches | |
if __name__ == "__main__": | |
# Launch and capture the share URL | |
# Gradio will print the share URL to the console automatically when share=True | |
# For more reliable programmatic access *within* the app (especially for older Gradio versions | |
# or more complex scenarios), you might need a custom endpoint or to rely on the | |
# client-side `Gradio.share_url` variable. | |
# For simple local demo, printing to console is often sufficient. | |
launched_info = demo.launch(share=True) | |
# In newer Gradio versions (e.g., 3.x and 4.x), the share_url is typically | |
# available via `launched_info.share_url` directly. | |
if hasattr(launched_info, 'share_url') and launched_info.share_url: | |
global_share_url = launched_info.share_url | |
print(f"\n--- Public Share URL (for your custom button): {global_share_url} ---\n") | |
else: | |
print("\n--- Public Share URL not available directly from launch info. " | |
"Please check your console for the Gradio-generated share link. ---\n") | |
print("\n--- Gradio App Running ---") | |
print("Open the link in your browser to interact with the app.") | |
print("---------------------------\n") |