Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -43,33 +43,28 @@ except Exception as e:
|
|
43 |
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
44 |
|
45 |
|
46 |
-
# --- Helper function for lexical mask ---
|
47 |
def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
48 |
"""
|
49 |
Creates a binary bag-of-words mask from input_ids,
|
50 |
zeroing out special tokens and padding.
|
51 |
"""
|
52 |
-
# Initialize a zero vector for the entire vocabulary
|
53 |
bow_mask = torch.zeros(vocab_size, device=input_ids.device)
|
54 |
-
|
55 |
-
# Get unique token IDs from the input, excluding special tokens
|
56 |
-
# input_ids is typically [batch_size, seq_len], we assume batch_size=1
|
57 |
meaningful_token_ids = []
|
58 |
-
for token_id in input_ids.squeeze().tolist():
|
59 |
if token_id not in [
|
60 |
tokenizer.pad_token_id,
|
61 |
tokenizer.cls_token_id,
|
62 |
tokenizer.sep_token_id,
|
63 |
tokenizer.mask_token_id,
|
64 |
-
tokenizer.unk_token_id
|
65 |
]:
|
66 |
meaningful_token_ids.append(token_id)
|
67 |
|
68 |
-
# Set 1 for tokens present in the original input
|
69 |
if meaningful_token_ids:
|
70 |
-
bow_mask[list(set(meaningful_token_ids))] = 1
|
71 |
|
72 |
-
return bow_mask.unsqueeze(0)
|
73 |
|
74 |
|
75 |
# --- Core Representation Functions ---
|
@@ -140,15 +135,10 @@ def get_splade_lexical_representation(text):
|
|
140 |
return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
|
141 |
|
142 |
# --- Apply Lexical Mask (always applied for this function now) ---
|
143 |
-
# Get the vocabulary size from the tokenizer
|
144 |
vocab_size = tokenizer_splade_lexical.vocab_size
|
145 |
-
|
146 |
-
# Create the Bag-of-Words mask
|
147 |
bow_mask = create_lexical_bow_mask(
|
148 |
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
|
149 |
).squeeze()
|
150 |
-
|
151 |
-
# Multiply the SPLADE vector by the BoW mask to zero out expanded terms
|
152 |
splade_vector = splade_vector * bow_mask
|
153 |
# --- End Lexical Mask Logic ---
|
154 |
|
@@ -181,8 +171,8 @@ def get_splade_lexical_representation(text):
|
|
181 |
return formatted_output
|
182 |
|
183 |
|
184 |
-
# NEW: Function for SPLADE-v3-Doc representation
|
185 |
-
def get_splade_doc_representation(text
|
186 |
if tokenizer_splade_doc is None or model_splade_doc is None:
|
187 |
return "SPLADE v3 Doc model is not loaded. Please check the console for loading errors."
|
188 |
|
@@ -192,28 +182,34 @@ def get_splade_doc_representation(text, apply_lexical_mask: bool):
|
|
192 |
with torch.no_grad():
|
193 |
output = model_splade_doc(**inputs)
|
194 |
|
195 |
-
if hasattr(output,
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
#
|
204 |
-
if
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
#
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
values
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
token_weights = dict(zip(indices, values))
|
218 |
|
219 |
meaningful_tokens = {}
|
@@ -222,17 +218,22 @@ def get_splade_doc_representation(text, apply_lexical_mask: bool):
|
|
222 |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
223 |
meaningful_tokens[decoded_token] = weight
|
224 |
|
225 |
-
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[
|
226 |
|
227 |
-
formatted_output = "SPLADE v3 Doc Representation (
|
228 |
if not sorted_representation:
|
229 |
formatted_output += "No significant terms found for this input.\n"
|
230 |
else:
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
236 |
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
|
237 |
|
238 |
return formatted_output
|
@@ -243,12 +244,11 @@ def predict_representation(model_choice, text):
|
|
243 |
if model_choice == "SPLADE (cocondenser)":
|
244 |
return get_splade_representation(text)
|
245 |
elif model_choice == "SPLADE-v3-Lexical":
|
246 |
-
# Always applies lexical mask for this option
|
247 |
return get_splade_lexical_representation(text)
|
248 |
-
elif model_choice == "SPLADE-v3-Doc
|
249 |
-
|
250 |
-
|
251 |
-
return get_splade_doc_representation(text, apply_lexical_mask=True)
|
252 |
else:
|
253 |
return "Please select a model."
|
254 |
|
@@ -259,9 +259,8 @@ demo = gr.Interface(
|
|
259 |
gr.Radio(
|
260 |
[
|
261 |
"SPLADE (cocondenser)",
|
262 |
-
"SPLADE-v3-Lexical",
|
263 |
-
"SPLADE-v3-Doc
|
264 |
-
"SPLADE-v3-Doc (lexical-only)" # Option with lexical mask applied
|
265 |
],
|
266 |
label="Choose Representation Model",
|
267 |
value="SPLADE (cocondenser)" # Default selection
|
@@ -274,7 +273,7 @@ demo = gr.Interface(
|
|
274 |
],
|
275 |
outputs=gr.Markdown(),
|
276 |
title="🌌 Sparse Representation Generator",
|
277 |
-
description="Enter any text to see its
|
278 |
allow_flagging="never"
|
279 |
)
|
280 |
|
|
|
43 |
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
44 |
|
45 |
|
46 |
+
# --- Helper function for lexical mask (still needed for splade-v3-lexical) ---
|
47 |
def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
48 |
"""
|
49 |
Creates a binary bag-of-words mask from input_ids,
|
50 |
zeroing out special tokens and padding.
|
51 |
"""
|
|
|
52 |
bow_mask = torch.zeros(vocab_size, device=input_ids.device)
|
|
|
|
|
|
|
53 |
meaningful_token_ids = []
|
54 |
+
for token_id in input_ids.squeeze().tolist():
|
55 |
if token_id not in [
|
56 |
tokenizer.pad_token_id,
|
57 |
tokenizer.cls_token_id,
|
58 |
tokenizer.sep_token_id,
|
59 |
tokenizer.mask_token_id,
|
60 |
+
tokenizer.unk_token_id
|
61 |
]:
|
62 |
meaningful_token_ids.append(token_id)
|
63 |
|
|
|
64 |
if meaningful_token_ids:
|
65 |
+
bow_mask[list(set(meaningful_token_ids))] = 1
|
66 |
|
67 |
+
return bow_mask.unsqueeze(0)
|
68 |
|
69 |
|
70 |
# --- Core Representation Functions ---
|
|
|
135 |
return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
|
136 |
|
137 |
# --- Apply Lexical Mask (always applied for this function now) ---
|
|
|
138 |
vocab_size = tokenizer_splade_lexical.vocab_size
|
|
|
|
|
139 |
bow_mask = create_lexical_bow_mask(
|
140 |
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
|
141 |
).squeeze()
|
|
|
|
|
142 |
splade_vector = splade_vector * bow_mask
|
143 |
# --- End Lexical Mask Logic ---
|
144 |
|
|
|
171 |
return formatted_output
|
172 |
|
173 |
|
174 |
+
# NEW: Function for SPLADE-v3-Doc representation (Binary Sparse)
|
175 |
+
def get_splade_doc_representation(text):
|
176 |
if tokenizer_splade_doc is None or model_splade_doc is None:
|
177 |
return "SPLADE v3 Doc model is not loaded. Please check the console for loading errors."
|
178 |
|
|
|
182 |
with torch.no_grad():
|
183 |
output = model_splade_doc(**inputs)
|
184 |
|
185 |
+
if not hasattr(output, "logits"):
|
186 |
+
return "SPLADE v3 Doc model output structure not as expected. 'logits' not found."
|
187 |
+
|
188 |
+
# For SPLADE-v3-Doc, the output is often a binary sparse vector.
|
189 |
+
# We will assume a simple binarization based on a threshold or selecting active tokens.
|
190 |
+
# A common way to get "binary" is to use softplus and then binarize, or directly binarize max logits.
|
191 |
+
# Given the "no weighting, no expansion" request, we'll aim for a strict presence check.
|
192 |
+
|
193 |
+
# Option 1: Binarize based on softplus output and threshold (similar to UNICOIL)
|
194 |
+
# This might still activate some "expanded" terms if the model predicts them strongly.
|
195 |
+
# transformed_scores = torch.log(1 + torch.exp(output.logits)) # Softplus
|
196 |
+
# splade_vector_raw = torch.max(transformed_scores * inputs['attention_mask'].unsqueeze(-1), dim=1).values
|
197 |
+
# binary_splade_vector = (splade_vector_raw > 0.5).float() # Binarize
|
198 |
+
|
199 |
+
# Option 2: Rely on the original BoW for terms, with 1 for presence
|
200 |
+
# This aligns best with "no weighting, no expansion"
|
201 |
+
vocab_size = tokenizer_splade_doc.vocab_size
|
202 |
+
binary_splade_vector = create_lexical_bow_mask(
|
203 |
+
inputs['input_ids'], vocab_size, tokenizer_splade_doc
|
204 |
+
).squeeze()
|
205 |
+
|
206 |
+
# We set values to 1 as it's a binary representation, not weighted
|
207 |
+
indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
|
208 |
+
if not isinstance(indices, list): # Handle case where only one non-zero index
|
209 |
+
indices = [indices] if indices else [] # Ensure it's a list even if empty or single
|
210 |
+
|
211 |
+
# Values are all 1 for binary representation
|
212 |
+
values = [1.0] * len(indices)
|
213 |
token_weights = dict(zip(indices, values))
|
214 |
|
215 |
meaningful_tokens = {}
|
|
|
218 |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
219 |
meaningful_tokens[decoded_token] = weight
|
220 |
|
221 |
+
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for binary
|
222 |
|
223 |
+
formatted_output = "SPLADE v3 Doc Representation (Binary Sparse - Lexical Only):\n"
|
224 |
if not sorted_representation:
|
225 |
formatted_output += "No significant terms found for this input.\n"
|
226 |
else:
|
227 |
+
# Display as terms with no weights as they are binary (value 1)
|
228 |
+
for i, (term, _) in enumerate(sorted_representation):
|
229 |
+
# Limit display for very long lists for readability
|
230 |
+
if i >= 50:
|
231 |
+
formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n"
|
232 |
+
break
|
233 |
+
formatted_output += f"- **{term}**\n"
|
234 |
+
|
235 |
+
formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
|
236 |
+
formatted_output += f"Total activated terms: {len(indices)}\n"
|
237 |
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
|
238 |
|
239 |
return formatted_output
|
|
|
244 |
if model_choice == "SPLADE (cocondenser)":
|
245 |
return get_splade_representation(text)
|
246 |
elif model_choice == "SPLADE-v3-Lexical":
|
247 |
+
# Always applies lexical mask for this option
|
248 |
return get_splade_lexical_representation(text)
|
249 |
+
elif model_choice == "SPLADE-v3-Doc": # Simplified to a single option
|
250 |
+
# This function now intrinsically handles binary, lexical-only output
|
251 |
+
return get_splade_doc_representation(text)
|
|
|
252 |
else:
|
253 |
return "Please select a model."
|
254 |
|
|
|
259 |
gr.Radio(
|
260 |
[
|
261 |
"SPLADE (cocondenser)",
|
262 |
+
"SPLADE-v3-Lexical",
|
263 |
+
"SPLADE-v3-Doc" # Only one option for Doc model
|
|
|
264 |
],
|
265 |
label="Choose Representation Model",
|
266 |
value="SPLADE (cocondenser)" # Default selection
|
|
|
273 |
],
|
274 |
outputs=gr.Markdown(),
|
275 |
title="🌌 Sparse Representation Generator",
|
276 |
+
description="Enter any text to see its sparse vector representation.", # Simplified description
|
277 |
allow_flagging="never"
|
278 |
)
|
279 |
|