Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,8 @@ tokenizer_splade = None
|
|
7 |
model_splade = None
|
8 |
tokenizer_splade_lexical = None
|
9 |
model_splade_lexical = None
|
|
|
|
|
10 |
|
11 |
# Load SPLADE v3 model (original)
|
12 |
try:
|
@@ -29,6 +31,18 @@ except Exception as e:
|
|
29 |
print(f"Error loading SPLADE v3 Lexical model: {e}")
|
30 |
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
# --- Helper function for lexical mask ---
|
33 |
def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
34 |
"""
|
@@ -107,7 +121,7 @@ def get_splade_representation(text):
|
|
107 |
return formatted_output
|
108 |
|
109 |
|
110 |
-
def get_splade_lexical_representation(text):
|
111 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
112 |
return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
|
113 |
|
@@ -167,12 +181,74 @@ def get_splade_lexical_representation(text): # Removed apply_lexical_mask parame
|
|
167 |
return formatted_output
|
168 |
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
# --- Unified Prediction Function for Gradio ---
|
171 |
def predict_representation(model_choice, text):
|
172 |
if model_choice == "SPLADE (cocondenser)":
|
173 |
return get_splade_representation(text)
|
174 |
-
elif model_choice == "SPLADE-v3-Lexical":
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
176 |
else:
|
177 |
return "Please select a model."
|
178 |
|
@@ -182,8 +258,10 @@ demo = gr.Interface(
|
|
182 |
inputs=[
|
183 |
gr.Radio(
|
184 |
[
|
185 |
-
"SPLADE
|
186 |
-
"SPLADE-v3-Lexical
|
|
|
|
|
187 |
],
|
188 |
label="Choose Representation Model",
|
189 |
value="SPLADE (cocondenser)" # Default selection
|
@@ -196,7 +274,7 @@ demo = gr.Interface(
|
|
196 |
],
|
197 |
outputs=gr.Markdown(),
|
198 |
title="🌌 Sparse Representation Generator",
|
199 |
-
description="Enter any text to see its SPLADE sparse vector.
|
200 |
allow_flagging="never"
|
201 |
)
|
202 |
|
|
|
7 |
model_splade = None
|
8 |
tokenizer_splade_lexical = None
|
9 |
model_splade_lexical = None
|
10 |
+
tokenizer_splade_doc = None # New tokenizer for SPLADE-v3-Doc
|
11 |
+
model_splade_doc = None # New model for SPLADE-v3-Doc
|
12 |
|
13 |
# Load SPLADE v3 model (original)
|
14 |
try:
|
|
|
31 |
print(f"Error loading SPLADE v3 Lexical model: {e}")
|
32 |
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
33 |
|
34 |
+
# Load SPLADE v3 Doc model (NEW)
|
35 |
+
try:
|
36 |
+
splade_doc_model_name = "naver/splade-v3-doc"
|
37 |
+
tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
|
38 |
+
model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name)
|
39 |
+
model_splade_doc.eval() # Set to evaluation mode for inference
|
40 |
+
print(f"SPLADE v3 Doc model '{splade_doc_model_name}' loaded successfully!")
|
41 |
+
except Exception as e:
|
42 |
+
print(f"Error loading SPLADE v3 Doc model: {e}")
|
43 |
+
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
44 |
+
|
45 |
+
|
46 |
# --- Helper function for lexical mask ---
|
47 |
def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
48 |
"""
|
|
|
121 |
return formatted_output
|
122 |
|
123 |
|
124 |
+
def get_splade_lexical_representation(text):
|
125 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
126 |
return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
|
127 |
|
|
|
181 |
return formatted_output
|
182 |
|
183 |
|
184 |
+
# NEW: Function for SPLADE-v3-Doc representation
|
185 |
+
def get_splade_doc_representation(text, apply_lexical_mask: bool):
|
186 |
+
if tokenizer_splade_doc is None or model_splade_doc is None:
|
187 |
+
return "SPLADE v3 Doc model is not loaded. Please check the console for loading errors."
|
188 |
+
|
189 |
+
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
|
190 |
+
inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()}
|
191 |
+
|
192 |
+
with torch.no_grad():
|
193 |
+
output = model_splade_doc(**inputs)
|
194 |
+
|
195 |
+
if hasattr(output, 'logits'):
|
196 |
+
splade_vector = torch.max(
|
197 |
+
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
|
198 |
+
dim=1
|
199 |
+
)[0].squeeze()
|
200 |
+
else:
|
201 |
+
return "Model output structure not as expected for SPLADE v3 Doc. 'logits' not found."
|
202 |
+
|
203 |
+
# --- Apply Lexical Mask if requested ---
|
204 |
+
if apply_lexical_mask:
|
205 |
+
vocab_size = tokenizer_splade_doc.vocab_size
|
206 |
+
bow_mask = create_lexical_bow_mask(
|
207 |
+
inputs['input_ids'], vocab_size, tokenizer_splade_doc
|
208 |
+
).squeeze()
|
209 |
+
splade_vector = splade_vector * bow_mask
|
210 |
+
# --- End Lexical Mask Logic ---
|
211 |
+
|
212 |
+
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
213 |
+
if not isinstance(indices, list):
|
214 |
+
indices = [indices]
|
215 |
+
|
216 |
+
values = splade_vector[indices].cpu().tolist()
|
217 |
+
token_weights = dict(zip(indices, values))
|
218 |
+
|
219 |
+
meaningful_tokens = {}
|
220 |
+
for token_id, weight in token_weights.items():
|
221 |
+
decoded_token = tokenizer_splade_doc.decode([token_id])
|
222 |
+
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
223 |
+
meaningful_tokens[decoded_token] = weight
|
224 |
+
|
225 |
+
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
|
226 |
+
|
227 |
+
formatted_output = "SPLADE v3 Doc Representation (All Non-Zero Terms):\n"
|
228 |
+
if not sorted_representation:
|
229 |
+
formatted_output += "No significant terms found for this input.\n"
|
230 |
+
else:
|
231 |
+
for term, weight in sorted_representation:
|
232 |
+
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
233 |
+
|
234 |
+
formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
|
235 |
+
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
|
236 |
+
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
|
237 |
+
|
238 |
+
return formatted_output
|
239 |
+
|
240 |
+
|
241 |
# --- Unified Prediction Function for Gradio ---
|
242 |
def predict_representation(model_choice, text):
|
243 |
if model_choice == "SPLADE (cocondenser)":
|
244 |
return get_splade_representation(text)
|
245 |
+
elif model_choice == "SPLADE-v3-Lexical":
|
246 |
+
# Always applies lexical mask for this option as per last request
|
247 |
+
return get_splade_lexical_representation(text)
|
248 |
+
elif model_choice == "SPLADE-v3-Doc (with expansion)": # New option
|
249 |
+
return get_splade_doc_representation(text, apply_lexical_mask=False)
|
250 |
+
elif model_choice == "SPLADE-v3-Doc (lexical-only)": # New option
|
251 |
+
return get_splade_doc_representation(text, apply_lexical_mask=True)
|
252 |
else:
|
253 |
return "Please select a model."
|
254 |
|
|
|
258 |
inputs=[
|
259 |
gr.Radio(
|
260 |
[
|
261 |
+
"SPLADE (cocondenser)",
|
262 |
+
"SPLADE-v3-Lexical", # Lexical-only by default now
|
263 |
+
"SPLADE-v3-Doc (with expansion)", # Option to see full neural output
|
264 |
+
"SPLADE-v3-Doc (lexical-only)" # Option with lexical mask applied
|
265 |
],
|
266 |
label="Choose Representation Model",
|
267 |
value="SPLADE (cocondenser)" # Default selection
|
|
|
274 |
],
|
275 |
outputs=gr.Markdown(),
|
276 |
title="🌌 Sparse Representation Generator",
|
277 |
+
description="Enter any text to see its SPLADE sparse vector. Explore different SPLADE models and their expansion behaviors.",
|
278 |
allow_flagging="never"
|
279 |
)
|
280 |
|