Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel | |
import torch | |
# --- Model Loading --- | |
tokenizer_splade = None | |
model_splade = None | |
tokenizer_splade_lexical = None | |
model_splade_lexical = None | |
# Load SPLADE v3 model (original) | |
try: | |
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil") | |
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil") | |
model_splade.eval() # Set to evaluation mode for inference | |
print("SPLADE v3 (cocondenser) model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading SPLADE (cocondenser) model: {e}") | |
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.") | |
# Load SPLADE v3 Lexical model | |
try: | |
splade_lexical_model_name = "naver/splade-v3-lexical" | |
tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name) | |
model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name) | |
model_splade_lexical.eval() # Set to evaluation mode for inference | |
print(f"SPLADE v3 Lexical model '{splade_lexical_model_name}' loaded successfully!") | |
except Exception as e: | |
print(f"Error loading SPLADE v3 Lexical model: {e}") | |
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).") | |
# --- Helper function for lexical mask --- | |
def create_lexical_bow_mask(input_ids, vocab_size, tokenizer): | |
""" | |
Creates a binary bag-of-words mask from input_ids, | |
zeroing out special tokens and padding. | |
""" | |
# Initialize a zero vector for the entire vocabulary | |
bow_mask = torch.zeros(vocab_size, device=input_ids.device) | |
# Get unique token IDs from the input, excluding special tokens | |
# input_ids is typically [batch_size, seq_len], we assume batch_size=1 | |
meaningful_token_ids = [] | |
for token_id in input_ids.squeeze().tolist(): # Squeeze to remove batch dim and convert to list | |
if token_id not in [ | |
tokenizer.pad_token_id, | |
tokenizer.cls_token_id, | |
tokenizer.sep_token_id, | |
tokenizer.mask_token_id, | |
tokenizer.unk_token_id # Also exclude unknown tokens | |
]: | |
meaningful_token_ids.append(token_id) | |
# Set 1 for tokens present in the original input | |
if meaningful_token_ids: | |
bow_mask[list(set(meaningful_token_ids))] = 1 # Use set to handle duplicates | |
return bow_mask.unsqueeze(0) # Keep batch dimension for consistency | |
# --- Core Representation Functions --- | |
def get_splade_representation(text): | |
if tokenizer_splade is None or model_splade is None: | |
return "SPLADE (cocondenser) model is not loaded. Please check the console for loading errors." | |
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True) | |
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
output = model_splade(**inputs) | |
if hasattr(output, 'logits'): | |
splade_vector = torch.max( | |
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), | |
dim=1 | |
)[0].squeeze() | |
else: | |
return "Model output structure not as expected for SPLADE (cocondenser). 'logits' not found." | |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() | |
if not isinstance(indices, list): | |
indices = [indices] | |
values = splade_vector[indices].cpu().tolist() | |
token_weights = dict(zip(indices, values)) | |
meaningful_tokens = {} | |
for token_id, weight in token_weights.items(): | |
decoded_token = tokenizer_splade.decode([token_id]) | |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
meaningful_tokens[decoded_token] = weight | |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) | |
formatted_output = "SPLADE (cocondenser) Representation (All Non-Zero Terms):\n" | |
if not sorted_representation: | |
formatted_output += "No significant terms found for this input.\n" | |
else: | |
for term, weight in sorted_representation: | |
formatted_output += f"- **{term}**: {weight:.4f}\n" | |
formatted_output += "\n--- Raw SPLADE Vector Info ---\n" | |
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n" | |
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n" | |
return formatted_output | |
def get_splade_lexical_representation(text, apply_lexical_mask: bool): # Added parameter | |
if tokenizer_splade_lexical is None or model_splade_lexical is None: | |
return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors." | |
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True) | |
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
output = model_splade_lexical(**inputs) | |
if hasattr(output, 'logits'): | |
splade_vector = torch.max( | |
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), | |
dim=1 | |
)[0].squeeze() | |
else: | |
return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found." | |
# --- Apply Lexical Mask if requested --- | |
if apply_lexical_mask: | |
# Get the vocabulary size from the tokenizer | |
vocab_size = tokenizer_splade_lexical.vocab_size | |
# Create the Bag-of-Words mask | |
bow_mask = create_lexical_bow_mask( | |
inputs['input_ids'], vocab_size, tokenizer_splade_lexical | |
).squeeze() # Squeeze to match splade_vector's [vocab_size] shape | |
# Multiply the SPLADE vector by the BoW mask to zero out expanded terms | |
splade_vector = splade_vector * bow_mask | |
# --- End Lexical Mask Logic --- | |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() | |
if not isinstance(indices, list): | |
indices = [indices] | |
values = splade_vector[indices].cpu().tolist() | |
token_weights = dict(zip(indices, values)) | |
meaningful_tokens = {} | |
for token_id, weight in token_weights.items(): | |
decoded_token = tokenizer_splade_lexical.decode([token_id]) | |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
meaningful_tokens[decoded_token] = weight | |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) | |
formatted_output = "SPLADE v3 Lexical Representation (All Non-Zero Terms):\n" | |
if not sorted_representation: | |
formatted_output += "No significant terms found for this input.\n" | |
else: | |
for term, weight in sorted_representation: | |
formatted_output += f"- **{term}**: {weight:.4f}\n" | |
formatted_output += "\n--- Raw SPLADE Vector Info ---\n" | |
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n" | |
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_lexical.vocab_size):.2%}\n" | |
return formatted_output | |
# --- Unified Prediction Function for Gradio --- | |
def predict_representation(model_choice, text): | |
if model_choice == "SPLADE (cocondenser)": | |
return get_splade_representation(text) | |
elif model_choice == "SPLADE-v3-Lexical (with expansion)": | |
# Call the lexical function without applying the mask | |
return get_splade_lexical_representation(text, apply_lexical_mask=False) | |
elif model_choice == "SPLADE-v3-Lexical (lexical-only)": | |
# Call the lexical function applying the mask | |
return get_splade_lexical_representation(text, apply_lexical_mask=True) | |
else: | |
return "Please select a model." | |
# --- Gradio Interface Setup --- | |
demo = gr.Interface( | |
fn=predict_representation, | |
inputs=[ | |
gr.Radio( | |
[ | |
"SPLADE (cocondenser)", | |
"SPLADE-v3-Lexical (with expansion)", # Option to see full neural output | |
"SPLADE-v3-Lexical (lexical-only)" # Option with lexical mask applied | |
], | |
label="Choose Representation Model", | |
value="SPLADE (cocondenser)" # Default selection | |
), | |
gr.Textbox( | |
lines=5, | |
label="Enter your query or document text here:", | |
placeholder="e.g., Why is Padua the nicest city in Italy?" | |
) | |
], | |
outputs=gr.Markdown(), | |
title="🌌 Sparse Representation Generator", | |
description="Enter any text to see its SPLADE sparse vector. Explore the difference between full neural expansion and lexical-only representations.", | |
allow_flagging="never" | |
) | |
# Launch the Gradio app | |
demo.launch() |