import gradio as gr from transformers import AutoTokenizer, AutoModelForMaskedLM import torch import numpy as np from tqdm.auto import tqdm # Still useful for model loading progress if desired, but not strictly necessary for this simplified version import os # Still useful for general purpose, but not explicitly used in this simplified version # --- Model Loading --- tokenizer_splade = None model_splade = None tokenizer_splade_lexical = None model_splade_lexical = None tokenizer_splade_doc = None model_splade_doc = None # Load SPLADE v3 model (original) try: tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil") model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil") model_splade.eval() print("SPLADE-cocondenser-distil model loaded successfully!") except Exception as e: print(f"Error loading SPLADE-cocondenser-distil model: {e}") print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.") # Load SPLADE v3 Lexical model try: splade_lexical_model_name = "naver/splade-v3-lexical" tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name) model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name) model_splade_lexical.eval() print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!") except Exception as e: print(f"Error loading SPLADE-v3-Lexical model: {e}") print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).") # Load SPLADE v3 Doc model try: splade_doc_model_name = "naver/splade-v3-doc" tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name) model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name) model_splade_doc.eval() print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!") except Exception as e: print(f"Error loading SPLADE-v3-Doc model: {e}") print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).") # --- Helper function for lexical mask (now handles batches, but used for single input here) --- def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer): """ Creates a batch of lexical BOW masks. input_ids_batch: torch.Tensor of shape (batch_size, sequence_length) vocab_size: int, size of the tokenizer vocabulary tokenizer: the tokenizer object Returns: torch.Tensor of shape (batch_size, vocab_size) """ batch_size = input_ids_batch.shape[0] bow_masks = torch.zeros(batch_size, vocab_size, device=input_ids_batch.device) for i in range(batch_size): input_ids = input_ids_batch[i] # Get input_ids for the current item in the batch meaningful_token_ids = [] for token_id in input_ids.tolist(): if token_id not in [ tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.mask_token_id, tokenizer.unk_token_id ]: meaningful_token_ids.append(token_id) if meaningful_token_ids: # Apply mask to the current row in the batch bow_masks[i, list(set(meaningful_token_ids))] = 1 return bow_masks # --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) --- # These functions take single text input for the Explorer tab def get_splade_cocondenser_representation(text): if tokenizer_splade is None or model_splade is None: return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors." inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model_splade.device) for k, v in inputs.items()} with torch.no_grad(): output = model_splade(**inputs) if hasattr(output, 'logits'): splade_vector = torch.max( torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1 )[0].squeeze() # Squeeze is fine here as it's a single input else: return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found." indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() if not isinstance(indices, list): indices = [indices] if indices else [] values = splade_vector[indices].cpu().tolist() token_weights = dict(zip(indices, values)) meaningful_tokens = {} for token_id, weight in token_weights.items(): decoded_token = tokenizer_splade.decode([token_id]) if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: meaningful_tokens[decoded_token] = weight sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) formatted_output = "SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n" if not sorted_representation: formatted_output += "No significant terms found for this input.\n" else: for term, weight in sorted_representation: formatted_output += f"- **{term}**: {weight:.4f}\n" formatted_output += "\n--- Raw SPLADE Vector Info ---\n" formatted_output += f"Total non-zero terms in vector: {len(indices)}\n" formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n" return formatted_output def get_splade_lexical_representation(text): if tokenizer_splade_lexical is None or model_splade_lexical is None: return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors." inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()} with torch.no_grad(): output = model_splade_lexical(**inputs) if hasattr(output, 'logits'): splade_vector = torch.max( torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1 )[0].squeeze() # Squeeze is fine here else: return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found." # Always apply lexical mask for this model's specific behavior vocab_size = tokenizer_splade_lexical.vocab_size # Call with unsqueezed input_ids for single sample processing bow_mask = create_lexical_bow_mask( inputs['input_ids'], vocab_size, tokenizer_splade_lexical ).squeeze() # Squeeze back for single output splade_vector = splade_vector * bow_mask indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() if not isinstance(indices, list): indices = [indices] if indices else [] values = splade_vector[indices].cpu().tolist() token_weights = dict(zip(indices, values)) meaningful_tokens = {} for token_id, weight in token_weights.items(): decoded_token = tokenizer_splade_lexical.decode([token_id]) if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: meaningful_tokens[decoded_token] = weight sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) formatted_output = "SPLADE-v3-Lexical Representation (Weighting):\n" if not sorted_representation: formatted_output += "No significant terms found for this input.\n" else: for term, weight in sorted_representation: formatted_output += f"- **{term}**: {weight:.4f}\n" formatted_output += "\n--- Raw SPLADE Vector Info ---\n" formatted_output += f"Total non-zero terms in vector: {len(indices)}\n" formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_lexical.vocab_size):.2%}\n" return formatted_output def get_splade_doc_representation(text): if tokenizer_splade_doc is None or model_splade_doc is None: return "SPLADE-v3-Doc model is not loaded. Please check the console for loading errors." inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()} with torch.no_grad(): output = model_splade_doc(**inputs) if not hasattr(output, "logits"): return "Model output structure not as expected. 'logits' not found." vocab_size = tokenizer_splade_doc.vocab_size # Call with unsqueezed input_ids for single sample processing binary_splade_vector = create_lexical_bow_mask( inputs['input_ids'], vocab_size, tokenizer_splade_doc ).squeeze() # Squeeze back for single output indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist() if not isinstance(indices, list): indices = [indices] if indices else [] values = [1.0] * len(indices) # All values are 1 for binary representation token_weights = dict(zip(indices, values)) meaningful_tokens = {} for token_id, weight in token_weights.items(): decoded_token = tokenizer_splade_doc.decode([token_id]) if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: meaningful_tokens[decoded_token] = weight sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity formatted_output = "SPLADE-v3-Doc Representation (Binary):\n" if not sorted_representation: formatted_output += "No significant terms found for this input.\n" else: for i, (term, _) in enumerate(sorted_representation): if i >= 50: # Limit display for very long lists formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n" break formatted_output += f"- **{term}**\n" formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n" formatted_output += f"Total activated terms: {len(indices)}\n" formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n" return formatted_output # --- Unified Prediction Function for the Explorer Tab --- def predict_representation_explorer(model_choice, text): if model_choice == "SPLADE-cocondenser-distil (weighting and expansion)": return get_splade_cocondenser_representation(text) elif model_choice == "SPLADE-v3-Lexical (weighting)": return get_splade_lexical_representation(text) elif model_choice == "SPLADE-v3-Doc (binary)": return get_splade_doc_representation(text) else: return "Please select a model." # --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) --- # These functions remain unchanged from the previous iteration, as they return the raw tensors. def get_splade_cocondenser_vector(text): if tokenizer_splade is None or model_splade is None: return None inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model_splade.device) for k, v in inputs.items()} with torch.no_grad(): output = model_splade(**inputs) if hasattr(output, 'logits'): splade_vector = torch.max( torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1 )[0].squeeze() return splade_vector return None def get_splade_lexical_vector(text): if tokenizer_splade_lexical is None or model_splade_lexical is None: return None inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()} with torch.no_grad(): output = model_splade_lexical(**inputs) if hasattr(output, 'logits'): splade_vector = torch.max( torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1 )[0].squeeze() vocab_size = tokenizer_splade_lexical.vocab_size bow_mask = create_lexical_bow_mask( inputs['input_ids'], vocab_size, tokenizer_splade_lexical ).squeeze() splade_vector = splade_vector * bow_mask return splade_vector return None def get_splade_doc_vector(text): if tokenizer_splade_doc is None or model_splade_doc is None: return None inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()} with torch.no_grad(): output = model_splade_doc(**inputs) if hasattr(output, "logits"): vocab_size = tokenizer_splade_doc.vocab_size binary_splade_vector = create_lexical_bow_mask( inputs['input_ids'], vocab_size, tokenizer_splade_doc ).squeeze() return binary_splade_vector return None # --- Function to get formatted representation from a raw vector and tokenizer --- # This function remains unchanged as it's a generic formatter for any sparse vector. def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False): if splade_vector is None: return "Failed to generate vector." indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() if not isinstance(indices, list): indices = [indices] if indices else [] if is_binary: values = [1.0] * len(indices) else: values = splade_vector[indices].cpu().tolist() token_weights = dict(zip(indices, values)) meaningful_tokens = {} for token_id, weight in token_weights.items(): decoded_token = tokenizer.decode([token_id]) if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: meaningful_tokens[decoded_token] = weight if is_binary: sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for binary else: sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) formatted_output = "" if not sorted_representation: formatted_output += "No significant terms found.\n" else: for i, (term, weight) in enumerate(sorted_representation): if i >= 50 and is_binary: # Limit display for very long binary lists formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n" break if is_binary: formatted_output += f"- **{term}**\n" else: formatted_output += f"- **{term}**: {weight:.4f}\n" formatted_output += f"\nTotal non-zero terms: {len(indices)}\n" formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer.vocab_size):.2%}\n" return formatted_output # --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag --- def get_model_assets(model_choice_str): if model_choice_str == "SPLADE-cocondenser-distil (weighting and expansion)": return get_splade_cocondenser_vector, tokenizer_splade, False, "SPLADE-cocondenser-distil (Weighting and Expansion)" elif model_choice_str == "SPLADE-v3-Lexical (weighting)": return get_splade_lexical_vector, tokenizer_splade_lexical, False, "SPLADE-v3-Lexical (Weighting)" elif model_choice_str == "SPLADE-v3-Doc (binary)": return get_splade_doc_vector, tokenizer_splade_doc, True, "SPLADE-v3-Doc (Binary)" else: return None, None, False, "Unknown Model" # --- MODIFIED: Dot Product Calculation Function for the new tab --- def calculate_dot_product_and_representations_independent(query_model_choice, doc_model_choice, query_text, doc_text): query_vector_fn, query_tokenizer, query_is_binary, query_model_name_display = get_model_assets(query_model_choice) doc_vector_fn, doc_tokenizer, doc_is_binary, doc_model_name_display = get_model_assets(doc_model_choice) if query_vector_fn is None or doc_vector_fn is None: return "Please select valid models for both query and document encoding.", "", "" query_vector = query_vector_fn(query_text) doc_vector = doc_vector_fn(doc_text) if query_vector is None or doc_vector is None: return "Failed to generate one or both vectors. Please check model loading and input text.", "", "" # Calculate dot product # Ensure both vectors are on CPU before dot product to avoid device mismatch issues # and to ensure .item() works reliably for conversion to float. dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item()) # Format representations query_rep_str = f"Query Representation ({query_model_name_display}):\n" query_rep_str += format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary) doc_rep_str = f"Document Representation ({doc_model_name_display}):\n" doc_rep_str += format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary) # Combine output full_output = f"### Dot Product Score: {dot_product:.6f}\n\n" full_output += "---\n\n" full_output += f"{query_rep_str}\n\n---\n\n{doc_rep_str}" return full_output # --- Gradio Interface Setup with Tabs --- with gr.Blocks(title="SPLADE Demos") as demo: gr.Markdown("# 🌌 Sparse Encoder Playground") # Updated title gr.Markdown("Explore different SPLADE models and their sparse representation types, and calculate similarity between query and document representations.") # Updated description with gr.Tabs(): with gr.TabItem("Sparse Representation"): gr.Markdown("### Produce a Sparse Representation of of an Input Text") gr.Interface( fn=predict_representation_explorer, inputs=[ gr.Radio( [ "MLM encoder (SPLADE-cocondenser-distil)", "MLP encoder (SPLADE-v3-lexical)", "Binary Encoder" ], label="Choose Sparse Encoder", value="SPLADE-cocondenser-distil (weighting and expansion)" ), gr.Textbox( lines=5, label="Enter your query or document text here:", placeholder="e.g., Why is Padua the nicest city in Italy?" ) ], outputs=gr.Markdown(), allow_flagging="never", # live=True # Setting live=True might be slow for complex models on every keystroke ) with gr.TabItem("Compare Encoders"): # NEW TAB gr.Markdown("### Calculate Dot Product Similarity between Query and Document") gr.Markdown("Select **independent** SPLADE models to encode your query and document, then see their sparse representations and their similarity score.") # Define the common model choices for cleaner code model_choices = [ "SPLADE-cocondenser-distil (weighting and expansion)", "SPLADE-v3-Lexical (weighting)", "SPLADE-v3-Doc (binary)" ] gr.Interface( fn=calculate_dot_product_and_representations_independent, # MODIFIED FUNCTION NAME inputs=[ gr.Radio( model_choices, label="Choose Query Encoding Model", value="SPLADE-cocondenser-distil (weighting and expansion)" # Default value ), gr.Radio( model_choices, label="Choose Document Encoding Model", value="SPLADE-cocondenser-distil (weighting and expansion)" # Default value ), gr.Textbox( lines=3, label="Enter Query Text:", placeholder="e.g., famous dishes of Padua" ), gr.Textbox( lines=5, label="Enter Document Text:", placeholder="e.g., Padua's cuisine is as famous as its legendary University." ) ], outputs=gr.Markdown(), allow_flagging="never" ) demo.launch()