SiddharthAK commited on
Commit
17afa62
·
verified ·
1 Parent(s): 7c4de94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -22
app.py CHANGED
@@ -107,7 +107,7 @@ def get_splade_representation(text):
107
  return formatted_output
108
 
109
 
110
- def get_splade_lexical_representation(text, apply_lexical_mask: bool): # Added parameter
111
  if tokenizer_splade_lexical is None or model_splade_lexical is None:
112
  return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
113
 
@@ -125,18 +125,17 @@ def get_splade_lexical_representation(text, apply_lexical_mask: bool): # Added p
125
  else:
126
  return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
127
 
128
- # --- Apply Lexical Mask if requested ---
129
- if apply_lexical_mask:
130
- # Get the vocabulary size from the tokenizer
131
- vocab_size = tokenizer_splade_lexical.vocab_size
132
-
133
- # Create the Bag-of-Words mask
134
- bow_mask = create_lexical_bow_mask(
135
- inputs['input_ids'], vocab_size, tokenizer_splade_lexical
136
- ).squeeze() # Squeeze to match splade_vector's [vocab_size] shape
137
-
138
- # Multiply the SPLADE vector by the BoW mask to zero out expanded terms
139
- splade_vector = splade_vector * bow_mask
140
  # --- End Lexical Mask Logic ---
141
 
142
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
@@ -172,12 +171,8 @@ def get_splade_lexical_representation(text, apply_lexical_mask: bool): # Added p
172
  def predict_representation(model_choice, text):
173
  if model_choice == "SPLADE (cocondenser)":
174
  return get_splade_representation(text)
175
- elif model_choice == "SPLADE-v3-Lexical (with expansion)":
176
- # Call the lexical function without applying the mask
177
- return get_splade_lexical_representation(text, apply_lexical_mask=False)
178
- elif model_choice == "SPLADE-v3-Lexical (lexical-only)":
179
- # Call the lexical function applying the mask
180
- return get_splade_lexical_representation(text, apply_lexical_mask=True)
181
  else:
182
  return "Please select a model."
183
 
@@ -188,8 +183,7 @@ demo = gr.Interface(
188
  gr.Radio(
189
  [
190
  "SPLADE (cocondenser)",
191
- "SPLADE-v3-Lexical (with expansion)", # Option to see full neural output
192
- "SPLADE-v3-Lexical (lexical-only)" # Option with lexical mask applied
193
  ],
194
  label="Choose Representation Model",
195
  value="SPLADE (cocondenser)" # Default selection
@@ -202,7 +196,7 @@ demo = gr.Interface(
202
  ],
203
  outputs=gr.Markdown(),
204
  title="🌌 Sparse Representation Generator",
205
- description="Enter any text to see its SPLADE sparse vector. Explore the difference between full neural expansion and lexical-only representations.",
206
  allow_flagging="never"
207
  )
208
 
 
107
  return formatted_output
108
 
109
 
110
+ def get_splade_lexical_representation(text): # Removed apply_lexical_mask parameter
111
  if tokenizer_splade_lexical is None or model_splade_lexical is None:
112
  return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
113
 
 
125
  else:
126
  return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
127
 
128
+ # --- Apply Lexical Mask (always applied for this function now) ---
129
+ # Get the vocabulary size from the tokenizer
130
+ vocab_size = tokenizer_splade_lexical.vocab_size
131
+
132
+ # Create the Bag-of-Words mask
133
+ bow_mask = create_lexical_bow_mask(
134
+ inputs['input_ids'], vocab_size, tokenizer_splade_lexical
135
+ ).squeeze()
136
+
137
+ # Multiply the SPLADE vector by the BoW mask to zero out expanded terms
138
+ splade_vector = splade_vector * bow_mask
 
139
  # --- End Lexical Mask Logic ---
140
 
141
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
 
171
  def predict_representation(model_choice, text):
172
  if model_choice == "SPLADE (cocondenser)":
173
  return get_splade_representation(text)
174
+ elif model_choice == "SPLADE-v3-Lexical": # Simplified choice
175
+ return get_splade_lexical_representation(text) # Always applies lexical mask
 
 
 
 
176
  else:
177
  return "Please select a model."
178
 
 
183
  gr.Radio(
184
  [
185
  "SPLADE (cocondenser)",
186
+ "SPLADE-v3-Lexical" # Simplified option
 
187
  ],
188
  label="Choose Representation Model",
189
  value="SPLADE (cocondenser)" # Default selection
 
196
  ],
197
  outputs=gr.Markdown(),
198
  title="🌌 Sparse Representation Generator",
199
+ description="Enter any text to see its SPLADE sparse vector.", # Simplified description
200
  allow_flagging="never"
201
  )
202