Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -29,6 +29,34 @@ except Exception as e:
|
|
29 |
print(f"Error loading SPLADE v3 Lexical model: {e}")
|
30 |
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
# --- Core Representation Functions ---
|
34 |
|
@@ -79,7 +107,7 @@ def get_splade_representation(text):
|
|
79 |
return formatted_output
|
80 |
|
81 |
|
82 |
-
def get_splade_lexical_representation(text):
|
83 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
84 |
return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
|
85 |
|
@@ -97,6 +125,20 @@ def get_splade_lexical_representation(text):
|
|
97 |
else:
|
98 |
return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
101 |
if not isinstance(indices, list):
|
102 |
indices = [indices]
|
@@ -130,8 +172,12 @@ def get_splade_lexical_representation(text):
|
|
130 |
def predict_representation(model_choice, text):
|
131 |
if model_choice == "SPLADE (cocondenser)":
|
132 |
return get_splade_representation(text)
|
133 |
-
elif model_choice == "SPLADE-v3-Lexical":
|
134 |
-
|
|
|
|
|
|
|
|
|
135 |
else:
|
136 |
return "Please select a model."
|
137 |
|
@@ -140,7 +186,11 @@ demo = gr.Interface(
|
|
140 |
fn=predict_representation,
|
141 |
inputs=[
|
142 |
gr.Radio(
|
143 |
-
[
|
|
|
|
|
|
|
|
|
144 |
label="Choose Representation Model",
|
145 |
value="SPLADE (cocondenser)" # Default selection
|
146 |
),
|
@@ -151,8 +201,8 @@ demo = gr.Interface(
|
|
151 |
)
|
152 |
],
|
153 |
outputs=gr.Markdown(),
|
154 |
-
title="🌌 Sparse
|
155 |
-
description="Enter any text to see its SPLADE sparse vector
|
156 |
allow_flagging="never"
|
157 |
)
|
158 |
|
|
|
29 |
print(f"Error loading SPLADE v3 Lexical model: {e}")
|
30 |
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
31 |
|
32 |
+
# --- Helper function for lexical mask ---
|
33 |
+
def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
34 |
+
"""
|
35 |
+
Creates a binary bag-of-words mask from input_ids,
|
36 |
+
zeroing out special tokens and padding.
|
37 |
+
"""
|
38 |
+
# Initialize a zero vector for the entire vocabulary
|
39 |
+
bow_mask = torch.zeros(vocab_size, device=input_ids.device)
|
40 |
+
|
41 |
+
# Get unique token IDs from the input, excluding special tokens
|
42 |
+
# input_ids is typically [batch_size, seq_len], we assume batch_size=1
|
43 |
+
meaningful_token_ids = []
|
44 |
+
for token_id in input_ids.squeeze().tolist(): # Squeeze to remove batch dim and convert to list
|
45 |
+
if token_id not in [
|
46 |
+
tokenizer.pad_token_id,
|
47 |
+
tokenizer.cls_token_id,
|
48 |
+
tokenizer.sep_token_id,
|
49 |
+
tokenizer.mask_token_id,
|
50 |
+
tokenizer.unk_token_id # Also exclude unknown tokens
|
51 |
+
]:
|
52 |
+
meaningful_token_ids.append(token_id)
|
53 |
+
|
54 |
+
# Set 1 for tokens present in the original input
|
55 |
+
if meaningful_token_ids:
|
56 |
+
bow_mask[list(set(meaningful_token_ids))] = 1 # Use set to handle duplicates
|
57 |
+
|
58 |
+
return bow_mask.unsqueeze(0) # Keep batch dimension for consistency
|
59 |
+
|
60 |
|
61 |
# --- Core Representation Functions ---
|
62 |
|
|
|
107 |
return formatted_output
|
108 |
|
109 |
|
110 |
+
def get_splade_lexical_representation(text, apply_lexical_mask: bool): # Added parameter
|
111 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
112 |
return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
|
113 |
|
|
|
125 |
else:
|
126 |
return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
|
127 |
|
128 |
+
# --- Apply Lexical Mask if requested ---
|
129 |
+
if apply_lexical_mask:
|
130 |
+
# Get the vocabulary size from the tokenizer
|
131 |
+
vocab_size = tokenizer_splade_lexical.vocab_size
|
132 |
+
|
133 |
+
# Create the Bag-of-Words mask
|
134 |
+
bow_mask = create_lexical_bow_mask(
|
135 |
+
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
|
136 |
+
).squeeze() # Squeeze to match splade_vector's [vocab_size] shape
|
137 |
+
|
138 |
+
# Multiply the SPLADE vector by the BoW mask to zero out expanded terms
|
139 |
+
splade_vector = splade_vector * bow_mask
|
140 |
+
# --- End Lexical Mask Logic ---
|
141 |
+
|
142 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
143 |
if not isinstance(indices, list):
|
144 |
indices = [indices]
|
|
|
172 |
def predict_representation(model_choice, text):
|
173 |
if model_choice == "SPLADE (cocondenser)":
|
174 |
return get_splade_representation(text)
|
175 |
+
elif model_choice == "SPLADE-v3-Lexical (with expansion)":
|
176 |
+
# Call the lexical function without applying the mask
|
177 |
+
return get_splade_lexical_representation(text, apply_lexical_mask=False)
|
178 |
+
elif model_choice == "SPLADE-v3-Lexical (lexical-only)":
|
179 |
+
# Call the lexical function applying the mask
|
180 |
+
return get_splade_lexical_representation(text, apply_lexical_mask=True)
|
181 |
else:
|
182 |
return "Please select a model."
|
183 |
|
|
|
186 |
fn=predict_representation,
|
187 |
inputs=[
|
188 |
gr.Radio(
|
189 |
+
[
|
190 |
+
"SPLADE (cocondenser)",
|
191 |
+
"SPLADE-v3-Lexical (with expansion)", # Option to see full neural output
|
192 |
+
"SPLADE-v3-Lexical (lexical-only)" # Option with lexical mask applied
|
193 |
+
],
|
194 |
label="Choose Representation Model",
|
195 |
value="SPLADE (cocondenser)" # Default selection
|
196 |
),
|
|
|
201 |
)
|
202 |
],
|
203 |
outputs=gr.Markdown(),
|
204 |
+
title="🌌 Sparse Representation Generator",
|
205 |
+
description="Enter any text to see its SPLADE sparse vector. Explore the difference between full neural expansion and lexical-only representations.",
|
206 |
allow_flagging="never"
|
207 |
)
|
208 |
|