SiddharthAK commited on
Commit
1728ea0
·
verified ·
1 Parent(s): afa84d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -50,7 +50,7 @@ def get_splade_representation(text):
50
  # output.logits is typically [batch_size, sequence_length, vocab_size]
51
  # We need to take the max over the sequence_length dimension to get a [batch_size, vocab_size] vector.
52
  # inputs.attention_mask.unsqueeze(-1) expands the mask to match vocab_size for element-wise multiplication.
53
- splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs.attention_mask.unsqueeze(-1), dim=1)[0].squeeze()
54
  else:
55
  # Fallback/error message if the output structure is unexpected
56
  return "Model output structure not as expected for SPLADE. 'logits' not found."
 
50
  # output.logits is typically [batch_size, sequence_length, vocab_size]
51
  # We need to take the max over the sequence_length dimension to get a [batch_size, vocab_size] vector.
52
  # inputs.attention_mask.unsqueeze(-1) expands the mask to match vocab_size for element-wise multiplication.
53
+ splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
54
  else:
55
  # Fallback/error message if the output structure is unexpected
56
  return "Model output structure not as expected for SPLADE. 'logits' not found."