aynetdia commited on
Commit
c337da5
·
1 Parent(s): 5b86628

add support for last token pooling

Browse files
Files changed (1) hide show
  1. semscore.py +20 -3
semscore.py CHANGED
@@ -87,6 +87,7 @@ class SemScore(evaluate.Metric):
87
  # Load model and tokenizer from HuggingFace Hub
88
  self.model = AutoModel.from_pretrained(checkpoint)
89
  self.model.eval()
 
90
  self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
91
 
92
  @staticmethod
@@ -95,6 +96,16 @@ class SemScore(evaluate.Metric):
95
  token_embeddings = model_output[0]
96
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
97
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
 
 
 
 
 
 
 
 
 
 
98
 
99
  def _compute(
100
  self,
@@ -102,10 +113,12 @@ class SemScore(evaluate.Metric):
102
  references,
103
  batch_size=32,
104
  device=None,
 
105
  ):
106
  """Returns the scores"""
107
 
108
  assert len(predictions) == len(references), "predictions and references should have the same length."
 
109
  if device is not None:
110
  if "cuda" in device:
111
  assert torch.cuda.is_available()
@@ -123,8 +136,12 @@ class SemScore(evaluate.Metric):
123
  encoded_preds = self.tokenizer(batch_preds, padding=True, truncation=True, return_tensors='pt')
124
  model_output_refs = self.model(**encoded_refs.to(device))
125
  model_output_preds = self.model(**encoded_preds.to(device))
126
- batch_pooled_refs = self._mean_pooling(model_output_refs, encoded_refs['attention_mask'])
127
- batch_pooled_preds = self._mean_pooling(model_output_preds, encoded_preds['attention_mask'])
 
 
 
 
128
  pooled_refs.append(batch_pooled_refs)
129
  pooled_preds.append(batch_pooled_preds)
130
  pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)
@@ -136,4 +153,4 @@ class SemScore(evaluate.Metric):
136
  return {
137
  "semscore": round(semscore.item(), 2),
138
  "similarities": similarities.tolist()
139
- }
 
87
  # Load model and tokenizer from HuggingFace Hub
88
  self.model = AutoModel.from_pretrained(checkpoint)
89
  self.model.eval()
90
+ padding_side = "left" if self
91
  self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
92
 
93
  @staticmethod
 
96
  token_embeddings = model_output[0]
97
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
98
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
99
+
100
+ @staticmethod
101
+ def _last_token_pooling(last_hidden_states, attention_mask):
102
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
103
+ if left_padding:
104
+ return last_hidden_states[:, -1]
105
+ else:
106
+ sequence_lengths = attention_mask.sum(dim=1) - 1
107
+ batch_size = last_hidden_states.shape[0]
108
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
109
 
110
  def _compute(
111
  self,
 
113
  references,
114
  batch_size=32,
115
  device=None,
116
+ pooling="mean"
117
  ):
118
  """Returns the scores"""
119
 
120
  assert len(predictions) == len(references), "predictions and references should have the same length."
121
+ assert pooling in ["mean", "last"]
122
  if device is not None:
123
  if "cuda" in device:
124
  assert torch.cuda.is_available()
 
136
  encoded_preds = self.tokenizer(batch_preds, padding=True, truncation=True, return_tensors='pt')
137
  model_output_refs = self.model(**encoded_refs.to(device))
138
  model_output_preds = self.model(**encoded_preds.to(device))
139
+ if pooling == "mean":
140
+ batch_pooled_refs = self._mean_pooling(model_output_refs, encoded_refs['attention_mask'])
141
+ batch_pooled_preds = self._mean_pooling(model_output_preds, encoded_preds['attention_mask'])
142
+ elif pooling == "last":
143
+ batch_pooled_refs = self._last_token_pooling(model_output_refs, encoded_refs['attention_mask'])
144
+ batch_pooled_preds = self._last_token_pooling(model_output_preds, encoded_preds['attention_mask'])
145
  pooled_refs.append(batch_pooled_refs)
146
  pooled_preds.append(batch_pooled_preds)
147
  pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)
 
153
  return {
154
  "semscore": round(semscore.item(), 2),
155
  "similarities": similarities.tolist()
156
+ }