Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,120 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
3 |
+
import torch
|
4 |
|
5 |
+
# --- Model Loading ---
|
6 |
+
# Load SPLADE v23 model and tokenizer
|
7 |
+
# "naver/splade-v3" is a common and robust SPLADE v3 model.
|
8 |
+
# Make sure you've accepted any user access agreements on its Hugging Face Hub page.
|
9 |
+
try:
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained("naver/splade-v3")
|
11 |
+
model = AutoModelForMaskedLM.from_pretrained("naver/splade-v3")
|
12 |
+
model.eval() # Set the model to evaluation mode for inference
|
13 |
+
print("SPLADE v23 model and tokenizer loaded successfully!")
|
14 |
+
except Exception as e:
|
15 |
+
print(f"Error loading SPLADE model or tokenizer: {e}")
|
16 |
+
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-v3' (https://huggingface.co/naver/splade-v3).")
|
17 |
+
print("If the problem persists, check your internet connection or try a different SPLADE model if available.")
|
18 |
+
tokenizer = None
|
19 |
+
model = None
|
20 |
|
21 |
+
# --- Core SPLADE Representation Function ---
|
22 |
+
def get_splade_representation(text):
|
23 |
+
if tokenizer is None or model is None:
|
24 |
+
return "SPLADE model is not loaded. Please check the console for loading errors."
|
25 |
+
|
26 |
+
# Tokenize the input text
|
27 |
+
# return_tensors="pt" ensures PyTorch tensors are returned.
|
28 |
+
# padding=True pads to the longest sequence in the batch (though for single input, it's just the input length).
|
29 |
+
# truncation=True truncates if the text is too long for the model's max input size.
|
30 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
31 |
+
|
32 |
+
# Move inputs to the same device as the model (e.g., CPU or GPU)
|
33 |
+
# This is important if you were running on a GPU in a production environment.
|
34 |
+
# For Hugging Face Spaces free tier, it's usually CPU.
|
35 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
36 |
+
|
37 |
+
# Get the model's output without calculating gradients (inference mode)
|
38 |
+
with torch.no_grad():
|
39 |
+
output = model(**inputs)
|
40 |
+
|
41 |
+
# Extract the logits from the model's output.
|
42 |
+
# SPLADE uses the masked language modeling head's logits to derive term importance.
|
43 |
+
# Apply the SPLADE aggregation function: log(1 + ReLU(logits))
|
44 |
+
# This transforms the raw logits into a sparse vector where higher values indicate more importance.
|
45 |
+
# The attention_mask ensures we only consider actual tokens, not padding.
|
46 |
+
|
47 |
+
# Check if 'logits' is in the output (standard for AutoModelForMaskedLM)
|
48 |
+
if hasattr(output, 'logits'):
|
49 |
+
# Apply the SPLADE transformation
|
50 |
+
# output.logits is typically [batch_size, sequence_length, vocab_size]
|
51 |
+
# We need to take the max over the sequence_length dimension to get a [batch_size, vocab_size] vector.
|
52 |
+
# inputs.attention_mask.unsqueeze(-1) expands the mask to match vocab_size for element-wise multiplication.
|
53 |
+
splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs.attention_mask.unsqueeze(-1), dim=1)[0].squeeze()
|
54 |
+
else:
|
55 |
+
# Fallback/error message if the output structure is unexpected
|
56 |
+
return "Model output structure not as expected for SPLADE. 'logits' not found."
|
57 |
+
|
58 |
+
# Convert the sparse vector to a human-readable format.
|
59 |
+
# We only care about the non-zero (or very small) entries, as they represent activated terms.
|
60 |
+
|
61 |
+
# Get the indices (token IDs) of the non-zero elements in the SPLADE vector
|
62 |
+
# torch.nonzero returns coordinates of non-zero elements. squeeze() removes dimensions of size 1.
|
63 |
+
# .cpu().tolist() moves the tensor to CPU and converts to a Python list.
|
64 |
+
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
65 |
+
|
66 |
+
# If it's a single index (e.g., a very short text), make it a list for consistent processing
|
67 |
+
if not isinstance(indices, list):
|
68 |
+
indices = [indices]
|
69 |
+
|
70 |
+
# Get the corresponding values (weights) for these non-zero indices
|
71 |
+
values = splade_vector[indices].cpu().tolist()
|
72 |
+
|
73 |
+
# Create a dictionary mapping token ID to its weight
|
74 |
+
token_weights = dict(zip(indices, values))
|
75 |
+
|
76 |
+
# Decode token IDs back to actual words/subwords
|
77 |
+
# Filter out common special tokens that are not meaningful for retrieval (e.g., [CLS], [SEP], [PAD])
|
78 |
+
# You can add more tokens to this list if they appear frequently and are not helpful.
|
79 |
+
meaningful_tokens = {}
|
80 |
+
for token_id, weight in token_weights.items():
|
81 |
+
decoded_token = tokenizer.decode([token_id])
|
82 |
+
# Filter out special tokens or very short/noisy tokens
|
83 |
+
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
84 |
+
meaningful_tokens[decoded_token] = weight
|
85 |
+
|
86 |
+
# Sort the meaningful tokens by their weight in descending order
|
87 |
+
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
|
88 |
+
|
89 |
+
# Format the output for display
|
90 |
+
formatted_output = "SPLADE Representation (Top 20 Terms):\n"
|
91 |
+
if not sorted_representation:
|
92 |
+
formatted_output += "No significant terms found for this input.\n"
|
93 |
+
else:
|
94 |
+
for i, (term, weight) in enumerate(sorted_representation):
|
95 |
+
if i >= 20: # Limit to top 20 terms for readability
|
96 |
+
break
|
97 |
+
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
98 |
+
|
99 |
+
formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
|
100 |
+
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
|
101 |
+
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer.vocab_size):.2%}\n" # Calculate sparsity
|
102 |
+
|
103 |
+
return formatted_output
|
104 |
+
|
105 |
+
# --- Gradio Interface Setup ---
|
106 |
+
demo = gr.Interface(
|
107 |
+
fn=get_splade_representation,
|
108 |
+
inputs=gr.Textbox(
|
109 |
+
lines=5,
|
110 |
+
label="Enter your query or document text here:",
|
111 |
+
placeholder="e.g., What are the capital cities of Europe?"
|
112 |
+
),
|
113 |
+
outputs=gr.Markdown(), # Use Markdown for richer text formatting (bolding terms)
|
114 |
+
title="🌌 SPLADE v23 Sparse Representation Generator",
|
115 |
+
description="Enter any text (query or document) to see its SPLADE v23 sparse vector representation. The output highlights the most important terms with their learned weights.",
|
116 |
+
allow_flagging="never" # Disable flagging for this demo
|
117 |
+
)
|
118 |
+
|
119 |
+
# Launch the Gradio app
|
120 |
demo.launch()
|