Spaces:
Running
Running
added unicoil
Browse files
app.py
CHANGED
@@ -1,119 +1,244 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
3 |
import torch
|
4 |
|
5 |
# --- Model Loading ---
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
try:
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
14 |
except Exception as e:
|
15 |
-
print(f"Error loading
|
16 |
-
print("Please ensure
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
|
21 |
-
# --- Core SPLADE Representation Function ---
|
22 |
def get_splade_representation(text):
|
23 |
-
if
|
24 |
return "SPLADE model is not loaded. Please check the console for loading errors."
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
# padding=True pads to the longest sequence in the batch (though for single input, it's just the input length).
|
29 |
-
# truncation=True truncates if the text is too long for the model's max input size.
|
30 |
-
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
31 |
|
32 |
-
# Move inputs to the same device as the model (e.g., CPU or GPU)
|
33 |
-
# This is important if you were running on a GPU in a production environment.
|
34 |
-
# For Hugging Face Spaces free tier, it's usually CPU.
|
35 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
36 |
-
|
37 |
-
# Get the model's output without calculating gradients (inference mode)
|
38 |
with torch.no_grad():
|
39 |
-
output =
|
40 |
-
|
41 |
-
# Extract the logits from the model's output.
|
42 |
-
# SPLADE uses the masked language modeling head's logits to derive term importance.
|
43 |
-
# Apply the SPLADE aggregation function: log(1 + ReLU(logits))
|
44 |
-
# This transforms the raw logits into a sparse vector where higher values indicate more importance.
|
45 |
-
# The attention_mask ensures we only consider actual tokens, not padding.
|
46 |
-
|
47 |
-
# Check if 'logits' is in the output (standard for AutoModelForMaskedLM)
|
48 |
if hasattr(output, 'logits'):
|
49 |
-
# Apply the SPLADE transformation
|
50 |
-
# output.logits is typically [batch_size, sequence_length, vocab_size]
|
51 |
-
# We need to take the max over the sequence_length dimension to get a [batch_size, vocab_size] vector.
|
52 |
-
# inputs.attention_mask.unsqueeze(-1) expands the mask to match vocab_size for element-wise multiplication.
|
53 |
splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
|
54 |
else:
|
55 |
-
# Fallback/error message if the output structure is unexpected
|
56 |
return "Model output structure not as expected for SPLADE. 'logits' not found."
|
57 |
|
58 |
-
# Convert the sparse vector to a human-readable format.
|
59 |
-
# We only care about the non-zero (or very small) entries, as they represent activated terms.
|
60 |
-
|
61 |
-
# Get the indices (token IDs) of the non-zero elements in the SPLADE vector
|
62 |
-
# torch.nonzero returns coordinates of non-zero elements. squeeze() removes dimensions of size 1.
|
63 |
-
# .cpu().tolist() moves the tensor to CPU and converts to a Python list.
|
64 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
65 |
-
|
66 |
-
# If it's a single index (e.g., a very short text), make it a list for consistent processing
|
67 |
if not isinstance(indices, list):
|
68 |
indices = [indices]
|
69 |
-
|
70 |
-
# Get the corresponding values (weights) for these non-zero indices
|
71 |
values = splade_vector[indices].cpu().tolist()
|
72 |
-
|
73 |
-
# Create a dictionary mapping token ID to its weight
|
74 |
token_weights = dict(zip(indices, values))
|
75 |
|
76 |
-
# Decode token IDs back to actual words/subwords
|
77 |
-
# Filter out common special tokens that are not meaningful for retrieval (e.g., [CLS], [SEP], [PAD])
|
78 |
-
# You can add more tokens to this list if they appear frequently and are not helpful.
|
79 |
meaningful_tokens = {}
|
80 |
for token_id, weight in token_weights.items():
|
81 |
-
decoded_token =
|
82 |
-
# Filter out special tokens or very short/noisy tokens
|
83 |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
84 |
meaningful_tokens[decoded_token] = weight
|
85 |
|
86 |
-
# Sort the meaningful tokens by their weight in descending order
|
87 |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
|
88 |
|
89 |
-
# Format the output for display
|
90 |
formatted_output = "SPLADE Representation (Top 20 Terms):\n"
|
91 |
if not sorted_representation:
|
92 |
formatted_output += "No significant terms found for this input.\n"
|
93 |
else:
|
94 |
for i, (term, weight) in enumerate(sorted_representation):
|
95 |
-
if i >= 20:
|
96 |
break
|
97 |
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
98 |
-
|
99 |
formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
|
100 |
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
|
101 |
-
formatted_output += f"Sparsity: {1 - (len(indices) /
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
return formatted_output
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
# --- Gradio Interface Setup ---
|
106 |
demo = gr.Interface(
|
107 |
-
fn=
|
108 |
-
inputs=
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
)
|
118 |
|
119 |
# Launch the Gradio app
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel
|
3 |
import torch
|
4 |
|
5 |
# --- Model Loading ---
|
6 |
+
tokenizer_splade = None
|
7 |
+
model_splade = None
|
8 |
+
tokenizer_unicoil = None
|
9 |
+
model_unicoil = None
|
10 |
+
|
11 |
+
# Load SPLADE v3 model
|
12 |
+
try:
|
13 |
+
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
|
14 |
+
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
|
15 |
+
model_splade.eval() # Set to evaluation mode for inference
|
16 |
+
print("SPLADE v3 model loaded successfully!")
|
17 |
+
except Exception as e:
|
18 |
+
print(f"Error loading SPLADE model: {e}")
|
19 |
+
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
|
20 |
+
|
21 |
+
# Load UNICOIL model for binary sparse encoding
|
22 |
try:
|
23 |
+
# UNICOIL models are typically just AutoModel as they add a linear layer
|
24 |
+
# on top of a BERT-like encoder to predict weights.
|
25 |
+
# 'castorini/unicoil-msmarco-passage' is a common UNICOIL checkpoint.
|
26 |
+
unicoil_model_name = "castorini/unicoil-msmarco-passage"
|
27 |
+
tokenizer_unicoil = AutoTokenizer.from_pretrained(unicoil_model_name)
|
28 |
+
model_unicoil = AutoModel.from_pretrained(unicoil_model_name)
|
29 |
+
model_unicoil.eval() # Set to evaluation mode for inference
|
30 |
+
print(f"UNICOIL model '{unicoil_model_name}' loaded successfully!")
|
31 |
except Exception as e:
|
32 |
+
print(f"Error loading UNICOIL model: {e}")
|
33 |
+
print(f"Please ensure '{unicoil_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
34 |
+
|
35 |
+
|
36 |
+
# --- Core Representation Functions ---
|
37 |
|
|
|
38 |
def get_splade_representation(text):
|
39 |
+
if tokenizer_splade is None or model_splade is None:
|
40 |
return "SPLADE model is not loaded. Please check the console for loading errors."
|
41 |
|
42 |
+
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
|
43 |
+
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
|
|
|
|
|
|
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
with torch.no_grad():
|
46 |
+
output = model_splade(**inputs)
|
47 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
if hasattr(output, 'logits'):
|
|
|
|
|
|
|
|
|
49 |
splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
|
50 |
else:
|
|
|
51 |
return "Model output structure not as expected for SPLADE. 'logits' not found."
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
|
|
|
|
54 |
if not isinstance(indices, list):
|
55 |
indices = [indices]
|
56 |
+
|
|
|
57 |
values = splade_vector[indices].cpu().tolist()
|
|
|
|
|
58 |
token_weights = dict(zip(indices, values))
|
59 |
|
|
|
|
|
|
|
60 |
meaningful_tokens = {}
|
61 |
for token_id, weight in token_weights.items():
|
62 |
+
decoded_token = tokenizer_splade.decode([token_id])
|
|
|
63 |
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
64 |
meaningful_tokens[decoded_token] = weight
|
65 |
|
|
|
66 |
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
|
67 |
|
|
|
68 |
formatted_output = "SPLADE Representation (Top 20 Terms):\n"
|
69 |
if not sorted_representation:
|
70 |
formatted_output += "No significant terms found for this input.\n"
|
71 |
else:
|
72 |
for i, (term, weight) in enumerate(sorted_representation):
|
73 |
+
if i >= 20:
|
74 |
break
|
75 |
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
76 |
+
|
77 |
formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
|
78 |
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
|
79 |
+
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n"
|
80 |
+
|
81 |
+
return formatted_output
|
82 |
+
|
83 |
+
|
84 |
+
def get_unicoil_binary_representation(text):
|
85 |
+
if tokenizer_unicoil is None or model_unicoil is None:
|
86 |
+
return "UNICOIL model is not loaded. Please check the console for loading errors."
|
87 |
+
|
88 |
+
inputs = tokenizer_unicoil(text, return_tensors="pt", padding=True, truncation=True)
|
89 |
+
inputs = {k: v.to(model_unicoil.device) for k, v in inputs.items()}
|
90 |
+
|
91 |
+
with torch.no_grad():
|
92 |
+
# UNICOIL models often output a dictionary where 'token_scores' or similar
|
93 |
+
# contain the learned weights for each token. The structure can vary.
|
94 |
+
# For 'castorini/unicoil-msmarco-passage', the token scores are typically
|
95 |
+
# the last hidden state of the model, which is then mapped by a linear layer
|
96 |
+
# into the sparse weights. We might need to manually extract those,
|
97 |
+
# or the model itself might be set up to produce the weights directly.
|
98 |
+
# Based on typical UNICOIL implementations, we usually take the output
|
99 |
+
# from the last layer and map it to vocabulary size.
|
100 |
+
|
101 |
+
# In many UNICOIL variations, the model itself is designed to output
|
102 |
+
# the "re-weight" scores for each token. Let's assume for simplicity
|
103 |
+
# that the model's forward pass returns something that can be interpreted
|
104 |
+
# as per-token scores, perhaps in `output.last_hidden_state`
|
105 |
+
# and then a simple linear layer (not part of the AutoModel usually)
|
106 |
+
# would project this to vocab size.
|
107 |
+
# For simplicity and to fit `AutoModel`, we'll treat the last hidden state
|
108 |
+
# directly as the basis for term importance for now, which is common in similar models,
|
109 |
+
# or if the model already has a head, we use it.
|
110 |
+
|
111 |
+
# A more robust UNICOIL implementation would involve a specific head
|
112 |
+
# if not using AutoModelForMaskedLM. However, AutoModel gives us the
|
113 |
+
# last hidden states from which we can infer.
|
114 |
+
|
115 |
+
# For UNICOIL, we're interested in the weighted token scores.
|
116 |
+
# `model(**inputs)` will typically return a `BaseModelOutput`
|
117 |
+
# or `MaskedLMOutput` if it's based on an MLM.
|
118 |
+
# Let's assume the model's output provides the token importance.
|
119 |
+
|
120 |
+
# A common way to get UNICOIL scores if not explicitly provided as logits:
|
121 |
+
# It's usually a linear layer on top of the last hidden state.
|
122 |
+
# Since AutoModel just gives the base model, we'll mimic the output
|
123 |
+
# as a direct mapping if the model doesn't have a specific head for scores.
|
124 |
+
# However, looking at `castorini/unicoil-msmarco-passage`
|
125 |
+
# its `config.json` might give hints or the model itself is structured.
|
126 |
+
# Often, it uses `BertForMaskedLM` and then applies `log(1+relu)` to the logits.
|
127 |
+
# Let's assume it behaves similar to SPLADE for simplicity of extraction for now,
|
128 |
+
# or we might need to load it as `AutoModelForMaskedLM` if its internal structure
|
129 |
+
# is indeed like that, and then apply a binarization.
|
130 |
+
|
131 |
+
# Re-evaluating: UNICOIL typically *learns* explicit token weights.
|
132 |
+
# The common approach for UNICOIL with Hugging Face is indeed to load it
|
133 |
+
# as `AutoModelForMaskedLM` and use its `logits` output, similar to SPLADE,
|
134 |
+
# but with a different aggregation strategy.
|
135 |
+
# Let's verify the model type for 'castorini/unicoil-msmarco-passage'.
|
136 |
+
# Its config.json and architecture implies it's a BertForMaskedLM variant.
|
137 |
+
|
138 |
+
output = model_unicoil(**inputs) # This should be a BaseModelOutputWithPooling or similar
|
139 |
+
|
140 |
+
if not hasattr(output, 'logits'):
|
141 |
+
# If `model_unicoil` is an `AutoModel` without a classification head,
|
142 |
+
# we need to add a way to get per-token scores.
|
143 |
+
# This is where a custom model head or a specific model class would be needed.
|
144 |
+
# For `castorini/unicoil-msmarco-passage`, it *is* an MLM variant.
|
145 |
+
# So, `output.logits` *should* be available.
|
146 |
+
return "UNICOIL model output structure not as expected. 'logits' not found."
|
147 |
+
|
148 |
+
# UNICOIL's output is also typically per-token scores from the MLM head.
|
149 |
+
# For UNICOIL, the weights are often taken directly from the logits after pooling.
|
150 |
+
# Unlike SPLADE's log(1+ReLU), UNICOIL's approach can be simpler,
|
151 |
+
# sometimes just taking the maximum of logits (or similar pooling).
|
152 |
+
# A common binarization for UNICOIL is based on the sign of the re-weighted scores.
|
153 |
+
|
154 |
+
# Let's mimic a common UNICOIL interpretation for obtaining sparse weights
|
155 |
+
# from the logits. The weights are usually sparse and positive.
|
156 |
+
# We can apply a threshold for binarization.
|
157 |
+
|
158 |
+
# This is a simplification; actual UNICOIL might have specific layers.
|
159 |
+
# For `castorini/unicoil-msmarco-passage`, it uses the `log(1+exp(logits))` formulation
|
160 |
+
# followed by max pooling, then often binarization based on a threshold.
|
161 |
+
|
162 |
+
# Applying a common interpretation of UNICOIL-like score generation for sparse weights:
|
163 |
+
# Instead of `log(1+ReLU(logits))`, it often uses `torch.log(1 + torch.exp(output.logits))`.
|
164 |
+
# This is essentially the softplus function, which makes values positive and sparse.
|
165 |
+
|
166 |
+
# Get the sparse weights using the UNICOIL-like transformation
|
167 |
+
sparse_weights = torch.max(torch.log(1 + torch.exp(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
|
168 |
+
|
169 |
+
# --- Binarization Step for UNICOIL ---
|
170 |
+
# For true "binary sparse", we threshold these sparse weights.
|
171 |
+
# A common approach is to simply take any non-zero value as 1, and zero as 0.
|
172 |
+
# Or, define a small threshold for binarization if values are very small but non-zero.
|
173 |
+
# For simplicity, let's treat anything above a very small epsilon as 1.
|
174 |
+
|
175 |
+
# Convert to binary: 1 if weight > epsilon, else 0
|
176 |
+
threshold = 1e-6 # Define a small threshold for binarization
|
177 |
+
binary_sparse_vector = (sparse_weights > threshold).int()
|
178 |
+
|
179 |
+
# Get indices of the '1's in the binary vector
|
180 |
+
binary_indices = torch.nonzero(binary_sparse_vector).squeeze().cpu().tolist()
|
181 |
+
|
182 |
+
if not isinstance(binary_indices, list):
|
183 |
+
binary_indices = [binary_indices] if binary_indices.numel() > 0 else []
|
184 |
+
|
185 |
+
# Map token IDs back to terms for the binary representation
|
186 |
+
binary_terms = {}
|
187 |
+
for token_id in binary_indices:
|
188 |
+
decoded_token = tokenizer_unicoil.decode([token_id])
|
189 |
+
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
190 |
+
binary_terms[decoded_token] = 1 # Value is always 1 for binary
|
191 |
+
|
192 |
+
sorted_binary_terms = sorted(binary_terms.items(), key=lambda item: item[0]) # Sort by term for consistent display
|
193 |
+
|
194 |
+
formatted_output = "UNICOIL Binary Sparse Representation (Activated Terms):\n"
|
195 |
+
if not sorted_binary_terms:
|
196 |
+
formatted_output += "No significant terms activated for this input.\n"
|
197 |
+
else:
|
198 |
+
# Display up to 50 activated terms for readability
|
199 |
+
for i, (term, _) in enumerate(sorted_binary_terms):
|
200 |
+
if i >= 50:
|
201 |
+
break
|
202 |
+
formatted_output += f"- **{term}**\n" # Only show term, as weight is always 1
|
203 |
+
if len(sorted_binary_terms) > 50:
|
204 |
+
formatted_output += f"...and {len(sorted_binary_terms) - 50} more terms.\n"
|
205 |
+
|
206 |
+
formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
|
207 |
+
formatted_output += f"Total activated terms: {len(binary_indices)}\n"
|
208 |
+
# Calculate sparsity based on the number of '1's vs. total vocabulary size
|
209 |
+
formatted_output += f"Sparsity: {1 - (len(binary_indices) / tokenizer_unicoil.vocab_size):.2%}\n"
|
210 |
|
211 |
return formatted_output
|
212 |
|
213 |
+
|
214 |
+
# --- Unified Prediction Function for Gradio ---
|
215 |
+
def predict_representation(model_choice, text):
|
216 |
+
if model_choice == "SPLADE":
|
217 |
+
return get_splade_representation(text)
|
218 |
+
elif model_choice == "UNICOIL (Binary Sparse)":
|
219 |
+
return get_unicoil_binary_representation(text)
|
220 |
+
else:
|
221 |
+
return "Please select a model."
|
222 |
+
|
223 |
# --- Gradio Interface Setup ---
|
224 |
demo = gr.Interface(
|
225 |
+
fn=predict_representation,
|
226 |
+
inputs=[
|
227 |
+
gr.Radio(
|
228 |
+
["SPLADE", "UNICOIL (Binary Sparse)"], # Added UNICOIL option
|
229 |
+
label="Choose Representation Model",
|
230 |
+
value="SPLADE" # Default selection
|
231 |
+
),
|
232 |
+
gr.Textbox(
|
233 |
+
lines=5,
|
234 |
+
label="Enter your query or document text here:",
|
235 |
+
placeholder="e.g., Why is Padua the nicest city in Italy?"
|
236 |
+
)
|
237 |
+
],
|
238 |
+
outputs=gr.Markdown(),
|
239 |
+
title="🌌 Sparse and Binary Sparse Representation Generator",
|
240 |
+
description="Enter any text to see its SPLADE sparse vector or UNICOIL binary sparse representation.",
|
241 |
+
allow_flagging="never"
|
242 |
)
|
243 |
|
244 |
# Launch the Gradio app
|