Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -107,7 +107,7 @@ def get_splade_representation(text):
|
|
107 |
return formatted_output
|
108 |
|
109 |
|
110 |
-
def get_splade_lexical_representation(text
|
111 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
112 |
return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
|
113 |
|
@@ -125,18 +125,17 @@ def get_splade_lexical_representation(text, apply_lexical_mask: bool): # Added p
|
|
125 |
else:
|
126 |
return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
|
127 |
|
128 |
-
# --- Apply Lexical Mask
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
splade_vector = splade_vector * bow_mask
|
140 |
# --- End Lexical Mask Logic ---
|
141 |
|
142 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
@@ -172,12 +171,8 @@ def get_splade_lexical_representation(text, apply_lexical_mask: bool): # Added p
|
|
172 |
def predict_representation(model_choice, text):
|
173 |
if model_choice == "SPLADE (cocondenser)":
|
174 |
return get_splade_representation(text)
|
175 |
-
elif model_choice == "SPLADE-v3-Lexical
|
176 |
-
#
|
177 |
-
return get_splade_lexical_representation(text, apply_lexical_mask=False)
|
178 |
-
elif model_choice == "SPLADE-v3-Lexical (lexical-only)":
|
179 |
-
# Call the lexical function applying the mask
|
180 |
-
return get_splade_lexical_representation(text, apply_lexical_mask=True)
|
181 |
else:
|
182 |
return "Please select a model."
|
183 |
|
@@ -188,8 +183,7 @@ demo = gr.Interface(
|
|
188 |
gr.Radio(
|
189 |
[
|
190 |
"SPLADE (cocondenser)",
|
191 |
-
"SPLADE-v3-Lexical
|
192 |
-
"SPLADE-v3-Lexical (lexical-only)" # Option with lexical mask applied
|
193 |
],
|
194 |
label="Choose Representation Model",
|
195 |
value="SPLADE (cocondenser)" # Default selection
|
@@ -202,7 +196,7 @@ demo = gr.Interface(
|
|
202 |
],
|
203 |
outputs=gr.Markdown(),
|
204 |
title="🌌 Sparse Representation Generator",
|
205 |
-
description="Enter any text to see its SPLADE sparse vector.
|
206 |
allow_flagging="never"
|
207 |
)
|
208 |
|
|
|
107 |
return formatted_output
|
108 |
|
109 |
|
110 |
+
def get_splade_lexical_representation(text): # Removed apply_lexical_mask parameter
|
111 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
112 |
return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
|
113 |
|
|
|
125 |
else:
|
126 |
return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
|
127 |
|
128 |
+
# --- Apply Lexical Mask (always applied for this function now) ---
|
129 |
+
# Get the vocabulary size from the tokenizer
|
130 |
+
vocab_size = tokenizer_splade_lexical.vocab_size
|
131 |
+
|
132 |
+
# Create the Bag-of-Words mask
|
133 |
+
bow_mask = create_lexical_bow_mask(
|
134 |
+
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
|
135 |
+
).squeeze()
|
136 |
+
|
137 |
+
# Multiply the SPLADE vector by the BoW mask to zero out expanded terms
|
138 |
+
splade_vector = splade_vector * bow_mask
|
|
|
139 |
# --- End Lexical Mask Logic ---
|
140 |
|
141 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
|
|
171 |
def predict_representation(model_choice, text):
|
172 |
if model_choice == "SPLADE (cocondenser)":
|
173 |
return get_splade_representation(text)
|
174 |
+
elif model_choice == "SPLADE-v3-Lexical": # Simplified choice
|
175 |
+
return get_splade_lexical_representation(text) # Always applies lexical mask
|
|
|
|
|
|
|
|
|
176 |
else:
|
177 |
return "Please select a model."
|
178 |
|
|
|
183 |
gr.Radio(
|
184 |
[
|
185 |
"SPLADE (cocondenser)",
|
186 |
+
"SPLADE-v3-Lexical" # Simplified option
|
|
|
187 |
],
|
188 |
label="Choose Representation Model",
|
189 |
value="SPLADE (cocondenser)" # Default selection
|
|
|
196 |
],
|
197 |
outputs=gr.Markdown(),
|
198 |
title="🌌 Sparse Representation Generator",
|
199 |
+
description="Enter any text to see its SPLADE sparse vector.", # Simplified description
|
200 |
allow_flagging="never"
|
201 |
)
|
202 |
|