SiddharthAK commited on
Commit
22a278f
·
verified ·
1 Parent(s): 37c6637

hopefully fixed unicoil

Browse files
Files changed (1) hide show
  1. app.py +42 -117
app.py CHANGED
@@ -86,131 +86,56 @@ def get_unicoil_binary_representation(text):
86
  return "UNICOIL model is not loaded. Please check the console for loading errors."
87
 
88
  inputs = tokenizer_unicoil(text, return_tensors="pt", padding=True, truncation=True)
 
 
89
  inputs = {k: v.to(model_unicoil.device) for k, v in inputs.items()}
90
 
91
  with torch.no_grad():
92
- # UNICOIL models often output a dictionary where 'token_scores' or similar
93
- # contain the learned weights for each token. The structure can vary.
94
- # For 'castorini/unicoil-msmarco-passage', the token scores are typically
95
- # the last hidden state of the model, which is then mapped by a linear layer
96
- # into the sparse weights. We might need to manually extract those,
97
- # or the model itself might be set up to produce the weights directly.
98
- # Based on typical UNICOIL implementations, we usually take the output
99
- # from the last layer and map it to vocabulary size.
100
-
101
- # In many UNICOIL variations, the model itself is designed to output
102
- # the "re-weight" scores for each token. Let's assume for simplicity
103
- # that the model's forward pass returns something that can be interpreted
104
- # as per-token scores, perhaps in `output.last_hidden_state`
105
- # and then a simple linear layer (not part of the AutoModel usually)
106
- # would project this to vocab size.
107
- # For simplicity and to fit `AutoModel`, we'll treat the last hidden state
108
- # directly as the basis for term importance for now, which is common in similar models,
109
- # or if the model already has a head, we use it.
110
-
111
- # A more robust UNICOIL implementation would involve a specific head
112
- # if not using AutoModelForMaskedLM. However, AutoModel gives us the
113
- # last hidden states from which we can infer.
114
-
115
- # For UNICOIL, we're interested in the weighted token scores.
116
- # `model(**inputs)` will typically return a `BaseModelOutput`
117
- # or `MaskedLMOutput` if it's based on an MLM.
118
- # Let's assume the model's output provides the token importance.
119
-
120
- # A common way to get UNICOIL scores if not explicitly provided as logits:
121
- # It's usually a linear layer on top of the last hidden state.
122
- # Since AutoModel just gives the base model, we'll mimic the output
123
- # as a direct mapping if the model doesn't have a specific head for scores.
124
- # However, looking at `castorini/unicoil-msmarco-passage`
125
- # its `config.json` might give hints or the model itself is structured.
126
- # Often, it uses `BertForMaskedLM` and then applies `log(1+relu)` to the logits.
127
- # Let's assume it behaves similar to SPLADE for simplicity of extraction for now,
128
- # or we might need to load it as `AutoModelForMaskedLM` if its internal structure
129
- # is indeed like that, and then apply a binarization.
130
-
131
- # Re-evaluating: UNICOIL typically *learns* explicit token weights.
132
- # The common approach for UNICOIL with Hugging Face is indeed to load it
133
- # as `AutoModelForMaskedLM` and use its `logits` output, similar to SPLADE,
134
- # but with a different aggregation strategy.
135
- # Let's verify the model type for 'castorini/unicoil-msmarco-passage'.
136
- # Its config.json and architecture implies it's a BertForMaskedLM variant.
137
-
138
- output = model_unicoil(**inputs) # This should be a BaseModelOutputWithPooling or similar
139
-
140
- if not hasattr(output, 'logits'):
141
- # If `model_unicoil` is an `AutoModel` without a classification head,
142
- # we need to add a way to get per-token scores.
143
- # This is where a custom model head or a specific model class would be needed.
144
- # For `castorini/unicoil-msmarco-passage`, it *is* an MLM variant.
145
- # So, `output.logits` *should* be available.
146
- return "UNICOIL model output structure not as expected. 'logits' not found."
147
-
148
- # UNICOIL's output is also typically per-token scores from the MLM head.
149
- # For UNICOIL, the weights are often taken directly from the logits after pooling.
150
- # Unlike SPLADE's log(1+ReLU), UNICOIL's approach can be simpler,
151
- # sometimes just taking the maximum of logits (or similar pooling).
152
- # A common binarization for UNICOIL is based on the sign of the re-weighted scores.
153
-
154
- # Let's mimic a common UNICOIL interpretation for obtaining sparse weights
155
- # from the logits. The weights are usually sparse and positive.
156
- # We can apply a threshold for binarization.
157
-
158
- # This is a simplification; actual UNICOIL might have specific layers.
159
- # For `castorini/unicoil-msmarco-passage`, it uses the `log(1+exp(logits))` formulation
160
- # followed by max pooling, then often binarization based on a threshold.
161
-
162
- # Applying a common interpretation of UNICOIL-like score generation for sparse weights:
163
- # Instead of `log(1+ReLU(logits))`, it often uses `torch.log(1 + torch.exp(output.logits))`.
164
- # This is essentially the softplus function, which makes values positive and sparse.
165
-
166
- # Get the sparse weights using the UNICOIL-like transformation
167
- sparse_weights = torch.max(torch.log(1 + torch.exp(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
168
-
169
- # --- Binarization Step for UNICOIL ---
170
- # For true "binary sparse", we threshold these sparse weights.
171
- # A common approach is to simply take any non-zero value as 1, and zero as 0.
172
- # Or, define a small threshold for binarization if values are very small but non-zero.
173
- # For simplicity, let's treat anything above a very small epsilon as 1.
174
-
175
- # Convert to binary: 1 if weight > epsilon, else 0
176
- threshold = 1e-6 # Define a small threshold for binarization
177
- binary_sparse_vector = (sparse_weights > threshold).int()
178
-
179
- # Get indices of the '1's in the binary vector
180
- binary_indices = torch.nonzero(binary_sparse_vector).squeeze().cpu().tolist()
181
-
182
- if not isinstance(binary_indices, list):
183
- binary_indices = [binary_indices] if binary_indices.numel() > 0 else []
184
-
185
- # Map token IDs back to terms for the binary representation
186
- binary_terms = {}
187
- for token_id in binary_indices:
188
- decoded_token = tokenizer_unicoil.decode([token_id])
189
- if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
190
- binary_terms[decoded_token] = 1 # Value is always 1 for binary
191
-
192
- sorted_binary_terms = sorted(binary_terms.items(), key=lambda item: item[0]) # Sort by term for consistent display
193
-
194
- formatted_output = "UNICOIL Binary Sparse Representation (Activated Terms):\n"
195
- if not sorted_binary_terms:
196
- formatted_output += "No significant terms activated for this input.\n"
197
- else:
198
- # Display up to 50 activated terms for readability
199
- for i, (term, _) in enumerate(sorted_binary_terms):
200
- if i >= 50:
201
- break
202
- formatted_output += f"- **{term}**\n" # Only show term, as weight is always 1
203
- if len(sorted_binary_terms) > 50:
204
  formatted_output += f"...and {len(sorted_binary_terms) - 50} more terms.\n"
205
-
206
- formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
207
- formatted_output += f"Total activated terms: {len(binary_indices)}\n"
208
- # Calculate sparsity based on the number of '1's vs. total vocabulary size
209
- formatted_output += f"Sparsity: {1 - (len(binary_indices) / tokenizer_unicoil.vocab_size):.2%}\n"
 
210
 
211
  return formatted_output
212
 
213
 
 
 
214
  # --- Unified Prediction Function for Gradio ---
215
  def predict_representation(model_choice, text):
216
  if model_choice == "SPLADE":
 
86
  return "UNICOIL model is not loaded. Please check the console for loading errors."
87
 
88
  inputs = tokenizer_unicoil(text, return_tensors="pt", padding=True, truncation=True)
89
+ input_ids = inputs["input_ids"]
90
+ attention_mask = inputs["attention_mask"]
91
  inputs = {k: v.to(model_unicoil.device) for k, v in inputs.items()}
92
 
93
  with torch.no_grad():
94
+ output = model_unicoil(**inputs)
95
+
96
+ if not hasattr(output, "logits"):
97
+ return "UNICOIL model output structure not as expected. 'logits' not found."
98
+
99
+ logits = output.logits.squeeze(0) # [seq_len, vocab_size]
100
+ token_ids = input_ids.squeeze(0) # [seq_len]
101
+ mask = attention_mask.squeeze(0) # [seq_len]
102
+
103
+ transformed_scores = torch.log(1 + torch.exp(logits)) # softplus
104
+ token_scores = transformed_scores[range(len(token_ids)), token_ids] # only scores for input tokens
105
+ token_scores = token_scores * mask # mask out padding
106
+
107
+ # Binarize: threshold scores > 0.5 (tune as needed)
108
+ binary_mask = (token_scores > 0.5)
109
+ activated_token_ids = token_ids[binary_mask].cpu().tolist()
110
+
111
+ # Map token ids to strings
112
+ binary_terms = {}
113
+ for token_id in activated_token_ids:
114
+ decoded_token = tokenizer_unicoil.decode([token_id])
115
+ if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
116
+ binary_terms[decoded_token] = 1
117
+
118
+ sorted_binary_terms = sorted(binary_terms.items(), key=lambda item: item[0])
119
+
120
+ formatted_output = "UNICOIL Binary Sparse Representation (Activated Terms):\n"
121
+ if not sorted_binary_terms:
122
+ formatted_output += "No significant terms activated for this input.\n"
123
+ else:
124
+ for i, (term, _) in enumerate(sorted_binary_terms):
125
+ if i >= 50:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  formatted_output += f"...and {len(sorted_binary_terms) - 50} more terms.\n"
127
+ break
128
+ formatted_output += f"- **{term}**\n"
129
+
130
+ formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
131
+ formatted_output += f"Total activated terms: {len(sorted_binary_terms)}\n"
132
+ formatted_output += f"Sparsity: {1 - (len(sorted_binary_terms) / tokenizer_unicoil.vocab_size):.2%}\n"
133
 
134
  return formatted_output
135
 
136
 
137
+
138
+
139
  # --- Unified Prediction Function for Gradio ---
140
  def predict_representation(model_choice, text):
141
  if model_choice == "SPLADE":