SiddharthAK commited on
Commit
4a365e4
·
verified ·
1 Parent(s): 63d582f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -6
app.py CHANGED
@@ -7,6 +7,8 @@ tokenizer_splade = None
7
  model_splade = None
8
  tokenizer_splade_lexical = None
9
  model_splade_lexical = None
 
 
10
 
11
  # Load SPLADE v3 model (original)
12
  try:
@@ -29,6 +31,18 @@ except Exception as e:
29
  print(f"Error loading SPLADE v3 Lexical model: {e}")
30
  print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
31
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # --- Helper function for lexical mask ---
33
  def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
34
  """
@@ -107,7 +121,7 @@ def get_splade_representation(text):
107
  return formatted_output
108
 
109
 
110
- def get_splade_lexical_representation(text): # Removed apply_lexical_mask parameter
111
  if tokenizer_splade_lexical is None or model_splade_lexical is None:
112
  return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
113
 
@@ -167,12 +181,74 @@ def get_splade_lexical_representation(text): # Removed apply_lexical_mask parame
167
  return formatted_output
168
 
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  # --- Unified Prediction Function for Gradio ---
171
  def predict_representation(model_choice, text):
172
  if model_choice == "SPLADE (cocondenser)":
173
  return get_splade_representation(text)
174
- elif model_choice == "SPLADE-v3-Lexical": # Simplified choice
175
- return get_splade_lexical_representation(text) # Always applies lexical mask
 
 
 
 
 
176
  else:
177
  return "Please select a model."
178
 
@@ -182,8 +258,10 @@ demo = gr.Interface(
182
  inputs=[
183
  gr.Radio(
184
  [
185
- "SPLADE-cocondenser-distil (expansion and weighting)",
186
- "SPLADE-v3-Lexical (weighting only)" # Simplified option
 
 
187
  ],
188
  label="Choose Representation Model",
189
  value="SPLADE (cocondenser)" # Default selection
@@ -196,7 +274,7 @@ demo = gr.Interface(
196
  ],
197
  outputs=gr.Markdown(),
198
  title="🌌 Sparse Representation Generator",
199
- description="Enter any text to see its SPLADE sparse vector.", # Simplified description
200
  allow_flagging="never"
201
  )
202
 
 
7
  model_splade = None
8
  tokenizer_splade_lexical = None
9
  model_splade_lexical = None
10
+ tokenizer_splade_doc = None # New tokenizer for SPLADE-v3-Doc
11
+ model_splade_doc = None # New model for SPLADE-v3-Doc
12
 
13
  # Load SPLADE v3 model (original)
14
  try:
 
31
  print(f"Error loading SPLADE v3 Lexical model: {e}")
32
  print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
33
 
34
+ # Load SPLADE v3 Doc model (NEW)
35
+ try:
36
+ splade_doc_model_name = "naver/splade-v3-doc"
37
+ tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
38
+ model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name)
39
+ model_splade_doc.eval() # Set to evaluation mode for inference
40
+ print(f"SPLADE v3 Doc model '{splade_doc_model_name}' loaded successfully!")
41
+ except Exception as e:
42
+ print(f"Error loading SPLADE v3 Doc model: {e}")
43
+ print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
44
+
45
+
46
  # --- Helper function for lexical mask ---
47
  def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
48
  """
 
121
  return formatted_output
122
 
123
 
124
+ def get_splade_lexical_representation(text):
125
  if tokenizer_splade_lexical is None or model_splade_lexical is None:
126
  return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
127
 
 
181
  return formatted_output
182
 
183
 
184
+ # NEW: Function for SPLADE-v3-Doc representation
185
+ def get_splade_doc_representation(text, apply_lexical_mask: bool):
186
+ if tokenizer_splade_doc is None or model_splade_doc is None:
187
+ return "SPLADE v3 Doc model is not loaded. Please check the console for loading errors."
188
+
189
+ inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
190
+ inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()}
191
+
192
+ with torch.no_grad():
193
+ output = model_splade_doc(**inputs)
194
+
195
+ if hasattr(output, 'logits'):
196
+ splade_vector = torch.max(
197
+ torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
198
+ dim=1
199
+ )[0].squeeze()
200
+ else:
201
+ return "Model output structure not as expected for SPLADE v3 Doc. 'logits' not found."
202
+
203
+ # --- Apply Lexical Mask if requested ---
204
+ if apply_lexical_mask:
205
+ vocab_size = tokenizer_splade_doc.vocab_size
206
+ bow_mask = create_lexical_bow_mask(
207
+ inputs['input_ids'], vocab_size, tokenizer_splade_doc
208
+ ).squeeze()
209
+ splade_vector = splade_vector * bow_mask
210
+ # --- End Lexical Mask Logic ---
211
+
212
+ indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
213
+ if not isinstance(indices, list):
214
+ indices = [indices]
215
+
216
+ values = splade_vector[indices].cpu().tolist()
217
+ token_weights = dict(zip(indices, values))
218
+
219
+ meaningful_tokens = {}
220
+ for token_id, weight in token_weights.items():
221
+ decoded_token = tokenizer_splade_doc.decode([token_id])
222
+ if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
223
+ meaningful_tokens[decoded_token] = weight
224
+
225
+ sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
226
+
227
+ formatted_output = "SPLADE v3 Doc Representation (All Non-Zero Terms):\n"
228
+ if not sorted_representation:
229
+ formatted_output += "No significant terms found for this input.\n"
230
+ else:
231
+ for term, weight in sorted_representation:
232
+ formatted_output += f"- **{term}**: {weight:.4f}\n"
233
+
234
+ formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
235
+ formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
236
+ formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
237
+
238
+ return formatted_output
239
+
240
+
241
  # --- Unified Prediction Function for Gradio ---
242
  def predict_representation(model_choice, text):
243
  if model_choice == "SPLADE (cocondenser)":
244
  return get_splade_representation(text)
245
+ elif model_choice == "SPLADE-v3-Lexical":
246
+ # Always applies lexical mask for this option as per last request
247
+ return get_splade_lexical_representation(text)
248
+ elif model_choice == "SPLADE-v3-Doc (with expansion)": # New option
249
+ return get_splade_doc_representation(text, apply_lexical_mask=False)
250
+ elif model_choice == "SPLADE-v3-Doc (lexical-only)": # New option
251
+ return get_splade_doc_representation(text, apply_lexical_mask=True)
252
  else:
253
  return "Please select a model."
254
 
 
258
  inputs=[
259
  gr.Radio(
260
  [
261
+ "SPLADE (cocondenser)",
262
+ "SPLADE-v3-Lexical", # Lexical-only by default now
263
+ "SPLADE-v3-Doc (with expansion)", # Option to see full neural output
264
+ "SPLADE-v3-Doc (lexical-only)" # Option with lexical mask applied
265
  ],
266
  label="Choose Representation Model",
267
  value="SPLADE (cocondenser)" # Default selection
 
274
  ],
275
  outputs=gr.Markdown(),
276
  title="🌌 Sparse Representation Generator",
277
+ description="Enter any text to see its SPLADE sparse vector. Explore different SPLADE models and their expansion behaviors.",
278
  allow_flagging="never"
279
  )
280