Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
import torch | |
# --- Model Loading --- | |
# Load SPLADE v3 model and tokenizer | |
# "naver/splade-v3" is a common and robust SPLADE v3 model. | |
# Make sure you've accepted any user access agreements on its Hugging Face Hub page. | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil") | |
model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil") | |
model.eval() # Set the model to evaluation mode for inference | |
print("SPLADE v3 model and tokenizer loaded successfully!") | |
except Exception as e: | |
print(f"Error loading SPLADE model or tokenizer: {e}") | |
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-v3' (https://huggingface.co/naver/splade-v3).") | |
print("If the problem persists, check your internet connection or try a different SPLADE model if available.") | |
tokenizer = None | |
model = None | |
# --- Core SPLADE Representation Function --- | |
def get_splade_representation(text): | |
if tokenizer is None or model is None: | |
return "SPLADE model is not loaded. Please check the console for loading errors." | |
# Tokenize the input text | |
# return_tensors="pt" ensures PyTorch tensors are returned. | |
# padding=True pads to the longest sequence in the batch (though for single input, it's just the input length). | |
# truncation=True truncates if the text is too long for the model's max input size. | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
# Move inputs to the same device as the model (e.g., CPU or GPU) | |
# This is important if you were running on a GPU in a production environment. | |
# For Hugging Face Spaces free tier, it's usually CPU. | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
# Get the model's output without calculating gradients (inference mode) | |
with torch.no_grad(): | |
output = model(**inputs) | |
# Extract the logits from the model's output. | |
# SPLADE uses the masked language modeling head's logits to derive term importance. | |
# Apply the SPLADE aggregation function: log(1 + ReLU(logits)) | |
# This transforms the raw logits into a sparse vector where higher values indicate more importance. | |
# The attention_mask ensures we only consider actual tokens, not padding. | |
# Check if 'logits' is in the output (standard for AutoModelForMaskedLM) | |
if hasattr(output, 'logits'): | |
# Apply the SPLADE transformation | |
# output.logits is typically [batch_size, sequence_length, vocab_size] | |
# We need to take the max over the sequence_length dimension to get a [batch_size, vocab_size] vector. | |
# inputs.attention_mask.unsqueeze(-1) expands the mask to match vocab_size for element-wise multiplication. | |
splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs.attention_mask.unsqueeze(-1), dim=1)[0].squeeze() | |
else: | |
# Fallback/error message if the output structure is unexpected | |
return "Model output structure not as expected for SPLADE. 'logits' not found." | |
# Convert the sparse vector to a human-readable format. | |
# We only care about the non-zero (or very small) entries, as they represent activated terms. | |
# Get the indices (token IDs) of the non-zero elements in the SPLADE vector | |
# torch.nonzero returns coordinates of non-zero elements. squeeze() removes dimensions of size 1. | |
# .cpu().tolist() moves the tensor to CPU and converts to a Python list. | |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() | |
# If it's a single index (e.g., a very short text), make it a list for consistent processing | |
if not isinstance(indices, list): | |
indices = [indices] | |
# Get the corresponding values (weights) for these non-zero indices | |
values = splade_vector[indices].cpu().tolist() | |
# Create a dictionary mapping token ID to its weight | |
token_weights = dict(zip(indices, values)) | |
# Decode token IDs back to actual words/subwords | |
# Filter out common special tokens that are not meaningful for retrieval (e.g., [CLS], [SEP], [PAD]) | |
# You can add more tokens to this list if they appear frequently and are not helpful. | |
meaningful_tokens = {} | |
for token_id, weight in token_weights.items(): | |
decoded_token = tokenizer.decode([token_id]) | |
# Filter out special tokens or very short/noisy tokens | |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
meaningful_tokens[decoded_token] = weight | |
# Sort the meaningful tokens by their weight in descending order | |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) | |
# Format the output for display | |
formatted_output = "SPLADE Representation (Top 20 Terms):\n" | |
if not sorted_representation: | |
formatted_output += "No significant terms found for this input.\n" | |
else: | |
for i, (term, weight) in enumerate(sorted_representation): | |
if i >= 20: # Limit to top 20 terms for readability | |
break | |
formatted_output += f"- **{term}**: {weight:.4f}\n" | |
formatted_output += "\n--- Raw SPLADE Vector Info ---\n" | |
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n" | |
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer.vocab_size):.2%}\n" # Calculate sparsity | |
return formatted_output | |
# --- Gradio Interface Setup --- | |
demo = gr.Interface( | |
fn=get_splade_representation, | |
inputs=gr.Textbox( | |
lines=5, | |
label="Enter your query or document text here:", | |
placeholder="e.g., Why is Padua the nicest city in Italy?" | |
), | |
outputs=gr.Markdown(), # Use Markdown for richer text formatting (bolding terms) | |
title="🌌 SPLADE v3 Sparse Representation Generator", | |
description="Enter any text (query or document) to see its SPLADE v3 sparse vector representation. The output highlights the most important terms with their learned weights.", | |
allow_flagging="never" # Disable flagging for this demo | |
) | |
# Launch the Gradio app | |
demo.launch() |