Spaces:
Running
Running
File size: 12,500 Bytes
a4de739 da3acda 3035463 a4de739 3035463 da3acda 3035463 da3acda 3035463 da3acda a4de739 3035463 da3acda 3035463 da3acda 3035463 da3acda 3035463 1728ea0 3035463 da3acda 3035463 da3acda 3035463 da3acda 3035463 da3acda 3035463 da3acda 3035463 da3acda 3035463 da3acda 3035463 a4de739 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel
import torch
# --- Model Loading ---
tokenizer_splade = None
model_splade = None
tokenizer_unicoil = None
model_unicoil = None
# Load SPLADE v3 model
try:
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade.eval() # Set to evaluation mode for inference
print("SPLADE v3 model loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE model: {e}")
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
# Load UNICOIL model for binary sparse encoding
try:
# UNICOIL models are typically just AutoModel as they add a linear layer
# on top of a BERT-like encoder to predict weights.
# 'castorini/unicoil-msmarco-passage' is a common UNICOIL checkpoint.
unicoil_model_name = "castorini/unicoil-msmarco-passage"
tokenizer_unicoil = AutoTokenizer.from_pretrained(unicoil_model_name)
model_unicoil = AutoModel.from_pretrained(unicoil_model_name)
model_unicoil.eval() # Set to evaluation mode for inference
print(f"UNICOIL model '{unicoil_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading UNICOIL model: {e}")
print(f"Please ensure '{unicoil_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# --- Core Representation Functions ---
def get_splade_representation(text):
if tokenizer_splade is None or model_splade is None:
return "SPLADE model is not loaded. Please check the console for loading errors."
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
else:
return "Model output structure not as expected for SPLADE. 'logits' not found."
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices]
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "SPLADE Representation (Top 20 Terms):\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
for i, (term, weight) in enumerate(sorted_representation):
if i >= 20:
break
formatted_output += f"- **{term}**: {weight:.4f}\n"
formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n"
return formatted_output
def get_unicoil_binary_representation(text):
if tokenizer_unicoil is None or model_unicoil is None:
return "UNICOIL model is not loaded. Please check the console for loading errors."
inputs = tokenizer_unicoil(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_unicoil.device) for k, v in inputs.items()}
with torch.no_grad():
# UNICOIL models often output a dictionary where 'token_scores' or similar
# contain the learned weights for each token. The structure can vary.
# For 'castorini/unicoil-msmarco-passage', the token scores are typically
# the last hidden state of the model, which is then mapped by a linear layer
# into the sparse weights. We might need to manually extract those,
# or the model itself might be set up to produce the weights directly.
# Based on typical UNICOIL implementations, we usually take the output
# from the last layer and map it to vocabulary size.
# In many UNICOIL variations, the model itself is designed to output
# the "re-weight" scores for each token. Let's assume for simplicity
# that the model's forward pass returns something that can be interpreted
# as per-token scores, perhaps in `output.last_hidden_state`
# and then a simple linear layer (not part of the AutoModel usually)
# would project this to vocab size.
# For simplicity and to fit `AutoModel`, we'll treat the last hidden state
# directly as the basis for term importance for now, which is common in similar models,
# or if the model already has a head, we use it.
# A more robust UNICOIL implementation would involve a specific head
# if not using AutoModelForMaskedLM. However, AutoModel gives us the
# last hidden states from which we can infer.
# For UNICOIL, we're interested in the weighted token scores.
# `model(**inputs)` will typically return a `BaseModelOutput`
# or `MaskedLMOutput` if it's based on an MLM.
# Let's assume the model's output provides the token importance.
# A common way to get UNICOIL scores if not explicitly provided as logits:
# It's usually a linear layer on top of the last hidden state.
# Since AutoModel just gives the base model, we'll mimic the output
# as a direct mapping if the model doesn't have a specific head for scores.
# However, looking at `castorini/unicoil-msmarco-passage`
# its `config.json` might give hints or the model itself is structured.
# Often, it uses `BertForMaskedLM` and then applies `log(1+relu)` to the logits.
# Let's assume it behaves similar to SPLADE for simplicity of extraction for now,
# or we might need to load it as `AutoModelForMaskedLM` if its internal structure
# is indeed like that, and then apply a binarization.
# Re-evaluating: UNICOIL typically *learns* explicit token weights.
# The common approach for UNICOIL with Hugging Face is indeed to load it
# as `AutoModelForMaskedLM` and use its `logits` output, similar to SPLADE,
# but with a different aggregation strategy.
# Let's verify the model type for 'castorini/unicoil-msmarco-passage'.
# Its config.json and architecture implies it's a BertForMaskedLM variant.
output = model_unicoil(**inputs) # This should be a BaseModelOutputWithPooling or similar
if not hasattr(output, 'logits'):
# If `model_unicoil` is an `AutoModel` without a classification head,
# we need to add a way to get per-token scores.
# This is where a custom model head or a specific model class would be needed.
# For `castorini/unicoil-msmarco-passage`, it *is* an MLM variant.
# So, `output.logits` *should* be available.
return "UNICOIL model output structure not as expected. 'logits' not found."
# UNICOIL's output is also typically per-token scores from the MLM head.
# For UNICOIL, the weights are often taken directly from the logits after pooling.
# Unlike SPLADE's log(1+ReLU), UNICOIL's approach can be simpler,
# sometimes just taking the maximum of logits (or similar pooling).
# A common binarization for UNICOIL is based on the sign of the re-weighted scores.
# Let's mimic a common UNICOIL interpretation for obtaining sparse weights
# from the logits. The weights are usually sparse and positive.
# We can apply a threshold for binarization.
# This is a simplification; actual UNICOIL might have specific layers.
# For `castorini/unicoil-msmarco-passage`, it uses the `log(1+exp(logits))` formulation
# followed by max pooling, then often binarization based on a threshold.
# Applying a common interpretation of UNICOIL-like score generation for sparse weights:
# Instead of `log(1+ReLU(logits))`, it often uses `torch.log(1 + torch.exp(output.logits))`.
# This is essentially the softplus function, which makes values positive and sparse.
# Get the sparse weights using the UNICOIL-like transformation
sparse_weights = torch.max(torch.log(1 + torch.exp(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
# --- Binarization Step for UNICOIL ---
# For true "binary sparse", we threshold these sparse weights.
# A common approach is to simply take any non-zero value as 1, and zero as 0.
# Or, define a small threshold for binarization if values are very small but non-zero.
# For simplicity, let's treat anything above a very small epsilon as 1.
# Convert to binary: 1 if weight > epsilon, else 0
threshold = 1e-6 # Define a small threshold for binarization
binary_sparse_vector = (sparse_weights > threshold).int()
# Get indices of the '1's in the binary vector
binary_indices = torch.nonzero(binary_sparse_vector).squeeze().cpu().tolist()
if not isinstance(binary_indices, list):
binary_indices = [binary_indices] if binary_indices.numel() > 0 else []
# Map token IDs back to terms for the binary representation
binary_terms = {}
for token_id in binary_indices:
decoded_token = tokenizer_unicoil.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
binary_terms[decoded_token] = 1 # Value is always 1 for binary
sorted_binary_terms = sorted(binary_terms.items(), key=lambda item: item[0]) # Sort by term for consistent display
formatted_output = "UNICOIL Binary Sparse Representation (Activated Terms):\n"
if not sorted_binary_terms:
formatted_output += "No significant terms activated for this input.\n"
else:
# Display up to 50 activated terms for readability
for i, (term, _) in enumerate(sorted_binary_terms):
if i >= 50:
break
formatted_output += f"- **{term}**\n" # Only show term, as weight is always 1
if len(sorted_binary_terms) > 50:
formatted_output += f"...and {len(sorted_binary_terms) - 50} more terms.\n"
formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
formatted_output += f"Total activated terms: {len(binary_indices)}\n"
# Calculate sparsity based on the number of '1's vs. total vocabulary size
formatted_output += f"Sparsity: {1 - (len(binary_indices) / tokenizer_unicoil.vocab_size):.2%}\n"
return formatted_output
# --- Unified Prediction Function for Gradio ---
def predict_representation(model_choice, text):
if model_choice == "SPLADE":
return get_splade_representation(text)
elif model_choice == "UNICOIL (Binary Sparse)":
return get_unicoil_binary_representation(text)
else:
return "Please select a model."
# --- Gradio Interface Setup ---
demo = gr.Interface(
fn=predict_representation,
inputs=[
gr.Radio(
["SPLADE", "UNICOIL (Binary Sparse)"], # Added UNICOIL option
label="Choose Representation Model",
value="SPLADE" # Default selection
),
gr.Textbox(
lines=5,
label="Enter your query or document text here:",
placeholder="e.g., Why is Padua the nicest city in Italy?"
)
],
outputs=gr.Markdown(),
title="🌌 Sparse and Binary Sparse Representation Generator",
description="Enter any text to see its SPLADE sparse vector or UNICOIL binary sparse representation.",
allow_flagging="never"
)
# Launch the Gradio app
demo.launch() |