File size: 6,237 Bytes
a4de739
3035463
 
a4de739
3035463
91fb382
3035463
 
 
0ad0578
 
3035463
e47f5f8
3035463
 
 
 
 
 
a4de739
3035463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1728ea0
3035463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afa84d6
3035463
 
91fb382
 
3035463
 
 
 
a4de739
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

# --- Model Loading ---
# Load SPLADE v3 model and tokenizer
# "naver/splade-v3" is a common and robust SPLADE v3 model.
# Make sure you've accepted any user access agreements on its Hugging Face Hub page.
try:
    tokenizer = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
    model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
    model.eval() # Set the model to evaluation mode for inference
    print("SPLADE v3 model and tokenizer loaded successfully!")
except Exception as e:
    print(f"Error loading SPLADE model or tokenizer: {e}")
    print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-v3' (https://huggingface.co/naver/splade-v3).")
    print("If the problem persists, check your internet connection or try a different SPLADE model if available.")
    tokenizer = None
    model = None

# --- Core SPLADE Representation Function ---
def get_splade_representation(text):
    if tokenizer is None or model is None:
        return "SPLADE model is not loaded. Please check the console for loading errors."

    # Tokenize the input text
    # return_tensors="pt" ensures PyTorch tensors are returned.
    # padding=True pads to the longest sequence in the batch (though for single input, it's just the input length).
    # truncation=True truncates if the text is too long for the model's max input size.
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

    # Move inputs to the same device as the model (e.g., CPU or GPU)
    # This is important if you were running on a GPU in a production environment.
    # For Hugging Face Spaces free tier, it's usually CPU.
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Get the model's output without calculating gradients (inference mode)
    with torch.no_grad():
        output = model(**inputs)

    # Extract the logits from the model's output.
    # SPLADE uses the masked language modeling head's logits to derive term importance.
    # Apply the SPLADE aggregation function: log(1 + ReLU(logits))
    # This transforms the raw logits into a sparse vector where higher values indicate more importance.
    # The attention_mask ensures we only consider actual tokens, not padding.
    
    # Check if 'logits' is in the output (standard for AutoModelForMaskedLM)
    if hasattr(output, 'logits'):
        # Apply the SPLADE transformation
        # output.logits is typically [batch_size, sequence_length, vocab_size]
        # We need to take the max over the sequence_length dimension to get a [batch_size, vocab_size] vector.
        # inputs.attention_mask.unsqueeze(-1) expands the mask to match vocab_size for element-wise multiplication.
        splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
    else:
        # Fallback/error message if the output structure is unexpected
        return "Model output structure not as expected for SPLADE. 'logits' not found."

    # Convert the sparse vector to a human-readable format.
    # We only care about the non-zero (or very small) entries, as they represent activated terms.
    
    # Get the indices (token IDs) of the non-zero elements in the SPLADE vector
    # torch.nonzero returns coordinates of non-zero elements. squeeze() removes dimensions of size 1.
    # .cpu().tolist() moves the tensor to CPU and converts to a Python list.
    indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()

    # If it's a single index (e.g., a very short text), make it a list for consistent processing
    if not isinstance(indices, list):
        indices = [indices]
        
    # Get the corresponding values (weights) for these non-zero indices
    values = splade_vector[indices].cpu().tolist()

    # Create a dictionary mapping token ID to its weight
    token_weights = dict(zip(indices, values))

    # Decode token IDs back to actual words/subwords
    # Filter out common special tokens that are not meaningful for retrieval (e.g., [CLS], [SEP], [PAD])
    # You can add more tokens to this list if they appear frequently and are not helpful.
    meaningful_tokens = {}
    for token_id, weight in token_weights.items():
        decoded_token = tokenizer.decode([token_id])
        # Filter out special tokens or very short/noisy tokens
        if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
            meaningful_tokens[decoded_token] = weight

    # Sort the meaningful tokens by their weight in descending order
    sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)

    # Format the output for display
    formatted_output = "SPLADE Representation (Top 20 Terms):\n"
    if not sorted_representation:
        formatted_output += "No significant terms found for this input.\n"
    else:
        for i, (term, weight) in enumerate(sorted_representation):
            if i >= 20: # Limit to top 20 terms for readability
                break
            formatted_output += f"- **{term}**: {weight:.4f}\n"
    
    formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
    formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
    formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer.vocab_size):.2%}\n" # Calculate sparsity

    return formatted_output

# --- Gradio Interface Setup ---
demo = gr.Interface(
    fn=get_splade_representation,
    inputs=gr.Textbox(
        lines=5,
        label="Enter your query or document text here:",
        placeholder="Why is Padua the nicest city in Italy?"
    ),
    outputs=gr.Markdown(), # Use Markdown for richer text formatting (bolding terms)
    title="🌌 SPLADE v3 Sparse Representation Generator",
    description="Enter any text (query or document) to see its SPLADE v3 sparse vector representation. The output highlights the most important terms with their learned weights.",
    allow_flagging="never" # Disable flagging for this demo
)

# Launch the Gradio app
demo.launch()