SiddharthAK commited on
Commit
f4c84bc
·
verified ·
1 Parent(s): 4a365e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -55
app.py CHANGED
@@ -43,33 +43,28 @@ except Exception as e:
43
  print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
44
 
45
 
46
- # --- Helper function for lexical mask ---
47
  def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
48
  """
49
  Creates a binary bag-of-words mask from input_ids,
50
  zeroing out special tokens and padding.
51
  """
52
- # Initialize a zero vector for the entire vocabulary
53
  bow_mask = torch.zeros(vocab_size, device=input_ids.device)
54
-
55
- # Get unique token IDs from the input, excluding special tokens
56
- # input_ids is typically [batch_size, seq_len], we assume batch_size=1
57
  meaningful_token_ids = []
58
- for token_id in input_ids.squeeze().tolist(): # Squeeze to remove batch dim and convert to list
59
  if token_id not in [
60
  tokenizer.pad_token_id,
61
  tokenizer.cls_token_id,
62
  tokenizer.sep_token_id,
63
  tokenizer.mask_token_id,
64
- tokenizer.unk_token_id # Also exclude unknown tokens
65
  ]:
66
  meaningful_token_ids.append(token_id)
67
 
68
- # Set 1 for tokens present in the original input
69
  if meaningful_token_ids:
70
- bow_mask[list(set(meaningful_token_ids))] = 1 # Use set to handle duplicates
71
 
72
- return bow_mask.unsqueeze(0) # Keep batch dimension for consistency
73
 
74
 
75
  # --- Core Representation Functions ---
@@ -140,15 +135,10 @@ def get_splade_lexical_representation(text):
140
  return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
141
 
142
  # --- Apply Lexical Mask (always applied for this function now) ---
143
- # Get the vocabulary size from the tokenizer
144
  vocab_size = tokenizer_splade_lexical.vocab_size
145
-
146
- # Create the Bag-of-Words mask
147
  bow_mask = create_lexical_bow_mask(
148
  inputs['input_ids'], vocab_size, tokenizer_splade_lexical
149
  ).squeeze()
150
-
151
- # Multiply the SPLADE vector by the BoW mask to zero out expanded terms
152
  splade_vector = splade_vector * bow_mask
153
  # --- End Lexical Mask Logic ---
154
 
@@ -181,8 +171,8 @@ def get_splade_lexical_representation(text):
181
  return formatted_output
182
 
183
 
184
- # NEW: Function for SPLADE-v3-Doc representation
185
- def get_splade_doc_representation(text, apply_lexical_mask: bool):
186
  if tokenizer_splade_doc is None or model_splade_doc is None:
187
  return "SPLADE v3 Doc model is not loaded. Please check the console for loading errors."
188
 
@@ -192,28 +182,34 @@ def get_splade_doc_representation(text, apply_lexical_mask: bool):
192
  with torch.no_grad():
193
  output = model_splade_doc(**inputs)
194
 
195
- if hasattr(output, 'logits'):
196
- splade_vector = torch.max(
197
- torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
198
- dim=1
199
- )[0].squeeze()
200
- else:
201
- return "Model output structure not as expected for SPLADE v3 Doc. 'logits' not found."
202
-
203
- # --- Apply Lexical Mask if requested ---
204
- if apply_lexical_mask:
205
- vocab_size = tokenizer_splade_doc.vocab_size
206
- bow_mask = create_lexical_bow_mask(
207
- inputs['input_ids'], vocab_size, tokenizer_splade_doc
208
- ).squeeze()
209
- splade_vector = splade_vector * bow_mask
210
- # --- End Lexical Mask Logic ---
211
-
212
- indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
213
- if not isinstance(indices, list):
214
- indices = [indices]
215
-
216
- values = splade_vector[indices].cpu().tolist()
 
 
 
 
 
 
217
  token_weights = dict(zip(indices, values))
218
 
219
  meaningful_tokens = {}
@@ -222,17 +218,22 @@ def get_splade_doc_representation(text, apply_lexical_mask: bool):
222
  if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
223
  meaningful_tokens[decoded_token] = weight
224
 
225
- sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
226
 
227
- formatted_output = "SPLADE v3 Doc Representation (All Non-Zero Terms):\n"
228
  if not sorted_representation:
229
  formatted_output += "No significant terms found for this input.\n"
230
  else:
231
- for term, weight in sorted_representation:
232
- formatted_output += f"- **{term}**: {weight:.4f}\n"
233
-
234
- formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
235
- formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
 
 
 
 
 
236
  formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
237
 
238
  return formatted_output
@@ -243,12 +244,11 @@ def predict_representation(model_choice, text):
243
  if model_choice == "SPLADE (cocondenser)":
244
  return get_splade_representation(text)
245
  elif model_choice == "SPLADE-v3-Lexical":
246
- # Always applies lexical mask for this option as per last request
247
  return get_splade_lexical_representation(text)
248
- elif model_choice == "SPLADE-v3-Doc (with expansion)": # New option
249
- return get_splade_doc_representation(text, apply_lexical_mask=False)
250
- elif model_choice == "SPLADE-v3-Doc (lexical-only)": # New option
251
- return get_splade_doc_representation(text, apply_lexical_mask=True)
252
  else:
253
  return "Please select a model."
254
 
@@ -259,9 +259,8 @@ demo = gr.Interface(
259
  gr.Radio(
260
  [
261
  "SPLADE (cocondenser)",
262
- "SPLADE-v3-Lexical", # Lexical-only by default now
263
- "SPLADE-v3-Doc (with expansion)", # Option to see full neural output
264
- "SPLADE-v3-Doc (lexical-only)" # Option with lexical mask applied
265
  ],
266
  label="Choose Representation Model",
267
  value="SPLADE (cocondenser)" # Default selection
@@ -274,7 +273,7 @@ demo = gr.Interface(
274
  ],
275
  outputs=gr.Markdown(),
276
  title="🌌 Sparse Representation Generator",
277
- description="Enter any text to see its SPLADE sparse vector. Explore different SPLADE models and their expansion behaviors.",
278
  allow_flagging="never"
279
  )
280
 
 
43
  print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
44
 
45
 
46
+ # --- Helper function for lexical mask (still needed for splade-v3-lexical) ---
47
  def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
48
  """
49
  Creates a binary bag-of-words mask from input_ids,
50
  zeroing out special tokens and padding.
51
  """
 
52
  bow_mask = torch.zeros(vocab_size, device=input_ids.device)
 
 
 
53
  meaningful_token_ids = []
54
+ for token_id in input_ids.squeeze().tolist():
55
  if token_id not in [
56
  tokenizer.pad_token_id,
57
  tokenizer.cls_token_id,
58
  tokenizer.sep_token_id,
59
  tokenizer.mask_token_id,
60
+ tokenizer.unk_token_id
61
  ]:
62
  meaningful_token_ids.append(token_id)
63
 
 
64
  if meaningful_token_ids:
65
+ bow_mask[list(set(meaningful_token_ids))] = 1
66
 
67
+ return bow_mask.unsqueeze(0)
68
 
69
 
70
  # --- Core Representation Functions ---
 
135
  return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
136
 
137
  # --- Apply Lexical Mask (always applied for this function now) ---
 
138
  vocab_size = tokenizer_splade_lexical.vocab_size
 
 
139
  bow_mask = create_lexical_bow_mask(
140
  inputs['input_ids'], vocab_size, tokenizer_splade_lexical
141
  ).squeeze()
 
 
142
  splade_vector = splade_vector * bow_mask
143
  # --- End Lexical Mask Logic ---
144
 
 
171
  return formatted_output
172
 
173
 
174
+ # NEW: Function for SPLADE-v3-Doc representation (Binary Sparse)
175
+ def get_splade_doc_representation(text):
176
  if tokenizer_splade_doc is None or model_splade_doc is None:
177
  return "SPLADE v3 Doc model is not loaded. Please check the console for loading errors."
178
 
 
182
  with torch.no_grad():
183
  output = model_splade_doc(**inputs)
184
 
185
+ if not hasattr(output, "logits"):
186
+ return "SPLADE v3 Doc model output structure not as expected. 'logits' not found."
187
+
188
+ # For SPLADE-v3-Doc, the output is often a binary sparse vector.
189
+ # We will assume a simple binarization based on a threshold or selecting active tokens.
190
+ # A common way to get "binary" is to use softplus and then binarize, or directly binarize max logits.
191
+ # Given the "no weighting, no expansion" request, we'll aim for a strict presence check.
192
+
193
+ # Option 1: Binarize based on softplus output and threshold (similar to UNICOIL)
194
+ # This might still activate some "expanded" terms if the model predicts them strongly.
195
+ # transformed_scores = torch.log(1 + torch.exp(output.logits)) # Softplus
196
+ # splade_vector_raw = torch.max(transformed_scores * inputs['attention_mask'].unsqueeze(-1), dim=1).values
197
+ # binary_splade_vector = (splade_vector_raw > 0.5).float() # Binarize
198
+
199
+ # Option 2: Rely on the original BoW for terms, with 1 for presence
200
+ # This aligns best with "no weighting, no expansion"
201
+ vocab_size = tokenizer_splade_doc.vocab_size
202
+ binary_splade_vector = create_lexical_bow_mask(
203
+ inputs['input_ids'], vocab_size, tokenizer_splade_doc
204
+ ).squeeze()
205
+
206
+ # We set values to 1 as it's a binary representation, not weighted
207
+ indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
208
+ if not isinstance(indices, list): # Handle case where only one non-zero index
209
+ indices = [indices] if indices else [] # Ensure it's a list even if empty or single
210
+
211
+ # Values are all 1 for binary representation
212
+ values = [1.0] * len(indices)
213
  token_weights = dict(zip(indices, values))
214
 
215
  meaningful_tokens = {}
 
218
  if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
219
  meaningful_tokens[decoded_token] = weight
220
 
221
+ sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for binary
222
 
223
+ formatted_output = "SPLADE v3 Doc Representation (Binary Sparse - Lexical Only):\n"
224
  if not sorted_representation:
225
  formatted_output += "No significant terms found for this input.\n"
226
  else:
227
+ # Display as terms with no weights as they are binary (value 1)
228
+ for i, (term, _) in enumerate(sorted_representation):
229
+ # Limit display for very long lists for readability
230
+ if i >= 50:
231
+ formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n"
232
+ break
233
+ formatted_output += f"- **{term}**\n"
234
+
235
+ formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
236
+ formatted_output += f"Total activated terms: {len(indices)}\n"
237
  formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
238
 
239
  return formatted_output
 
244
  if model_choice == "SPLADE (cocondenser)":
245
  return get_splade_representation(text)
246
  elif model_choice == "SPLADE-v3-Lexical":
247
+ # Always applies lexical mask for this option
248
  return get_splade_lexical_representation(text)
249
+ elif model_choice == "SPLADE-v3-Doc": # Simplified to a single option
250
+ # This function now intrinsically handles binary, lexical-only output
251
+ return get_splade_doc_representation(text)
 
252
  else:
253
  return "Please select a model."
254
 
 
259
  gr.Radio(
260
  [
261
  "SPLADE (cocondenser)",
262
+ "SPLADE-v3-Lexical",
263
+ "SPLADE-v3-Doc" # Only one option for Doc model
 
264
  ],
265
  label="Choose Representation Model",
266
  value="SPLADE (cocondenser)" # Default selection
 
273
  ],
274
  outputs=gr.Markdown(),
275
  title="🌌 Sparse Representation Generator",
276
+ description="Enter any text to see its sparse vector representation.", # Simplified description
277
  allow_flagging="never"
278
  )
279