File size: 12,500 Bytes
a4de739
da3acda
3035463
a4de739
3035463
da3acda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3035463
da3acda
 
 
 
 
 
 
 
3035463
da3acda
 
 
 
 
a4de739
3035463
da3acda
3035463
 
da3acda
 
3035463
 
da3acda
 
3035463
1728ea0
3035463
 
 
 
 
 
da3acda
3035463
 
 
 
 
da3acda
3035463
 
 
 
 
 
 
 
 
 
da3acda
3035463
 
da3acda
3035463
 
da3acda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3035463
 
 
da3acda
 
 
 
 
 
 
 
 
 
3035463
 
da3acda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3035463
 
 
a4de739
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel
import torch

# --- Model Loading ---
tokenizer_splade = None
model_splade = None
tokenizer_unicoil = None
model_unicoil = None

# Load SPLADE v3 model
try:
    tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
    model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
    model_splade.eval() # Set to evaluation mode for inference
    print("SPLADE v3 model loaded successfully!")
except Exception as e:
    print(f"Error loading SPLADE model: {e}")
    print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")

# Load UNICOIL model for binary sparse encoding
try:
    # UNICOIL models are typically just AutoModel as they add a linear layer
    # on top of a BERT-like encoder to predict weights.
    # 'castorini/unicoil-msmarco-passage' is a common UNICOIL checkpoint.
    unicoil_model_name = "castorini/unicoil-msmarco-passage"
    tokenizer_unicoil = AutoTokenizer.from_pretrained(unicoil_model_name)
    model_unicoil = AutoModel.from_pretrained(unicoil_model_name)
    model_unicoil.eval() # Set to evaluation mode for inference
    print(f"UNICOIL model '{unicoil_model_name}' loaded successfully!")
except Exception as e:
    print(f"Error loading UNICOIL model: {e}")
    print(f"Please ensure '{unicoil_model_name}' is accessible (check Hugging Face Hub for potential agreements).")


# --- Core Representation Functions ---

def get_splade_representation(text):
    if tokenizer_splade is None or model_splade is None:
        return "SPLADE model is not loaded. Please check the console for loading errors."

    inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}

    with torch.no_grad():
        output = model_splade(**inputs)

    if hasattr(output, 'logits'):
        splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
    else:
        return "Model output structure not as expected for SPLADE. 'logits' not found."

    indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
    if not isinstance(indices, list):
        indices = [indices]
            
    values = splade_vector[indices].cpu().tolist()
    token_weights = dict(zip(indices, values))

    meaningful_tokens = {}
    for token_id, weight in token_weights.items():
        decoded_token = tokenizer_splade.decode([token_id])
        if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
            meaningful_tokens[decoded_token] = weight

    sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)

    formatted_output = "SPLADE Representation (Top 20 Terms):\n"
    if not sorted_representation:
        formatted_output += "No significant terms found for this input.\n"
    else:
        for i, (term, weight) in enumerate(sorted_representation):
            if i >= 20:
                break
            formatted_output += f"- **{term}**: {weight:.4f}\n"
        
    formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
    formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
    formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n"

    return formatted_output


def get_unicoil_binary_representation(text):
    if tokenizer_unicoil is None or model_unicoil is None:
        return "UNICOIL model is not loaded. Please check the console for loading errors."

    inputs = tokenizer_unicoil(text, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(model_unicoil.device) for k, v in inputs.items()}

    with torch.no_grad():
        # UNICOIL models often output a dictionary where 'token_scores' or similar
        # contain the learned weights for each token. The structure can vary.
        # For 'castorini/unicoil-msmarco-passage', the token scores are typically
        # the last hidden state of the model, which is then mapped by a linear layer
        # into the sparse weights. We might need to manually extract those,
        # or the model itself might be set up to produce the weights directly.
        # Based on typical UNICOIL implementations, we usually take the output
        # from the last layer and map it to vocabulary size.
        
        # In many UNICOIL variations, the model itself is designed to output
        # the "re-weight" scores for each token. Let's assume for simplicity
        # that the model's forward pass returns something that can be interpreted
        # as per-token scores, perhaps in `output.last_hidden_state`
        # and then a simple linear layer (not part of the AutoModel usually)
        # would project this to vocab size.
        # For simplicity and to fit `AutoModel`, we'll treat the last hidden state
        # directly as the basis for term importance for now, which is common in similar models,
        # or if the model already has a head, we use it.
        
        # A more robust UNICOIL implementation would involve a specific head
        # if not using AutoModelForMaskedLM. However, AutoModel gives us the
        # last hidden states from which we can infer.
        
        # For UNICOIL, we're interested in the weighted token scores.
        # `model(**inputs)` will typically return a `BaseModelOutput`
        # or `MaskedLMOutput` if it's based on an MLM.
        # Let's assume the model's output provides the token importance.
        
        # A common way to get UNICOIL scores if not explicitly provided as logits:
        # It's usually a linear layer on top of the last hidden state.
        # Since AutoModel just gives the base model, we'll mimic the output
        # as a direct mapping if the model doesn't have a specific head for scores.
        # However, looking at `castorini/unicoil-msmarco-passage`
        # its `config.json` might give hints or the model itself is structured.
        # Often, it uses `BertForMaskedLM` and then applies `log(1+relu)` to the logits.
        # Let's assume it behaves similar to SPLADE for simplicity of extraction for now,
        # or we might need to load it as `AutoModelForMaskedLM` if its internal structure
        # is indeed like that, and then apply a binarization.
        
        # Re-evaluating: UNICOIL typically *learns* explicit token weights.
        # The common approach for UNICOIL with Hugging Face is indeed to load it
        # as `AutoModelForMaskedLM` and use its `logits` output, similar to SPLADE,
        # but with a different aggregation strategy.
        # Let's verify the model type for 'castorini/unicoil-msmarco-passage'.
        # Its config.json and architecture implies it's a BertForMaskedLM variant.

        output = model_unicoil(**inputs) # This should be a BaseModelOutputWithPooling or similar
        
        if not hasattr(output, 'logits'):
            # If `model_unicoil` is an `AutoModel` without a classification head,
            # we need to add a way to get per-token scores.
            # This is where a custom model head or a specific model class would be needed.
            # For `castorini/unicoil-msmarco-passage`, it *is* an MLM variant.
            # So, `output.logits` *should* be available.
            return "UNICOIL model output structure not as expected. 'logits' not found."

        # UNICOIL's output is also typically per-token scores from the MLM head.
        # For UNICOIL, the weights are often taken directly from the logits after pooling.
        # Unlike SPLADE's log(1+ReLU), UNICOIL's approach can be simpler,
        # sometimes just taking the maximum of logits (or similar pooling).
        # A common binarization for UNICOIL is based on the sign of the re-weighted scores.
        
        # Let's mimic a common UNICOIL interpretation for obtaining sparse weights
        # from the logits. The weights are usually sparse and positive.
        # We can apply a threshold for binarization.
        
        # This is a simplification; actual UNICOIL might have specific layers.
        # For `castorini/unicoil-msmarco-passage`, it uses the `log(1+exp(logits))` formulation
        # followed by max pooling, then often binarization based on a threshold.
        
        # Applying a common interpretation of UNICOIL-like score generation for sparse weights:
        # Instead of `log(1+ReLU(logits))`, it often uses `torch.log(1 + torch.exp(output.logits))`.
        # This is essentially the softplus function, which makes values positive and sparse.
        
        # Get the sparse weights using the UNICOIL-like transformation
        sparse_weights = torch.max(torch.log(1 + torch.exp(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()

        # --- Binarization Step for UNICOIL ---
        # For true "binary sparse", we threshold these sparse weights.
        # A common approach is to simply take any non-zero value as 1, and zero as 0.
        # Or, define a small threshold for binarization if values are very small but non-zero.
        # For simplicity, let's treat anything above a very small epsilon as 1.
        
        # Convert to binary: 1 if weight > epsilon, else 0
        threshold = 1e-6 # Define a small threshold for binarization
        binary_sparse_vector = (sparse_weights > threshold).int()
        
        # Get indices of the '1's in the binary vector
        binary_indices = torch.nonzero(binary_sparse_vector).squeeze().cpu().tolist()
        
        if not isinstance(binary_indices, list):
            binary_indices = [binary_indices] if binary_indices.numel() > 0 else []

        # Map token IDs back to terms for the binary representation
        binary_terms = {}
        for token_id in binary_indices:
            decoded_token = tokenizer_unicoil.decode([token_id])
            if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
                binary_terms[decoded_token] = 1 # Value is always 1 for binary

        sorted_binary_terms = sorted(binary_terms.items(), key=lambda item: item[0]) # Sort by term for consistent display

        formatted_output = "UNICOIL Binary Sparse Representation (Activated Terms):\n"
        if not sorted_binary_terms:
            formatted_output += "No significant terms activated for this input.\n"
        else:
            # Display up to 50 activated terms for readability
            for i, (term, _) in enumerate(sorted_binary_terms):
                if i >= 50:
                    break
                formatted_output += f"- **{term}**\n" # Only show term, as weight is always 1
            if len(sorted_binary_terms) > 50:
                formatted_output += f"...and {len(sorted_binary_terms) - 50} more terms.\n"
        
        formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
        formatted_output += f"Total activated terms: {len(binary_indices)}\n"
        # Calculate sparsity based on the number of '1's vs. total vocabulary size
        formatted_output += f"Sparsity: {1 - (len(binary_indices) / tokenizer_unicoil.vocab_size):.2%}\n"

    return formatted_output


# --- Unified Prediction Function for Gradio ---
def predict_representation(model_choice, text):
    if model_choice == "SPLADE":
        return get_splade_representation(text)
    elif model_choice == "UNICOIL (Binary Sparse)":
        return get_unicoil_binary_representation(text)
    else:
        return "Please select a model."

# --- Gradio Interface Setup ---
demo = gr.Interface(
    fn=predict_representation,
    inputs=[
        gr.Radio(
            ["SPLADE", "UNICOIL (Binary Sparse)"], # Added UNICOIL option
            label="Choose Representation Model",
            value="SPLADE" # Default selection
        ),
        gr.Textbox(
            lines=5,
            label="Enter your query or document text here:",
            placeholder="e.g., Why is Padua the nicest city in Italy?"
        )
    ],
    outputs=gr.Markdown(),
    title="🌌 Sparse and Binary Sparse Representation Generator",
    description="Enter any text to see its SPLADE sparse vector or UNICOIL binary sparse representation.",
    allow_flagging="never"
)

# Launch the Gradio app
demo.launch()