SiddharthAK commited on
Commit
b7f6be7
·
verified ·
1 Parent(s): f384e43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -37
app.py CHANGED
@@ -2,8 +2,8 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForMaskedLM
3
  import torch
4
  import numpy as np
5
- from tqdm.auto import tqdm # Still useful for model loading progress if desired, but not strictly necessary for this simplified version
6
- import os # Still useful for general purpose, but not explicitly used in this simplified version
7
 
8
  # --- Model Loading ---
9
  tokenizer_splade = None
@@ -34,11 +34,11 @@ except Exception as e:
34
  print(f"Error loading SPLADE-v3-Lexical model: {e}")
35
  print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
36
 
37
- # Load SPLADE v3 Doc model
38
  try:
39
  splade_doc_model_name = "naver/splade-v3-doc"
40
  tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
41
- model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name)
42
  model_splade_doc.eval()
43
  print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!")
44
  except Exception as e:
@@ -183,25 +183,19 @@ def get_splade_lexical_representation(text):
183
 
184
 
185
  def get_splade_doc_representation(text):
186
- if tokenizer_splade_doc is None or model_splade_doc is None:
187
- return "SPLADE-v3-Doc model is not loaded. Please check the console for loading errors."
188
 
189
  inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
190
- inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()}
191
-
192
- with torch.no_grad():
193
- output = model_splade_doc(**inputs)
194
-
195
- if not hasattr(output, "logits"):
196
- return "Model output structure not as expected. 'logits' not found."
197
 
198
  vocab_size = tokenizer_splade_doc.vocab_size
199
- # Call with unsqueezed input_ids for single sample processing
200
- binary_splade_vector = create_lexical_bow_mask(
201
  inputs['input_ids'], vocab_size, tokenizer_splade_doc
202
  ).squeeze() # Squeeze back for single output
203
 
204
- indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
205
  if not isinstance(indices, list):
206
  indices = [indices] if indices else []
207
 
@@ -216,7 +210,7 @@ def get_splade_doc_representation(text):
216
 
217
  sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity
218
 
219
- formatted_output = "SPLADE-v3-Doc Representation (Binary):\n"
220
  if not sorted_representation:
221
  formatted_output += "No significant terms found for this input.\n"
222
  else:
@@ -226,7 +220,7 @@ def get_splade_doc_representation(text):
226
  break
227
  formatted_output += f"- **{term}**\n"
228
 
229
- formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
230
  formatted_output += f"Total activated terms: {len(indices)}\n"
231
  formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
232
 
@@ -235,11 +229,11 @@ def get_splade_doc_representation(text):
235
 
236
  # --- Unified Prediction Function for the Explorer Tab ---
237
  def predict_representation_explorer(model_choice, text):
238
- if model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
239
  return get_splade_cocondenser_representation(text)
240
  elif model_choice == "MLP encoder (SPLADE-v3-lexical)":
241
  return get_splade_lexical_representation(text)
242
- elif model_choice == "Binary encoder":
243
  return get_splade_doc_representation(text)
244
  else:
245
  return "Please select a model."
@@ -290,22 +284,18 @@ def get_splade_lexical_vector(text):
290
  return None
291
 
292
  def get_splade_doc_vector(text):
293
- if tokenizer_splade_doc is None or model_splade_doc is None:
294
  return None
295
 
296
  inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
297
- inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()}
298
-
299
- with torch.no_grad():
300
- output = model_splade_doc(**inputs)
301
 
302
- if hasattr(output, "logits"):
303
- vocab_size = tokenizer_splade_doc.vocab_size
304
- binary_splade_vector = create_lexical_bow_mask(
305
- inputs['input_ids'], vocab_size, tokenizer_splade_doc
306
- ).squeeze()
307
- return binary_splade_vector
308
- return None
309
 
310
 
311
  # --- Function to get formatted representation from a raw vector and tokenizer ---
@@ -322,7 +312,7 @@ def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
322
  values = [1.0] * len(indices)
323
  else:
324
  values = splade_vector[indices].cpu().tolist()
325
-
326
  token_weights = dict(zip(indices, values))
327
 
328
  meaningful_tokens = {}
@@ -361,8 +351,8 @@ def get_model_assets(model_choice_str):
361
  return get_splade_cocondenser_vector, tokenizer_splade, False, "MLM encoder (SPLADE-cocondenser-distil)"
362
  elif model_choice_str == "MLP encoder (SPLADE-v3-lexical)":
363
  return get_splade_lexical_vector, tokenizer_splade_lexical, False, "MLP encoder (SPLADE-v3-lexical)"
364
- elif model_choice_str == "Binary encoder":
365
- return get_splade_doc_vector, tokenizer_splade_doc, True, "Binary encoder"
366
  else:
367
  return None, None, False, "Unknown Model"
368
 
@@ -415,7 +405,7 @@ with gr.Blocks(title="SPLADE Demos") as demo:
415
  [
416
  "MLM encoder (SPLADE-cocondenser-distil)",
417
  "MLP encoder (SPLADE-v3-lexical)",
418
- "Binary Encoder"
419
  ],
420
  label="Choose Sparse Encoder",
421
  value="MLM encoder (SPLADE-cocondenser-distil)"
@@ -439,7 +429,7 @@ with gr.Blocks(title="SPLADE Demos") as demo:
439
  model_choices = [
440
  "MLM encoder (SPLADE-cocondenser-distil)",
441
  "MLP encoder (SPLADE-v3-lexical)",
442
- "Binary encoder"
443
  ]
444
 
445
  gr.Interface(
 
2
  from transformers import AutoTokenizer, AutoModelForMaskedLM
3
  import torch
4
  import numpy as np
5
+ from tqdm.auto import tqdm
6
+ import os
7
 
8
  # --- Model Loading ---
9
  tokenizer_splade = None
 
34
  print(f"Error loading SPLADE-v3-Lexical model: {e}")
35
  print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
36
 
37
+ # Load SPLADE v3 Doc model - Model loading is still necessary even if its logits aren't used for BoW
38
  try:
39
  splade_doc_model_name = "naver/splade-v3-doc"
40
  tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
41
+ model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name) # Still load the model
42
  model_splade_doc.eval()
43
  print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!")
44
  except Exception as e:
 
183
 
184
 
185
  def get_splade_doc_representation(text):
186
+ if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
187
+ return "SPLADE-v3-Doc tokenizer is not loaded. Please check the console for loading errors."
188
 
189
  inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
190
+ inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
 
 
 
 
 
 
191
 
192
  vocab_size = tokenizer_splade_doc.vocab_size
193
+ # Directly create the binary Bag-of-Words vector using the input_ids
194
+ binary_bow_vector = create_lexical_bow_mask(
195
  inputs['input_ids'], vocab_size, tokenizer_splade_doc
196
  ).squeeze() # Squeeze back for single output
197
 
198
+ indices = torch.nonzero(binary_bow_vector).squeeze().cpu().tolist()
199
  if not isinstance(indices, list):
200
  indices = [indices] if indices else []
201
 
 
210
 
211
  sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity
212
 
213
+ formatted_output = "Binary Bag-of-Words Representation:\n" # Changed title
214
  if not sorted_representation:
215
  formatted_output += "No significant terms found for this input.\n"
216
  else:
 
220
  break
221
  formatted_output += f"- **{term}**\n"
222
 
223
+ formatted_output += "\n--- Raw Binary Bag-of-Words Vector Info ---\n" # Changed title
224
  formatted_output += f"Total activated terms: {len(indices)}\n"
225
  formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
226
 
 
229
 
230
  # --- Unified Prediction Function for the Explorer Tab ---
231
  def predict_representation_explorer(model_choice, text):
232
+ if model_choice == "MLM encoder (SPLADE-cocondenser-distil)":
233
  return get_splade_cocondenser_representation(text)
234
  elif model_choice == "MLP encoder (SPLADE-v3-lexical)":
235
  return get_splade_lexical_representation(text)
236
+ elif model_choice == "Binary Bag-of-Words": # Changed name
237
  return get_splade_doc_representation(text)
238
  else:
239
  return "Please select a model."
 
284
  return None
285
 
286
  def get_splade_doc_vector(text):
287
+ if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
288
  return None
289
 
290
  inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
291
+ inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
 
 
 
292
 
293
+ vocab_size = tokenizer_splade_doc.vocab_size
294
+ # Directly create the binary Bag-of-Words vector using the input_ids
295
+ binary_bow_vector = create_lexical_bow_mask(
296
+ inputs['input_ids'], vocab_size, tokenizer_splade_doc
297
+ ).squeeze()
298
+ return binary_bow_vector
 
299
 
300
 
301
  # --- Function to get formatted representation from a raw vector and tokenizer ---
 
312
  values = [1.0] * len(indices)
313
  else:
314
  values = splade_vector[indices].cpu().tolist()
315
+
316
  token_weights = dict(zip(indices, values))
317
 
318
  meaningful_tokens = {}
 
351
  return get_splade_cocondenser_vector, tokenizer_splade, False, "MLM encoder (SPLADE-cocondenser-distil)"
352
  elif model_choice_str == "MLP encoder (SPLADE-v3-lexical)":
353
  return get_splade_lexical_vector, tokenizer_splade_lexical, False, "MLP encoder (SPLADE-v3-lexical)"
354
+ elif model_choice_str == "Binary Bag-of-Words": # Changed name
355
+ return get_splade_doc_vector, tokenizer_splade_doc, True, "Binary Bag-of-Words" # Changed name
356
  else:
357
  return None, None, False, "Unknown Model"
358
 
 
405
  [
406
  "MLM encoder (SPLADE-cocondenser-distil)",
407
  "MLP encoder (SPLADE-v3-lexical)",
408
+ "Binary Bag-of-Words" # Changed name here
409
  ],
410
  label="Choose Sparse Encoder",
411
  value="MLM encoder (SPLADE-cocondenser-distil)"
 
429
  model_choices = [
430
  "MLM encoder (SPLADE-cocondenser-distil)",
431
  "MLP encoder (SPLADE-v3-lexical)",
432
+ "Binary Bag-of-Words" # Changed name here
433
  ]
434
 
435
  gr.Interface(