SiddharthAK commited on
Commit
3035463
·
verified ·
1 Parent(s): a4de739

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -3
app.py CHANGED
@@ -1,7 +1,120 @@
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
3
+ import torch
4
 
5
+ # --- Model Loading ---
6
+ # Load SPLADE v23 model and tokenizer
7
+ # "naver/splade-v3" is a common and robust SPLADE v3 model.
8
+ # Make sure you've accepted any user access agreements on its Hugging Face Hub page.
9
+ try:
10
+ tokenizer = AutoTokenizer.from_pretrained("naver/splade-v3")
11
+ model = AutoModelForMaskedLM.from_pretrained("naver/splade-v3")
12
+ model.eval() # Set the model to evaluation mode for inference
13
+ print("SPLADE v23 model and tokenizer loaded successfully!")
14
+ except Exception as e:
15
+ print(f"Error loading SPLADE model or tokenizer: {e}")
16
+ print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-v3' (https://huggingface.co/naver/splade-v3).")
17
+ print("If the problem persists, check your internet connection or try a different SPLADE model if available.")
18
+ tokenizer = None
19
+ model = None
20
 
21
+ # --- Core SPLADE Representation Function ---
22
+ def get_splade_representation(text):
23
+ if tokenizer is None or model is None:
24
+ return "SPLADE model is not loaded. Please check the console for loading errors."
25
+
26
+ # Tokenize the input text
27
+ # return_tensors="pt" ensures PyTorch tensors are returned.
28
+ # padding=True pads to the longest sequence in the batch (though for single input, it's just the input length).
29
+ # truncation=True truncates if the text is too long for the model's max input size.
30
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
31
+
32
+ # Move inputs to the same device as the model (e.g., CPU or GPU)
33
+ # This is important if you were running on a GPU in a production environment.
34
+ # For Hugging Face Spaces free tier, it's usually CPU.
35
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
36
+
37
+ # Get the model's output without calculating gradients (inference mode)
38
+ with torch.no_grad():
39
+ output = model(**inputs)
40
+
41
+ # Extract the logits from the model's output.
42
+ # SPLADE uses the masked language modeling head's logits to derive term importance.
43
+ # Apply the SPLADE aggregation function: log(1 + ReLU(logits))
44
+ # This transforms the raw logits into a sparse vector where higher values indicate more importance.
45
+ # The attention_mask ensures we only consider actual tokens, not padding.
46
+
47
+ # Check if 'logits' is in the output (standard for AutoModelForMaskedLM)
48
+ if hasattr(output, 'logits'):
49
+ # Apply the SPLADE transformation
50
+ # output.logits is typically [batch_size, sequence_length, vocab_size]
51
+ # We need to take the max over the sequence_length dimension to get a [batch_size, vocab_size] vector.
52
+ # inputs.attention_mask.unsqueeze(-1) expands the mask to match vocab_size for element-wise multiplication.
53
+ splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs.attention_mask.unsqueeze(-1), dim=1)[0].squeeze()
54
+ else:
55
+ # Fallback/error message if the output structure is unexpected
56
+ return "Model output structure not as expected for SPLADE. 'logits' not found."
57
+
58
+ # Convert the sparse vector to a human-readable format.
59
+ # We only care about the non-zero (or very small) entries, as they represent activated terms.
60
+
61
+ # Get the indices (token IDs) of the non-zero elements in the SPLADE vector
62
+ # torch.nonzero returns coordinates of non-zero elements. squeeze() removes dimensions of size 1.
63
+ # .cpu().tolist() moves the tensor to CPU and converts to a Python list.
64
+ indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
65
+
66
+ # If it's a single index (e.g., a very short text), make it a list for consistent processing
67
+ if not isinstance(indices, list):
68
+ indices = [indices]
69
+
70
+ # Get the corresponding values (weights) for these non-zero indices
71
+ values = splade_vector[indices].cpu().tolist()
72
+
73
+ # Create a dictionary mapping token ID to its weight
74
+ token_weights = dict(zip(indices, values))
75
+
76
+ # Decode token IDs back to actual words/subwords
77
+ # Filter out common special tokens that are not meaningful for retrieval (e.g., [CLS], [SEP], [PAD])
78
+ # You can add more tokens to this list if they appear frequently and are not helpful.
79
+ meaningful_tokens = {}
80
+ for token_id, weight in token_weights.items():
81
+ decoded_token = tokenizer.decode([token_id])
82
+ # Filter out special tokens or very short/noisy tokens
83
+ if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
84
+ meaningful_tokens[decoded_token] = weight
85
+
86
+ # Sort the meaningful tokens by their weight in descending order
87
+ sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
88
+
89
+ # Format the output for display
90
+ formatted_output = "SPLADE Representation (Top 20 Terms):\n"
91
+ if not sorted_representation:
92
+ formatted_output += "No significant terms found for this input.\n"
93
+ else:
94
+ for i, (term, weight) in enumerate(sorted_representation):
95
+ if i >= 20: # Limit to top 20 terms for readability
96
+ break
97
+ formatted_output += f"- **{term}**: {weight:.4f}\n"
98
+
99
+ formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
100
+ formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
101
+ formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer.vocab_size):.2%}\n" # Calculate sparsity
102
+
103
+ return formatted_output
104
+
105
+ # --- Gradio Interface Setup ---
106
+ demo = gr.Interface(
107
+ fn=get_splade_representation,
108
+ inputs=gr.Textbox(
109
+ lines=5,
110
+ label="Enter your query or document text here:",
111
+ placeholder="e.g., What are the capital cities of Europe?"
112
+ ),
113
+ outputs=gr.Markdown(), # Use Markdown for richer text formatting (bolding terms)
114
+ title="🌌 SPLADE v23 Sparse Representation Generator",
115
+ description="Enter any text (query or document) to see its SPLADE v23 sparse vector representation. The output highlights the most important terms with their learned weights.",
116
+ allow_flagging="never" # Disable flagging for this demo
117
+ )
118
+
119
+ # Launch the Gradio app
120
  demo.launch()