TextLSRDemo / app.py
SiddharthAK's picture
Update app.py
1728ea0 verified
raw
history blame
6.24 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
# --- Model Loading ---
# Load SPLADE v3 model and tokenizer
# "naver/splade-v3" is a common and robust SPLADE v3 model.
# Make sure you've accepted any user access agreements on its Hugging Face Hub page.
try:
tokenizer = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
model.eval() # Set the model to evaluation mode for inference
print("SPLADE v3 model and tokenizer loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE model or tokenizer: {e}")
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-v3' (https://huggingface.co/naver/splade-v3).")
print("If the problem persists, check your internet connection or try a different SPLADE model if available.")
tokenizer = None
model = None
# --- Core SPLADE Representation Function ---
def get_splade_representation(text):
if tokenizer is None or model is None:
return "SPLADE model is not loaded. Please check the console for loading errors."
# Tokenize the input text
# return_tensors="pt" ensures PyTorch tensors are returned.
# padding=True pads to the longest sequence in the batch (though for single input, it's just the input length).
# truncation=True truncates if the text is too long for the model's max input size.
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
# Move inputs to the same device as the model (e.g., CPU or GPU)
# This is important if you were running on a GPU in a production environment.
# For Hugging Face Spaces free tier, it's usually CPU.
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Get the model's output without calculating gradients (inference mode)
with torch.no_grad():
output = model(**inputs)
# Extract the logits from the model's output.
# SPLADE uses the masked language modeling head's logits to derive term importance.
# Apply the SPLADE aggregation function: log(1 + ReLU(logits))
# This transforms the raw logits into a sparse vector where higher values indicate more importance.
# The attention_mask ensures we only consider actual tokens, not padding.
# Check if 'logits' is in the output (standard for AutoModelForMaskedLM)
if hasattr(output, 'logits'):
# Apply the SPLADE transformation
# output.logits is typically [batch_size, sequence_length, vocab_size]
# We need to take the max over the sequence_length dimension to get a [batch_size, vocab_size] vector.
# inputs.attention_mask.unsqueeze(-1) expands the mask to match vocab_size for element-wise multiplication.
splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
else:
# Fallback/error message if the output structure is unexpected
return "Model output structure not as expected for SPLADE. 'logits' not found."
# Convert the sparse vector to a human-readable format.
# We only care about the non-zero (or very small) entries, as they represent activated terms.
# Get the indices (token IDs) of the non-zero elements in the SPLADE vector
# torch.nonzero returns coordinates of non-zero elements. squeeze() removes dimensions of size 1.
# .cpu().tolist() moves the tensor to CPU and converts to a Python list.
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
# If it's a single index (e.g., a very short text), make it a list for consistent processing
if not isinstance(indices, list):
indices = [indices]
# Get the corresponding values (weights) for these non-zero indices
values = splade_vector[indices].cpu().tolist()
# Create a dictionary mapping token ID to its weight
token_weights = dict(zip(indices, values))
# Decode token IDs back to actual words/subwords
# Filter out common special tokens that are not meaningful for retrieval (e.g., [CLS], [SEP], [PAD])
# You can add more tokens to this list if they appear frequently and are not helpful.
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer.decode([token_id])
# Filter out special tokens or very short/noisy tokens
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
# Sort the meaningful tokens by their weight in descending order
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
# Format the output for display
formatted_output = "SPLADE Representation (Top 20 Terms):\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
for i, (term, weight) in enumerate(sorted_representation):
if i >= 20: # Limit to top 20 terms for readability
break
formatted_output += f"- **{term}**: {weight:.4f}\n"
formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer.vocab_size):.2%}\n" # Calculate sparsity
return formatted_output
# --- Gradio Interface Setup ---
demo = gr.Interface(
fn=get_splade_representation,
inputs=gr.Textbox(
lines=5,
label="Enter your query or document text here:",
placeholder="Why is Padua the nicest city in Italy?"
),
outputs=gr.Markdown(), # Use Markdown for richer text formatting (bolding terms)
title="🌌 SPLADE v3 Sparse Representation Generator",
description="Enter any text (query or document) to see its SPLADE v3 sparse vector representation. The output highlights the most important terms with their learned weights.",
allow_flagging="never" # Disable flagging for this demo
)
# Launch the Gradio app
demo.launch()