import gradio as gr from transformers import AutoTokenizer, AutoModelForMaskedLM import torch # --- Model Loading --- # Load SPLADE v3 model and tokenizer # "naver/splade-v3" is a common and robust SPLADE v3 model. # Make sure you've accepted any user access agreements on its Hugging Face Hub page. try: tokenizer = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil") model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil") model.eval() # Set the model to evaluation mode for inference print("SPLADE v3 model and tokenizer loaded successfully!") except Exception as e: print(f"Error loading SPLADE model or tokenizer: {e}") print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-v3' (https://huggingface.co/naver/splade-v3).") print("If the problem persists, check your internet connection or try a different SPLADE model if available.") tokenizer = None model = None # --- Core SPLADE Representation Function --- def get_splade_representation(text): if tokenizer is None or model is None: return "SPLADE model is not loaded. Please check the console for loading errors." # Tokenize the input text # return_tensors="pt" ensures PyTorch tensors are returned. # padding=True pads to the longest sequence in the batch (though for single input, it's just the input length). # truncation=True truncates if the text is too long for the model's max input size. inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) # Move inputs to the same device as the model (e.g., CPU or GPU) # This is important if you were running on a GPU in a production environment. # For Hugging Face Spaces free tier, it's usually CPU. inputs = {k: v.to(model.device) for k, v in inputs.items()} # Get the model's output without calculating gradients (inference mode) with torch.no_grad(): output = model(**inputs) # Extract the logits from the model's output. # SPLADE uses the masked language modeling head's logits to derive term importance. # Apply the SPLADE aggregation function: log(1 + ReLU(logits)) # This transforms the raw logits into a sparse vector where higher values indicate more importance. # The attention_mask ensures we only consider actual tokens, not padding. # Check if 'logits' is in the output (standard for AutoModelForMaskedLM) if hasattr(output, 'logits'): # Apply the SPLADE transformation # output.logits is typically [batch_size, sequence_length, vocab_size] # We need to take the max over the sequence_length dimension to get a [batch_size, vocab_size] vector. # inputs.attention_mask.unsqueeze(-1) expands the mask to match vocab_size for element-wise multiplication. splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs.attention_mask.unsqueeze(-1), dim=1)[0].squeeze() else: # Fallback/error message if the output structure is unexpected return "Model output structure not as expected for SPLADE. 'logits' not found." # Convert the sparse vector to a human-readable format. # We only care about the non-zero (or very small) entries, as they represent activated terms. # Get the indices (token IDs) of the non-zero elements in the SPLADE vector # torch.nonzero returns coordinates of non-zero elements. squeeze() removes dimensions of size 1. # .cpu().tolist() moves the tensor to CPU and converts to a Python list. indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() # If it's a single index (e.g., a very short text), make it a list for consistent processing if not isinstance(indices, list): indices = [indices] # Get the corresponding values (weights) for these non-zero indices values = splade_vector[indices].cpu().tolist() # Create a dictionary mapping token ID to its weight token_weights = dict(zip(indices, values)) # Decode token IDs back to actual words/subwords # Filter out common special tokens that are not meaningful for retrieval (e.g., [CLS], [SEP], [PAD]) # You can add more tokens to this list if they appear frequently and are not helpful. meaningful_tokens = {} for token_id, weight in token_weights.items(): decoded_token = tokenizer.decode([token_id]) # Filter out special tokens or very short/noisy tokens if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: meaningful_tokens[decoded_token] = weight # Sort the meaningful tokens by their weight in descending order sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) # Format the output for display formatted_output = "SPLADE Representation (Top 20 Terms):\n" if not sorted_representation: formatted_output += "No significant terms found for this input.\n" else: for i, (term, weight) in enumerate(sorted_representation): if i >= 20: # Limit to top 20 terms for readability break formatted_output += f"- **{term}**: {weight:.4f}\n" formatted_output += "\n--- Raw SPLADE Vector Info ---\n" formatted_output += f"Total non-zero terms in vector: {len(indices)}\n" formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer.vocab_size):.2%}\n" # Calculate sparsity return formatted_output # --- Gradio Interface Setup --- demo = gr.Interface( fn=get_splade_representation, inputs=gr.Textbox( lines=5, label="Enter your query or document text here:", placeholder="e.g., Why is Padua the nicest city in Italy?" ), outputs=gr.Markdown(), # Use Markdown for richer text formatting (bolding terms) title="🌌 SPLADE v3 Sparse Representation Generator", description="Enter any text (query or document) to see its SPLADE v3 sparse vector representation. The output highlights the most important terms with their learned weights.", allow_flagging="never" # Disable flagging for this demo ) # Launch the Gradio app demo.launch()