Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,17 +7,17 @@ tokenizer_splade = None
|
|
7 |
model_splade = None
|
8 |
tokenizer_splade_lexical = None
|
9 |
model_splade_lexical = None
|
10 |
-
tokenizer_splade_doc = None
|
11 |
-
model_splade_doc = None
|
12 |
|
13 |
# Load SPLADE v3 model (original)
|
14 |
try:
|
15 |
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
|
16 |
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
|
17 |
model_splade.eval() # Set to evaluation mode for inference
|
18 |
-
print("SPLADE
|
19 |
except Exception as e:
|
20 |
-
print(f"Error loading SPLADE
|
21 |
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
|
22 |
|
23 |
# Load SPLADE v3 Lexical model
|
@@ -26,24 +26,24 @@ try:
|
|
26 |
tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
|
27 |
model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
|
28 |
model_splade_lexical.eval() # Set to evaluation mode for inference
|
29 |
-
print(f"SPLADE
|
30 |
except Exception as e:
|
31 |
-
print(f"Error loading SPLADE
|
32 |
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
33 |
|
34 |
-
# Load SPLADE v3 Doc model
|
35 |
try:
|
36 |
splade_doc_model_name = "naver/splade-v3-doc"
|
37 |
tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
|
38 |
model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name)
|
39 |
model_splade_doc.eval() # Set to evaluation mode for inference
|
40 |
-
print(f"SPLADE
|
41 |
except Exception as e:
|
42 |
-
print(f"Error loading SPLADE
|
43 |
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
44 |
|
45 |
|
46 |
-
# --- Helper function for lexical mask
|
47 |
def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
48 |
"""
|
49 |
Creates a binary bag-of-words mask from input_ids,
|
@@ -69,9 +69,9 @@ def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
|
69 |
|
70 |
# --- Core Representation Functions ---
|
71 |
|
72 |
-
def
|
73 |
if tokenizer_splade is None or model_splade is None:
|
74 |
-
return "SPLADE
|
75 |
|
76 |
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
|
77 |
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
|
@@ -80,12 +80,13 @@ def get_splade_representation(text):
|
|
80 |
output = model_splade(**inputs)
|
81 |
|
82 |
if hasattr(output, 'logits'):
|
|
|
83 |
splade_vector = torch.max(
|
84 |
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
|
85 |
dim=1
|
86 |
)[0].squeeze()
|
87 |
else:
|
88 |
-
return "Model output structure not as expected for SPLADE
|
89 |
|
90 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
91 |
if not isinstance(indices, list):
|
@@ -102,7 +103,7 @@ def get_splade_representation(text):
|
|
102 |
|
103 |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
|
104 |
|
105 |
-
formatted_output = "SPLADE
|
106 |
if not sorted_representation:
|
107 |
formatted_output += "No significant terms found for this input.\n"
|
108 |
else:
|
@@ -118,7 +119,7 @@ def get_splade_representation(text):
|
|
118 |
|
119 |
def get_splade_lexical_representation(text):
|
120 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
121 |
-
return "SPLADE
|
122 |
|
123 |
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
|
124 |
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
|
@@ -132,15 +133,14 @@ def get_splade_lexical_representation(text):
|
|
132 |
dim=1
|
133 |
)[0].squeeze()
|
134 |
else:
|
135 |
-
return "Model output structure not as expected for SPLADE
|
136 |
|
137 |
-
#
|
138 |
vocab_size = tokenizer_splade_lexical.vocab_size
|
139 |
bow_mask = create_lexical_bow_mask(
|
140 |
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
|
141 |
).squeeze()
|
142 |
splade_vector = splade_vector * bow_mask
|
143 |
-
# --- End Lexical Mask Logic ---
|
144 |
|
145 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
146 |
if not isinstance(indices, list):
|
@@ -157,7 +157,7 @@ def get_splade_lexical_representation(text):
|
|
157 |
|
158 |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
|
159 |
|
160 |
-
formatted_output = "SPLADE
|
161 |
if not sorted_representation:
|
162 |
formatted_output += "No significant terms found for this input.\n"
|
163 |
else:
|
@@ -171,10 +171,10 @@ def get_splade_lexical_representation(text):
|
|
171 |
return formatted_output
|
172 |
|
173 |
|
174 |
-
#
|
175 |
def get_splade_doc_representation(text):
|
176 |
if tokenizer_splade_doc is None or model_splade_doc is None:
|
177 |
-
return "SPLADE
|
178 |
|
179 |
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
|
180 |
inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()}
|
@@ -183,33 +183,22 @@ def get_splade_doc_representation(text):
|
|
183 |
output = model_splade_doc(**inputs)
|
184 |
|
185 |
if not hasattr(output, "logits"):
|
186 |
-
return "SPLADE
|
187 |
|
188 |
-
# For SPLADE-v3-Doc,
|
189 |
-
# We will
|
190 |
-
#
|
191 |
-
#
|
192 |
-
|
193 |
-
# Option 1: Binarize based on softplus output and threshold (similar to UNICOIL)
|
194 |
-
# This might still activate some "expanded" terms if the model predicts them strongly.
|
195 |
-
# transformed_scores = torch.log(1 + torch.exp(output.logits)) # Softplus
|
196 |
-
# splade_vector_raw = torch.max(transformed_scores * inputs['attention_mask'].unsqueeze(-1), dim=1).values
|
197 |
-
# binary_splade_vector = (splade_vector_raw > 0.5).float() # Binarize
|
198 |
-
|
199 |
-
# Option 2: Rely on the original BoW for terms, with 1 for presence
|
200 |
-
# This aligns best with "no weighting, no expansion"
|
201 |
vocab_size = tokenizer_splade_doc.vocab_size
|
202 |
-
binary_splade_vector = create_lexical_bow_mask(
|
203 |
inputs['input_ids'], vocab_size, tokenizer_splade_doc
|
204 |
).squeeze()
|
205 |
|
206 |
-
# We set values to 1 as it's a binary representation, not weighted
|
207 |
indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
|
208 |
-
if not isinstance(indices, list):
|
209 |
-
indices = [indices] if indices else []
|
210 |
|
211 |
-
#
|
212 |
-
values = [1.0] * len(indices)
|
213 |
token_weights = dict(zip(indices, values))
|
214 |
|
215 |
meaningful_tokens = {}
|
@@ -218,16 +207,14 @@ def get_splade_doc_representation(text):
|
|
218 |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
219 |
meaningful_tokens[decoded_token] = weight
|
220 |
|
221 |
-
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for
|
222 |
|
223 |
-
formatted_output = "SPLADE
|
224 |
if not sorted_representation:
|
225 |
formatted_output += "No significant terms found for this input.\n"
|
226 |
else:
|
227 |
-
# Display as terms with no weights as they are binary (value 1)
|
228 |
for i, (term, _) in enumerate(sorted_representation):
|
229 |
-
# Limit display for very long lists
|
230 |
-
if i >= 50:
|
231 |
formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n"
|
232 |
break
|
233 |
formatted_output += f"- **{term}**\n"
|
@@ -241,13 +228,11 @@ def get_splade_doc_representation(text):
|
|
241 |
|
242 |
# --- Unified Prediction Function for Gradio ---
|
243 |
def predict_representation(model_choice, text):
|
244 |
-
if model_choice == "SPLADE (
|
245 |
-
return
|
246 |
-
elif model_choice == "SPLADE-v3-Lexical":
|
247 |
-
# Always applies lexical mask for this option
|
248 |
return get_splade_lexical_representation(text)
|
249 |
-
elif model_choice == "SPLADE-v3-Doc":
|
250 |
-
# This function now intrinsically handles binary, lexical-only output
|
251 |
return get_splade_doc_representation(text)
|
252 |
else:
|
253 |
return "Please select a model."
|
@@ -260,10 +245,10 @@ demo = gr.Interface(
|
|
260 |
[
|
261 |
"SPLADE-cocondenser-distil (weighting and expansion)",
|
262 |
"SPLADE-v3-Lexical (weighting)",
|
263 |
-
"SPLADE-v3-Doc (binary)"
|
264 |
],
|
265 |
label="Choose Representation Model",
|
266 |
-
value="SPLADE (
|
267 |
),
|
268 |
gr.Textbox(
|
269 |
lines=5,
|
@@ -273,7 +258,7 @@ demo = gr.Interface(
|
|
273 |
],
|
274 |
outputs=gr.Markdown(),
|
275 |
title="🌌 Sparse Representation Generator",
|
276 |
-
description="
|
277 |
allow_flagging="never"
|
278 |
)
|
279 |
|
|
|
7 |
model_splade = None
|
8 |
tokenizer_splade_lexical = None
|
9 |
model_splade_lexical = None
|
10 |
+
tokenizer_splade_doc = None
|
11 |
+
model_splade_doc = None
|
12 |
|
13 |
# Load SPLADE v3 model (original)
|
14 |
try:
|
15 |
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
|
16 |
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
|
17 |
model_splade.eval() # Set to evaluation mode for inference
|
18 |
+
print("SPLADE-cocondenser-distil model loaded successfully!")
|
19 |
except Exception as e:
|
20 |
+
print(f"Error loading SPLADE-cocondenser-distil model: {e}")
|
21 |
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
|
22 |
|
23 |
# Load SPLADE v3 Lexical model
|
|
|
26 |
tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
|
27 |
model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
|
28 |
model_splade_lexical.eval() # Set to evaluation mode for inference
|
29 |
+
print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!")
|
30 |
except Exception as e:
|
31 |
+
print(f"Error loading SPLADE-v3-Lexical model: {e}")
|
32 |
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
33 |
|
34 |
+
# Load SPLADE v3 Doc model
|
35 |
try:
|
36 |
splade_doc_model_name = "naver/splade-v3-doc"
|
37 |
tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
|
38 |
model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name)
|
39 |
model_splade_doc.eval() # Set to evaluation mode for inference
|
40 |
+
print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!")
|
41 |
except Exception as e:
|
42 |
+
print(f"Error loading SPLADE-v3-Doc model: {e}")
|
43 |
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
44 |
|
45 |
|
46 |
+
# --- Helper function for lexical mask ---
|
47 |
def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
48 |
"""
|
49 |
Creates a binary bag-of-words mask from input_ids,
|
|
|
69 |
|
70 |
# --- Core Representation Functions ---
|
71 |
|
72 |
+
def get_splade_cocondenser_representation(text):
|
73 |
if tokenizer_splade is None or model_splade is None:
|
74 |
+
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
|
75 |
|
76 |
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
|
77 |
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
|
|
|
80 |
output = model_splade(**inputs)
|
81 |
|
82 |
if hasattr(output, 'logits'):
|
83 |
+
# Standard SPLADE calculation for learned weighting and expansion
|
84 |
splade_vector = torch.max(
|
85 |
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
|
86 |
dim=1
|
87 |
)[0].squeeze()
|
88 |
else:
|
89 |
+
return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found."
|
90 |
|
91 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
92 |
if not isinstance(indices, list):
|
|
|
103 |
|
104 |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
|
105 |
|
106 |
+
formatted_output = "SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n"
|
107 |
if not sorted_representation:
|
108 |
formatted_output += "No significant terms found for this input.\n"
|
109 |
else:
|
|
|
119 |
|
120 |
def get_splade_lexical_representation(text):
|
121 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
122 |
+
return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors."
|
123 |
|
124 |
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
|
125 |
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
|
|
|
133 |
dim=1
|
134 |
)[0].squeeze()
|
135 |
else:
|
136 |
+
return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found."
|
137 |
|
138 |
+
# Always apply lexical mask for this model's specific behavior
|
139 |
vocab_size = tokenizer_splade_lexical.vocab_size
|
140 |
bow_mask = create_lexical_bow_mask(
|
141 |
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
|
142 |
).squeeze()
|
143 |
splade_vector = splade_vector * bow_mask
|
|
|
144 |
|
145 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
146 |
if not isinstance(indices, list):
|
|
|
157 |
|
158 |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
|
159 |
|
160 |
+
formatted_output = "SPLADE-v3-Lexical Representation (Weighting):\n"
|
161 |
if not sorted_representation:
|
162 |
formatted_output += "No significant terms found for this input.\n"
|
163 |
else:
|
|
|
171 |
return formatted_output
|
172 |
|
173 |
|
174 |
+
# Function for SPLADE-v3-Doc representation (Binary Sparse - Lexical Only)
|
175 |
def get_splade_doc_representation(text):
|
176 |
if tokenizer_splade_doc is None or model_splade_doc is None:
|
177 |
+
return "SPLADE-v3-Doc model is not loaded. Please check the console for loading errors."
|
178 |
|
179 |
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
|
180 |
inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()}
|
|
|
183 |
output = model_splade_doc(**inputs)
|
184 |
|
185 |
if not hasattr(output, "logits"):
|
186 |
+
return "SPLADE-v3-Doc model output structure not as expected. 'logits' not found."
|
187 |
|
188 |
+
# For SPLADE-v3-Doc, assuming output is designed to be binary and lexical-only.
|
189 |
+
# We will derive the output directly from the input tokens themselves,
|
190 |
+
# as the model's primary role in this context is as a pre-trained LM feature extractor
|
191 |
+
# for a document-side, lexical-only binary sparse representation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
vocab_size = tokenizer_splade_doc.vocab_size
|
193 |
+
binary_splade_vector = create_lexical_bow_mask( # Use the BOW mask directly for binary
|
194 |
inputs['input_ids'], vocab_size, tokenizer_splade_doc
|
195 |
).squeeze()
|
196 |
|
|
|
197 |
indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
|
198 |
+
if not isinstance(indices, list):
|
199 |
+
indices = [indices] if indices else []
|
200 |
|
201 |
+
values = [1.0] * len(indices) # All values are 1 for binary representation
|
|
|
202 |
token_weights = dict(zip(indices, values))
|
203 |
|
204 |
meaningful_tokens = {}
|
|
|
207 |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
208 |
meaningful_tokens[decoded_token] = weight
|
209 |
|
210 |
+
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity
|
211 |
|
212 |
+
formatted_output = "SPLADE-v3-Doc Representation (Binary):\n"
|
213 |
if not sorted_representation:
|
214 |
formatted_output += "No significant terms found for this input.\n"
|
215 |
else:
|
|
|
216 |
for i, (term, _) in enumerate(sorted_representation):
|
217 |
+
if i >= 50: # Limit display for very long lists
|
|
|
218 |
formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n"
|
219 |
break
|
220 |
formatted_output += f"- **{term}**\n"
|
|
|
228 |
|
229 |
# --- Unified Prediction Function for Gradio ---
|
230 |
def predict_representation(model_choice, text):
|
231 |
+
if model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
|
232 |
+
return get_splade_cocondenser_representation(text)
|
233 |
+
elif model_choice == "SPLADE-v3-Lexical (weighting)":
|
|
|
234 |
return get_splade_lexical_representation(text)
|
235 |
+
elif model_choice == "SPLADE-v3-Doc (binary)":
|
|
|
236 |
return get_splade_doc_representation(text)
|
237 |
else:
|
238 |
return "Please select a model."
|
|
|
245 |
[
|
246 |
"SPLADE-cocondenser-distil (weighting and expansion)",
|
247 |
"SPLADE-v3-Lexical (weighting)",
|
248 |
+
"SPLADE-v3-Doc (binary)"
|
249 |
],
|
250 |
label="Choose Representation Model",
|
251 |
+
value="SPLADE-cocondenser-distil (weighting and expansion)" # Corrected default value
|
252 |
),
|
253 |
gr.Textbox(
|
254 |
lines=5,
|
|
|
258 |
],
|
259 |
outputs=gr.Markdown(),
|
260 |
title="🌌 Sparse Representation Generator",
|
261 |
+
description="Explore different SPLADE models and their sparse representation types: weighted and expansive, weighted and lexical-only, or strictly binary.",
|
262 |
allow_flagging="never"
|
263 |
)
|
264 |
|