SiddharthAK commited on
Commit
45b666a
·
verified ·
1 Parent(s): 22a278f

output all nonzero terms (instead of top 20)

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -46,14 +46,17 @@ def get_splade_representation(text):
46
  output = model_splade(**inputs)
47
 
48
  if hasattr(output, 'logits'):
49
- splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
 
 
 
50
  else:
51
  return "Model output structure not as expected for SPLADE. 'logits' not found."
52
 
53
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
54
  if not isinstance(indices, list):
55
  indices = [indices]
56
-
57
  values = splade_vector[indices].cpu().tolist()
58
  token_weights = dict(zip(indices, values))
59
 
@@ -65,15 +68,13 @@ def get_splade_representation(text):
65
 
66
  sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
67
 
68
- formatted_output = "SPLADE Representation (Top 20 Terms):\n"
69
  if not sorted_representation:
70
  formatted_output += "No significant terms found for this input.\n"
71
  else:
72
- for i, (term, weight) in enumerate(sorted_representation):
73
- if i >= 20:
74
- break
75
  formatted_output += f"- **{term}**: {weight:.4f}\n"
76
-
77
  formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
78
  formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
79
  formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n"
@@ -81,6 +82,8 @@ def get_splade_representation(text):
81
  return formatted_output
82
 
83
 
 
 
84
  def get_unicoil_binary_representation(text):
85
  if tokenizer_unicoil is None or model_unicoil is None:
86
  return "UNICOIL model is not loaded. Please check the console for loading errors."
 
46
  output = model_splade(**inputs)
47
 
48
  if hasattr(output, 'logits'):
49
+ splade_vector = torch.max(
50
+ torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
51
+ dim=1
52
+ )[0].squeeze()
53
  else:
54
  return "Model output structure not as expected for SPLADE. 'logits' not found."
55
 
56
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
57
  if not isinstance(indices, list):
58
  indices = [indices]
59
+
60
  values = splade_vector[indices].cpu().tolist()
61
  token_weights = dict(zip(indices, values))
62
 
 
68
 
69
  sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
70
 
71
+ formatted_output = "SPLADE Representation (All Non-Zero Terms):\n"
72
  if not sorted_representation:
73
  formatted_output += "No significant terms found for this input.\n"
74
  else:
75
+ for term, weight in sorted_representation:
 
 
76
  formatted_output += f"- **{term}**: {weight:.4f}\n"
77
+
78
  formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
79
  formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
80
  formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n"
 
82
  return formatted_output
83
 
84
 
85
+
86
+
87
  def get_unicoil_binary_representation(text):
88
  if tokenizer_unicoil is None or model_unicoil is None:
89
  return "UNICOIL model is not loaded. Please check the console for loading errors."