SiddharthAK commited on
Commit
7c4de94
·
verified ·
1 Parent(s): ab88097

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -6
app.py CHANGED
@@ -29,6 +29,34 @@ except Exception as e:
29
  print(f"Error loading SPLADE v3 Lexical model: {e}")
30
  print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  # --- Core Representation Functions ---
34
 
@@ -79,7 +107,7 @@ def get_splade_representation(text):
79
  return formatted_output
80
 
81
 
82
- def get_splade_lexical_representation(text):
83
  if tokenizer_splade_lexical is None or model_splade_lexical is None:
84
  return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
85
 
@@ -97,6 +125,20 @@ def get_splade_lexical_representation(text):
97
  else:
98
  return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
101
  if not isinstance(indices, list):
102
  indices = [indices]
@@ -130,8 +172,12 @@ def get_splade_lexical_representation(text):
130
  def predict_representation(model_choice, text):
131
  if model_choice == "SPLADE (cocondenser)":
132
  return get_splade_representation(text)
133
- elif model_choice == "SPLADE-v3-Lexical":
134
- return get_splade_lexical_representation(text)
 
 
 
 
135
  else:
136
  return "Please select a model."
137
 
@@ -140,7 +186,11 @@ demo = gr.Interface(
140
  fn=predict_representation,
141
  inputs=[
142
  gr.Radio(
143
- ["SPLADE (cocondenser)", "SPLADE-v3-Lexical"], # Updated options
 
 
 
 
144
  label="Choose Representation Model",
145
  value="SPLADE (cocondenser)" # Default selection
146
  ),
@@ -151,8 +201,8 @@ demo = gr.Interface(
151
  )
152
  ],
153
  outputs=gr.Markdown(),
154
- title="🌌 Sparse and Binary Sparse Representation Generator",
155
- description="Enter any text to see its SPLADE sparse vector or SPLADE-v3-Lexical representation.",
156
  allow_flagging="never"
157
  )
158
 
 
29
  print(f"Error loading SPLADE v3 Lexical model: {e}")
30
  print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
31
 
32
+ # --- Helper function for lexical mask ---
33
+ def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
34
+ """
35
+ Creates a binary bag-of-words mask from input_ids,
36
+ zeroing out special tokens and padding.
37
+ """
38
+ # Initialize a zero vector for the entire vocabulary
39
+ bow_mask = torch.zeros(vocab_size, device=input_ids.device)
40
+
41
+ # Get unique token IDs from the input, excluding special tokens
42
+ # input_ids is typically [batch_size, seq_len], we assume batch_size=1
43
+ meaningful_token_ids = []
44
+ for token_id in input_ids.squeeze().tolist(): # Squeeze to remove batch dim and convert to list
45
+ if token_id not in [
46
+ tokenizer.pad_token_id,
47
+ tokenizer.cls_token_id,
48
+ tokenizer.sep_token_id,
49
+ tokenizer.mask_token_id,
50
+ tokenizer.unk_token_id # Also exclude unknown tokens
51
+ ]:
52
+ meaningful_token_ids.append(token_id)
53
+
54
+ # Set 1 for tokens present in the original input
55
+ if meaningful_token_ids:
56
+ bow_mask[list(set(meaningful_token_ids))] = 1 # Use set to handle duplicates
57
+
58
+ return bow_mask.unsqueeze(0) # Keep batch dimension for consistency
59
+
60
 
61
  # --- Core Representation Functions ---
62
 
 
107
  return formatted_output
108
 
109
 
110
+ def get_splade_lexical_representation(text, apply_lexical_mask: bool): # Added parameter
111
  if tokenizer_splade_lexical is None or model_splade_lexical is None:
112
  return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
113
 
 
125
  else:
126
  return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
127
 
128
+ # --- Apply Lexical Mask if requested ---
129
+ if apply_lexical_mask:
130
+ # Get the vocabulary size from the tokenizer
131
+ vocab_size = tokenizer_splade_lexical.vocab_size
132
+
133
+ # Create the Bag-of-Words mask
134
+ bow_mask = create_lexical_bow_mask(
135
+ inputs['input_ids'], vocab_size, tokenizer_splade_lexical
136
+ ).squeeze() # Squeeze to match splade_vector's [vocab_size] shape
137
+
138
+ # Multiply the SPLADE vector by the BoW mask to zero out expanded terms
139
+ splade_vector = splade_vector * bow_mask
140
+ # --- End Lexical Mask Logic ---
141
+
142
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
143
  if not isinstance(indices, list):
144
  indices = [indices]
 
172
  def predict_representation(model_choice, text):
173
  if model_choice == "SPLADE (cocondenser)":
174
  return get_splade_representation(text)
175
+ elif model_choice == "SPLADE-v3-Lexical (with expansion)":
176
+ # Call the lexical function without applying the mask
177
+ return get_splade_lexical_representation(text, apply_lexical_mask=False)
178
+ elif model_choice == "SPLADE-v3-Lexical (lexical-only)":
179
+ # Call the lexical function applying the mask
180
+ return get_splade_lexical_representation(text, apply_lexical_mask=True)
181
  else:
182
  return "Please select a model."
183
 
 
186
  fn=predict_representation,
187
  inputs=[
188
  gr.Radio(
189
+ [
190
+ "SPLADE (cocondenser)",
191
+ "SPLADE-v3-Lexical (with expansion)", # Option to see full neural output
192
+ "SPLADE-v3-Lexical (lexical-only)" # Option with lexical mask applied
193
+ ],
194
  label="Choose Representation Model",
195
  value="SPLADE (cocondenser)" # Default selection
196
  ),
 
201
  )
202
  ],
203
  outputs=gr.Markdown(),
204
+ title="🌌 Sparse Representation Generator",
205
+ description="Enter any text to see its SPLADE sparse vector. Explore the difference between full neural expansion and lexical-only representations.",
206
  allow_flagging="never"
207
  )
208