SiddharthAK commited on
Commit
da3acda
·
verified ·
1 Parent(s): 1728ea0

added unicoil

Browse files
Files changed (1) hide show
  1. app.py +197 -72
app.py CHANGED
@@ -1,119 +1,244 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForMaskedLM
3
  import torch
4
 
5
  # --- Model Loading ---
6
- # Load SPLADE v3 model and tokenizer
7
- # "naver/splade-v3" is a common and robust SPLADE v3 model.
8
- # Make sure you've accepted any user access agreements on its Hugging Face Hub page.
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  try:
10
- tokenizer = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
11
- model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
12
- model.eval() # Set the model to evaluation mode for inference
13
- print("SPLADE v3 model and tokenizer loaded successfully!")
 
 
 
 
14
  except Exception as e:
15
- print(f"Error loading SPLADE model or tokenizer: {e}")
16
- print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-v3' (https://huggingface.co/naver/splade-v3).")
17
- print("If the problem persists, check your internet connection or try a different SPLADE model if available.")
18
- tokenizer = None
19
- model = None
20
 
21
- # --- Core SPLADE Representation Function ---
22
  def get_splade_representation(text):
23
- if tokenizer is None or model is None:
24
  return "SPLADE model is not loaded. Please check the console for loading errors."
25
 
26
- # Tokenize the input text
27
- # return_tensors="pt" ensures PyTorch tensors are returned.
28
- # padding=True pads to the longest sequence in the batch (though for single input, it's just the input length).
29
- # truncation=True truncates if the text is too long for the model's max input size.
30
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
31
 
32
- # Move inputs to the same device as the model (e.g., CPU or GPU)
33
- # This is important if you were running on a GPU in a production environment.
34
- # For Hugging Face Spaces free tier, it's usually CPU.
35
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
36
-
37
- # Get the model's output without calculating gradients (inference mode)
38
  with torch.no_grad():
39
- output = model(**inputs)
40
-
41
- # Extract the logits from the model's output.
42
- # SPLADE uses the masked language modeling head's logits to derive term importance.
43
- # Apply the SPLADE aggregation function: log(1 + ReLU(logits))
44
- # This transforms the raw logits into a sparse vector where higher values indicate more importance.
45
- # The attention_mask ensures we only consider actual tokens, not padding.
46
-
47
- # Check if 'logits' is in the output (standard for AutoModelForMaskedLM)
48
  if hasattr(output, 'logits'):
49
- # Apply the SPLADE transformation
50
- # output.logits is typically [batch_size, sequence_length, vocab_size]
51
- # We need to take the max over the sequence_length dimension to get a [batch_size, vocab_size] vector.
52
- # inputs.attention_mask.unsqueeze(-1) expands the mask to match vocab_size for element-wise multiplication.
53
  splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
54
  else:
55
- # Fallback/error message if the output structure is unexpected
56
  return "Model output structure not as expected for SPLADE. 'logits' not found."
57
 
58
- # Convert the sparse vector to a human-readable format.
59
- # We only care about the non-zero (or very small) entries, as they represent activated terms.
60
-
61
- # Get the indices (token IDs) of the non-zero elements in the SPLADE vector
62
- # torch.nonzero returns coordinates of non-zero elements. squeeze() removes dimensions of size 1.
63
- # .cpu().tolist() moves the tensor to CPU and converts to a Python list.
64
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
65
-
66
- # If it's a single index (e.g., a very short text), make it a list for consistent processing
67
  if not isinstance(indices, list):
68
  indices = [indices]
69
-
70
- # Get the corresponding values (weights) for these non-zero indices
71
  values = splade_vector[indices].cpu().tolist()
72
-
73
- # Create a dictionary mapping token ID to its weight
74
  token_weights = dict(zip(indices, values))
75
 
76
- # Decode token IDs back to actual words/subwords
77
- # Filter out common special tokens that are not meaningful for retrieval (e.g., [CLS], [SEP], [PAD])
78
- # You can add more tokens to this list if they appear frequently and are not helpful.
79
  meaningful_tokens = {}
80
  for token_id, weight in token_weights.items():
81
- decoded_token = tokenizer.decode([token_id])
82
- # Filter out special tokens or very short/noisy tokens
83
  if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
84
  meaningful_tokens[decoded_token] = weight
85
 
86
- # Sort the meaningful tokens by their weight in descending order
87
  sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
88
 
89
- # Format the output for display
90
  formatted_output = "SPLADE Representation (Top 20 Terms):\n"
91
  if not sorted_representation:
92
  formatted_output += "No significant terms found for this input.\n"
93
  else:
94
  for i, (term, weight) in enumerate(sorted_representation):
95
- if i >= 20: # Limit to top 20 terms for readability
96
  break
97
  formatted_output += f"- **{term}**: {weight:.4f}\n"
98
-
99
  formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
100
  formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
101
- formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer.vocab_size):.2%}\n" # Calculate sparsity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  return formatted_output
104
 
 
 
 
 
 
 
 
 
 
 
105
  # --- Gradio Interface Setup ---
106
  demo = gr.Interface(
107
- fn=get_splade_representation,
108
- inputs=gr.Textbox(
109
- lines=5,
110
- label="Enter your query or document text here:",
111
- placeholder="Why is Padua the nicest city in Italy?"
112
- ),
113
- outputs=gr.Markdown(), # Use Markdown for richer text formatting (bolding terms)
114
- title="🌌 SPLADE v3 Sparse Representation Generator",
115
- description="Enter any text (query or document) to see its SPLADE v3 sparse vector representation. The output highlights the most important terms with their learned weights.",
116
- allow_flagging="never" # Disable flagging for this demo
 
 
 
 
 
 
 
117
  )
118
 
119
  # Launch the Gradio app
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel
3
  import torch
4
 
5
  # --- Model Loading ---
6
+ tokenizer_splade = None
7
+ model_splade = None
8
+ tokenizer_unicoil = None
9
+ model_unicoil = None
10
+
11
+ # Load SPLADE v3 model
12
+ try:
13
+ tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
14
+ model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
15
+ model_splade.eval() # Set to evaluation mode for inference
16
+ print("SPLADE v3 model loaded successfully!")
17
+ except Exception as e:
18
+ print(f"Error loading SPLADE model: {e}")
19
+ print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
20
+
21
+ # Load UNICOIL model for binary sparse encoding
22
  try:
23
+ # UNICOIL models are typically just AutoModel as they add a linear layer
24
+ # on top of a BERT-like encoder to predict weights.
25
+ # 'castorini/unicoil-msmarco-passage' is a common UNICOIL checkpoint.
26
+ unicoil_model_name = "castorini/unicoil-msmarco-passage"
27
+ tokenizer_unicoil = AutoTokenizer.from_pretrained(unicoil_model_name)
28
+ model_unicoil = AutoModel.from_pretrained(unicoil_model_name)
29
+ model_unicoil.eval() # Set to evaluation mode for inference
30
+ print(f"UNICOIL model '{unicoil_model_name}' loaded successfully!")
31
  except Exception as e:
32
+ print(f"Error loading UNICOIL model: {e}")
33
+ print(f"Please ensure '{unicoil_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
34
+
35
+
36
+ # --- Core Representation Functions ---
37
 
 
38
  def get_splade_representation(text):
39
+ if tokenizer_splade is None or model_splade is None:
40
  return "SPLADE model is not loaded. Please check the console for loading errors."
41
 
42
+ inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
43
+ inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
 
 
 
44
 
 
 
 
 
 
 
45
  with torch.no_grad():
46
+ output = model_splade(**inputs)
47
+
 
 
 
 
 
 
 
48
  if hasattr(output, 'logits'):
 
 
 
 
49
  splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
50
  else:
 
51
  return "Model output structure not as expected for SPLADE. 'logits' not found."
52
 
 
 
 
 
 
 
53
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
 
 
54
  if not isinstance(indices, list):
55
  indices = [indices]
56
+
 
57
  values = splade_vector[indices].cpu().tolist()
 
 
58
  token_weights = dict(zip(indices, values))
59
 
 
 
 
60
  meaningful_tokens = {}
61
  for token_id, weight in token_weights.items():
62
+ decoded_token = tokenizer_splade.decode([token_id])
 
63
  if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
64
  meaningful_tokens[decoded_token] = weight
65
 
 
66
  sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
67
 
 
68
  formatted_output = "SPLADE Representation (Top 20 Terms):\n"
69
  if not sorted_representation:
70
  formatted_output += "No significant terms found for this input.\n"
71
  else:
72
  for i, (term, weight) in enumerate(sorted_representation):
73
+ if i >= 20:
74
  break
75
  formatted_output += f"- **{term}**: {weight:.4f}\n"
76
+
77
  formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
78
  formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
79
+ formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n"
80
+
81
+ return formatted_output
82
+
83
+
84
+ def get_unicoil_binary_representation(text):
85
+ if tokenizer_unicoil is None or model_unicoil is None:
86
+ return "UNICOIL model is not loaded. Please check the console for loading errors."
87
+
88
+ inputs = tokenizer_unicoil(text, return_tensors="pt", padding=True, truncation=True)
89
+ inputs = {k: v.to(model_unicoil.device) for k, v in inputs.items()}
90
+
91
+ with torch.no_grad():
92
+ # UNICOIL models often output a dictionary where 'token_scores' or similar
93
+ # contain the learned weights for each token. The structure can vary.
94
+ # For 'castorini/unicoil-msmarco-passage', the token scores are typically
95
+ # the last hidden state of the model, which is then mapped by a linear layer
96
+ # into the sparse weights. We might need to manually extract those,
97
+ # or the model itself might be set up to produce the weights directly.
98
+ # Based on typical UNICOIL implementations, we usually take the output
99
+ # from the last layer and map it to vocabulary size.
100
+
101
+ # In many UNICOIL variations, the model itself is designed to output
102
+ # the "re-weight" scores for each token. Let's assume for simplicity
103
+ # that the model's forward pass returns something that can be interpreted
104
+ # as per-token scores, perhaps in `output.last_hidden_state`
105
+ # and then a simple linear layer (not part of the AutoModel usually)
106
+ # would project this to vocab size.
107
+ # For simplicity and to fit `AutoModel`, we'll treat the last hidden state
108
+ # directly as the basis for term importance for now, which is common in similar models,
109
+ # or if the model already has a head, we use it.
110
+
111
+ # A more robust UNICOIL implementation would involve a specific head
112
+ # if not using AutoModelForMaskedLM. However, AutoModel gives us the
113
+ # last hidden states from which we can infer.
114
+
115
+ # For UNICOIL, we're interested in the weighted token scores.
116
+ # `model(**inputs)` will typically return a `BaseModelOutput`
117
+ # or `MaskedLMOutput` if it's based on an MLM.
118
+ # Let's assume the model's output provides the token importance.
119
+
120
+ # A common way to get UNICOIL scores if not explicitly provided as logits:
121
+ # It's usually a linear layer on top of the last hidden state.
122
+ # Since AutoModel just gives the base model, we'll mimic the output
123
+ # as a direct mapping if the model doesn't have a specific head for scores.
124
+ # However, looking at `castorini/unicoil-msmarco-passage`
125
+ # its `config.json` might give hints or the model itself is structured.
126
+ # Often, it uses `BertForMaskedLM` and then applies `log(1+relu)` to the logits.
127
+ # Let's assume it behaves similar to SPLADE for simplicity of extraction for now,
128
+ # or we might need to load it as `AutoModelForMaskedLM` if its internal structure
129
+ # is indeed like that, and then apply a binarization.
130
+
131
+ # Re-evaluating: UNICOIL typically *learns* explicit token weights.
132
+ # The common approach for UNICOIL with Hugging Face is indeed to load it
133
+ # as `AutoModelForMaskedLM` and use its `logits` output, similar to SPLADE,
134
+ # but with a different aggregation strategy.
135
+ # Let's verify the model type for 'castorini/unicoil-msmarco-passage'.
136
+ # Its config.json and architecture implies it's a BertForMaskedLM variant.
137
+
138
+ output = model_unicoil(**inputs) # This should be a BaseModelOutputWithPooling or similar
139
+
140
+ if not hasattr(output, 'logits'):
141
+ # If `model_unicoil` is an `AutoModel` without a classification head,
142
+ # we need to add a way to get per-token scores.
143
+ # This is where a custom model head or a specific model class would be needed.
144
+ # For `castorini/unicoil-msmarco-passage`, it *is* an MLM variant.
145
+ # So, `output.logits` *should* be available.
146
+ return "UNICOIL model output structure not as expected. 'logits' not found."
147
+
148
+ # UNICOIL's output is also typically per-token scores from the MLM head.
149
+ # For UNICOIL, the weights are often taken directly from the logits after pooling.
150
+ # Unlike SPLADE's log(1+ReLU), UNICOIL's approach can be simpler,
151
+ # sometimes just taking the maximum of logits (or similar pooling).
152
+ # A common binarization for UNICOIL is based on the sign of the re-weighted scores.
153
+
154
+ # Let's mimic a common UNICOIL interpretation for obtaining sparse weights
155
+ # from the logits. The weights are usually sparse and positive.
156
+ # We can apply a threshold for binarization.
157
+
158
+ # This is a simplification; actual UNICOIL might have specific layers.
159
+ # For `castorini/unicoil-msmarco-passage`, it uses the `log(1+exp(logits))` formulation
160
+ # followed by max pooling, then often binarization based on a threshold.
161
+
162
+ # Applying a common interpretation of UNICOIL-like score generation for sparse weights:
163
+ # Instead of `log(1+ReLU(logits))`, it often uses `torch.log(1 + torch.exp(output.logits))`.
164
+ # This is essentially the softplus function, which makes values positive and sparse.
165
+
166
+ # Get the sparse weights using the UNICOIL-like transformation
167
+ sparse_weights = torch.max(torch.log(1 + torch.exp(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
168
+
169
+ # --- Binarization Step for UNICOIL ---
170
+ # For true "binary sparse", we threshold these sparse weights.
171
+ # A common approach is to simply take any non-zero value as 1, and zero as 0.
172
+ # Or, define a small threshold for binarization if values are very small but non-zero.
173
+ # For simplicity, let's treat anything above a very small epsilon as 1.
174
+
175
+ # Convert to binary: 1 if weight > epsilon, else 0
176
+ threshold = 1e-6 # Define a small threshold for binarization
177
+ binary_sparse_vector = (sparse_weights > threshold).int()
178
+
179
+ # Get indices of the '1's in the binary vector
180
+ binary_indices = torch.nonzero(binary_sparse_vector).squeeze().cpu().tolist()
181
+
182
+ if not isinstance(binary_indices, list):
183
+ binary_indices = [binary_indices] if binary_indices.numel() > 0 else []
184
+
185
+ # Map token IDs back to terms for the binary representation
186
+ binary_terms = {}
187
+ for token_id in binary_indices:
188
+ decoded_token = tokenizer_unicoil.decode([token_id])
189
+ if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
190
+ binary_terms[decoded_token] = 1 # Value is always 1 for binary
191
+
192
+ sorted_binary_terms = sorted(binary_terms.items(), key=lambda item: item[0]) # Sort by term for consistent display
193
+
194
+ formatted_output = "UNICOIL Binary Sparse Representation (Activated Terms):\n"
195
+ if not sorted_binary_terms:
196
+ formatted_output += "No significant terms activated for this input.\n"
197
+ else:
198
+ # Display up to 50 activated terms for readability
199
+ for i, (term, _) in enumerate(sorted_binary_terms):
200
+ if i >= 50:
201
+ break
202
+ formatted_output += f"- **{term}**\n" # Only show term, as weight is always 1
203
+ if len(sorted_binary_terms) > 50:
204
+ formatted_output += f"...and {len(sorted_binary_terms) - 50} more terms.\n"
205
+
206
+ formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
207
+ formatted_output += f"Total activated terms: {len(binary_indices)}\n"
208
+ # Calculate sparsity based on the number of '1's vs. total vocabulary size
209
+ formatted_output += f"Sparsity: {1 - (len(binary_indices) / tokenizer_unicoil.vocab_size):.2%}\n"
210
 
211
  return formatted_output
212
 
213
+
214
+ # --- Unified Prediction Function for Gradio ---
215
+ def predict_representation(model_choice, text):
216
+ if model_choice == "SPLADE":
217
+ return get_splade_representation(text)
218
+ elif model_choice == "UNICOIL (Binary Sparse)":
219
+ return get_unicoil_binary_representation(text)
220
+ else:
221
+ return "Please select a model."
222
+
223
  # --- Gradio Interface Setup ---
224
  demo = gr.Interface(
225
+ fn=predict_representation,
226
+ inputs=[
227
+ gr.Radio(
228
+ ["SPLADE", "UNICOIL (Binary Sparse)"], # Added UNICOIL option
229
+ label="Choose Representation Model",
230
+ value="SPLADE" # Default selection
231
+ ),
232
+ gr.Textbox(
233
+ lines=5,
234
+ label="Enter your query or document text here:",
235
+ placeholder="e.g., Why is Padua the nicest city in Italy?"
236
+ )
237
+ ],
238
+ outputs=gr.Markdown(),
239
+ title="🌌 Sparse and Binary Sparse Representation Generator",
240
+ description="Enter any text to see its SPLADE sparse vector or UNICOIL binary sparse representation.",
241
+ allow_flagging="never"
242
  )
243
 
244
  # Launch the Gradio app