Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -50,7 +50,7 @@ def get_splade_representation(text):
|
|
50 |
# output.logits is typically [batch_size, sequence_length, vocab_size]
|
51 |
# We need to take the max over the sequence_length dimension to get a [batch_size, vocab_size] vector.
|
52 |
# inputs.attention_mask.unsqueeze(-1) expands the mask to match vocab_size for element-wise multiplication.
|
53 |
-
splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs
|
54 |
else:
|
55 |
# Fallback/error message if the output structure is unexpected
|
56 |
return "Model output structure not as expected for SPLADE. 'logits' not found."
|
|
|
50 |
# output.logits is typically [batch_size, sequence_length, vocab_size]
|
51 |
# We need to take the max over the sequence_length dimension to get a [batch_size, vocab_size] vector.
|
52 |
# inputs.attention_mask.unsqueeze(-1) expands the mask to match vocab_size for element-wise multiplication.
|
53 |
+
splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
|
54 |
else:
|
55 |
# Fallback/error message if the output structure is unexpected
|
56 |
return "Model output structure not as expected for SPLADE. 'logits' not found."
|