SiddharthAK commited on
Commit
6024481
·
verified ·
1 Parent(s): 358e025

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -55
app.py CHANGED
@@ -7,17 +7,17 @@ tokenizer_splade = None
7
  model_splade = None
8
  tokenizer_splade_lexical = None
9
  model_splade_lexical = None
10
- tokenizer_splade_doc = None # New tokenizer for SPLADE-v3-Doc
11
- model_splade_doc = None # New model for SPLADE-v3-Doc
12
 
13
  # Load SPLADE v3 model (original)
14
  try:
15
  tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
16
  model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
17
  model_splade.eval() # Set to evaluation mode for inference
18
- print("SPLADE v3 (cocondenser) model loaded successfully!")
19
  except Exception as e:
20
- print(f"Error loading SPLADE (cocondenser) model: {e}")
21
  print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
22
 
23
  # Load SPLADE v3 Lexical model
@@ -26,24 +26,24 @@ try:
26
  tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
27
  model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
28
  model_splade_lexical.eval() # Set to evaluation mode for inference
29
- print(f"SPLADE v3 Lexical model '{splade_lexical_model_name}' loaded successfully!")
30
  except Exception as e:
31
- print(f"Error loading SPLADE v3 Lexical model: {e}")
32
  print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
33
 
34
- # Load SPLADE v3 Doc model (NEW)
35
  try:
36
  splade_doc_model_name = "naver/splade-v3-doc"
37
  tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
38
  model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name)
39
  model_splade_doc.eval() # Set to evaluation mode for inference
40
- print(f"SPLADE v3 Doc model '{splade_doc_model_name}' loaded successfully!")
41
  except Exception as e:
42
- print(f"Error loading SPLADE v3 Doc model: {e}")
43
  print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
44
 
45
 
46
- # --- Helper function for lexical mask (still needed for splade-v3-lexical) ---
47
  def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
48
  """
49
  Creates a binary bag-of-words mask from input_ids,
@@ -69,9 +69,9 @@ def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
69
 
70
  # --- Core Representation Functions ---
71
 
72
- def get_splade_representation(text):
73
  if tokenizer_splade is None or model_splade is None:
74
- return "SPLADE (cocondenser) model is not loaded. Please check the console for loading errors."
75
 
76
  inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
77
  inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
@@ -80,12 +80,13 @@ def get_splade_representation(text):
80
  output = model_splade(**inputs)
81
 
82
  if hasattr(output, 'logits'):
 
83
  splade_vector = torch.max(
84
  torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
85
  dim=1
86
  )[0].squeeze()
87
  else:
88
- return "Model output structure not as expected for SPLADE (cocondenser). 'logits' not found."
89
 
90
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
91
  if not isinstance(indices, list):
@@ -102,7 +103,7 @@ def get_splade_representation(text):
102
 
103
  sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
104
 
105
- formatted_output = "SPLADE (cocondenser) Representation (All Non-Zero Terms):\n"
106
  if not sorted_representation:
107
  formatted_output += "No significant terms found for this input.\n"
108
  else:
@@ -118,7 +119,7 @@ def get_splade_representation(text):
118
 
119
  def get_splade_lexical_representation(text):
120
  if tokenizer_splade_lexical is None or model_splade_lexical is None:
121
- return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
122
 
123
  inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
124
  inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
@@ -132,15 +133,14 @@ def get_splade_lexical_representation(text):
132
  dim=1
133
  )[0].squeeze()
134
  else:
135
- return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
136
 
137
- # --- Apply Lexical Mask (always applied for this function now) ---
138
  vocab_size = tokenizer_splade_lexical.vocab_size
139
  bow_mask = create_lexical_bow_mask(
140
  inputs['input_ids'], vocab_size, tokenizer_splade_lexical
141
  ).squeeze()
142
  splade_vector = splade_vector * bow_mask
143
- # --- End Lexical Mask Logic ---
144
 
145
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
146
  if not isinstance(indices, list):
@@ -157,7 +157,7 @@ def get_splade_lexical_representation(text):
157
 
158
  sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
159
 
160
- formatted_output = "SPLADE v3 Lexical Representation (All Non-Zero Terms):\n"
161
  if not sorted_representation:
162
  formatted_output += "No significant terms found for this input.\n"
163
  else:
@@ -171,10 +171,10 @@ def get_splade_lexical_representation(text):
171
  return formatted_output
172
 
173
 
174
- # NEW: Function for SPLADE-v3-Doc representation (Binary Sparse)
175
  def get_splade_doc_representation(text):
176
  if tokenizer_splade_doc is None or model_splade_doc is None:
177
- return "SPLADE v3 Doc model is not loaded. Please check the console for loading errors."
178
 
179
  inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
180
  inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()}
@@ -183,33 +183,22 @@ def get_splade_doc_representation(text):
183
  output = model_splade_doc(**inputs)
184
 
185
  if not hasattr(output, "logits"):
186
- return "SPLADE v3 Doc model output structure not as expected. 'logits' not found."
187
 
188
- # For SPLADE-v3-Doc, the output is often a binary sparse vector.
189
- # We will assume a simple binarization based on a threshold or selecting active tokens.
190
- # A common way to get "binary" is to use softplus and then binarize, or directly binarize max logits.
191
- # Given the "no weighting, no expansion" request, we'll aim for a strict presence check.
192
-
193
- # Option 1: Binarize based on softplus output and threshold (similar to UNICOIL)
194
- # This might still activate some "expanded" terms if the model predicts them strongly.
195
- # transformed_scores = torch.log(1 + torch.exp(output.logits)) # Softplus
196
- # splade_vector_raw = torch.max(transformed_scores * inputs['attention_mask'].unsqueeze(-1), dim=1).values
197
- # binary_splade_vector = (splade_vector_raw > 0.5).float() # Binarize
198
-
199
- # Option 2: Rely on the original BoW for terms, with 1 for presence
200
- # This aligns best with "no weighting, no expansion"
201
  vocab_size = tokenizer_splade_doc.vocab_size
202
- binary_splade_vector = create_lexical_bow_mask(
203
  inputs['input_ids'], vocab_size, tokenizer_splade_doc
204
  ).squeeze()
205
 
206
- # We set values to 1 as it's a binary representation, not weighted
207
  indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
208
- if not isinstance(indices, list): # Handle case where only one non-zero index
209
- indices = [indices] if indices else [] # Ensure it's a list even if empty or single
210
 
211
- # Values are all 1 for binary representation
212
- values = [1.0] * len(indices)
213
  token_weights = dict(zip(indices, values))
214
 
215
  meaningful_tokens = {}
@@ -218,16 +207,14 @@ def get_splade_doc_representation(text):
218
  if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
219
  meaningful_tokens[decoded_token] = weight
220
 
221
- sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for binary
222
 
223
- formatted_output = "SPLADE v3 Doc Representation (Binary Sparse - Lexical Only):\n"
224
  if not sorted_representation:
225
  formatted_output += "No significant terms found for this input.\n"
226
  else:
227
- # Display as terms with no weights as they are binary (value 1)
228
  for i, (term, _) in enumerate(sorted_representation):
229
- # Limit display for very long lists for readability
230
- if i >= 50:
231
  formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n"
232
  break
233
  formatted_output += f"- **{term}**\n"
@@ -241,13 +228,11 @@ def get_splade_doc_representation(text):
241
 
242
  # --- Unified Prediction Function for Gradio ---
243
  def predict_representation(model_choice, text):
244
- if model_choice == "SPLADE (cocondenser)":
245
- return get_splade_representation(text)
246
- elif model_choice == "SPLADE-v3-Lexical":
247
- # Always applies lexical mask for this option
248
  return get_splade_lexical_representation(text)
249
- elif model_choice == "SPLADE-v3-Doc": # Simplified to a single option
250
- # This function now intrinsically handles binary, lexical-only output
251
  return get_splade_doc_representation(text)
252
  else:
253
  return "Please select a model."
@@ -260,10 +245,10 @@ demo = gr.Interface(
260
  [
261
  "SPLADE-cocondenser-distil (weighting and expansion)",
262
  "SPLADE-v3-Lexical (weighting)",
263
- "SPLADE-v3-Doc (binary)" # Only one option for Doc model
264
  ],
265
  label="Choose Representation Model",
266
- value="SPLADE (cocondenser)" # Default selection
267
  ),
268
  gr.Textbox(
269
  lines=5,
@@ -273,7 +258,7 @@ demo = gr.Interface(
273
  ],
274
  outputs=gr.Markdown(),
275
  title="🌌 Sparse Representation Generator",
276
- description="Enter any text to see its sparse vector representation.", # Simplified description
277
  allow_flagging="never"
278
  )
279
 
 
7
  model_splade = None
8
  tokenizer_splade_lexical = None
9
  model_splade_lexical = None
10
+ tokenizer_splade_doc = None
11
+ model_splade_doc = None
12
 
13
  # Load SPLADE v3 model (original)
14
  try:
15
  tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
16
  model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
17
  model_splade.eval() # Set to evaluation mode for inference
18
+ print("SPLADE-cocondenser-distil model loaded successfully!")
19
  except Exception as e:
20
+ print(f"Error loading SPLADE-cocondenser-distil model: {e}")
21
  print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
22
 
23
  # Load SPLADE v3 Lexical model
 
26
  tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
27
  model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
28
  model_splade_lexical.eval() # Set to evaluation mode for inference
29
+ print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!")
30
  except Exception as e:
31
+ print(f"Error loading SPLADE-v3-Lexical model: {e}")
32
  print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
33
 
34
+ # Load SPLADE v3 Doc model
35
  try:
36
  splade_doc_model_name = "naver/splade-v3-doc"
37
  tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
38
  model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name)
39
  model_splade_doc.eval() # Set to evaluation mode for inference
40
+ print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!")
41
  except Exception as e:
42
+ print(f"Error loading SPLADE-v3-Doc model: {e}")
43
  print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
44
 
45
 
46
+ # --- Helper function for lexical mask ---
47
  def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
48
  """
49
  Creates a binary bag-of-words mask from input_ids,
 
69
 
70
  # --- Core Representation Functions ---
71
 
72
+ def get_splade_cocondenser_representation(text):
73
  if tokenizer_splade is None or model_splade is None:
74
+ return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
75
 
76
  inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
77
  inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
 
80
  output = model_splade(**inputs)
81
 
82
  if hasattr(output, 'logits'):
83
+ # Standard SPLADE calculation for learned weighting and expansion
84
  splade_vector = torch.max(
85
  torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
86
  dim=1
87
  )[0].squeeze()
88
  else:
89
+ return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found."
90
 
91
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
92
  if not isinstance(indices, list):
 
103
 
104
  sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
105
 
106
+ formatted_output = "SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n"
107
  if not sorted_representation:
108
  formatted_output += "No significant terms found for this input.\n"
109
  else:
 
119
 
120
  def get_splade_lexical_representation(text):
121
  if tokenizer_splade_lexical is None or model_splade_lexical is None:
122
+ return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors."
123
 
124
  inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
125
  inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
 
133
  dim=1
134
  )[0].squeeze()
135
  else:
136
+ return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found."
137
 
138
+ # Always apply lexical mask for this model's specific behavior
139
  vocab_size = tokenizer_splade_lexical.vocab_size
140
  bow_mask = create_lexical_bow_mask(
141
  inputs['input_ids'], vocab_size, tokenizer_splade_lexical
142
  ).squeeze()
143
  splade_vector = splade_vector * bow_mask
 
144
 
145
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
146
  if not isinstance(indices, list):
 
157
 
158
  sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
159
 
160
+ formatted_output = "SPLADE-v3-Lexical Representation (Weighting):\n"
161
  if not sorted_representation:
162
  formatted_output += "No significant terms found for this input.\n"
163
  else:
 
171
  return formatted_output
172
 
173
 
174
+ # Function for SPLADE-v3-Doc representation (Binary Sparse - Lexical Only)
175
  def get_splade_doc_representation(text):
176
  if tokenizer_splade_doc is None or model_splade_doc is None:
177
+ return "SPLADE-v3-Doc model is not loaded. Please check the console for loading errors."
178
 
179
  inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
180
  inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()}
 
183
  output = model_splade_doc(**inputs)
184
 
185
  if not hasattr(output, "logits"):
186
+ return "SPLADE-v3-Doc model output structure not as expected. 'logits' not found."
187
 
188
+ # For SPLADE-v3-Doc, assuming output is designed to be binary and lexical-only.
189
+ # We will derive the output directly from the input tokens themselves,
190
+ # as the model's primary role in this context is as a pre-trained LM feature extractor
191
+ # for a document-side, lexical-only binary sparse representation.
 
 
 
 
 
 
 
 
 
192
  vocab_size = tokenizer_splade_doc.vocab_size
193
+ binary_splade_vector = create_lexical_bow_mask( # Use the BOW mask directly for binary
194
  inputs['input_ids'], vocab_size, tokenizer_splade_doc
195
  ).squeeze()
196
 
 
197
  indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
198
+ if not isinstance(indices, list):
199
+ indices = [indices] if indices else []
200
 
201
+ values = [1.0] * len(indices) # All values are 1 for binary representation
 
202
  token_weights = dict(zip(indices, values))
203
 
204
  meaningful_tokens = {}
 
207
  if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
208
  meaningful_tokens[decoded_token] = weight
209
 
210
+ sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity
211
 
212
+ formatted_output = "SPLADE-v3-Doc Representation (Binary):\n"
213
  if not sorted_representation:
214
  formatted_output += "No significant terms found for this input.\n"
215
  else:
 
216
  for i, (term, _) in enumerate(sorted_representation):
217
+ if i >= 50: # Limit display for very long lists
 
218
  formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n"
219
  break
220
  formatted_output += f"- **{term}**\n"
 
228
 
229
  # --- Unified Prediction Function for Gradio ---
230
  def predict_representation(model_choice, text):
231
+ if model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
232
+ return get_splade_cocondenser_representation(text)
233
+ elif model_choice == "SPLADE-v3-Lexical (weighting)":
 
234
  return get_splade_lexical_representation(text)
235
+ elif model_choice == "SPLADE-v3-Doc (binary)":
 
236
  return get_splade_doc_representation(text)
237
  else:
238
  return "Please select a model."
 
245
  [
246
  "SPLADE-cocondenser-distil (weighting and expansion)",
247
  "SPLADE-v3-Lexical (weighting)",
248
+ "SPLADE-v3-Doc (binary)"
249
  ],
250
  label="Choose Representation Model",
251
+ value="SPLADE-cocondenser-distil (weighting and expansion)" # Corrected default value
252
  ),
253
  gr.Textbox(
254
  lines=5,
 
258
  ],
259
  outputs=gr.Markdown(),
260
  title="🌌 Sparse Representation Generator",
261
+ description="Explore different SPLADE models and their sparse representation types: weighted and expansive, weighted and lexical-only, or strictly binary.",
262
  allow_flagging="never"
263
  )
264