add support for last token pooling
Browse files- semscore.py +20 -3
semscore.py
CHANGED
@@ -87,6 +87,7 @@ class SemScore(evaluate.Metric):
|
|
87 |
# Load model and tokenizer from HuggingFace Hub
|
88 |
self.model = AutoModel.from_pretrained(checkpoint)
|
89 |
self.model.eval()
|
|
|
90 |
self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
91 |
|
92 |
@staticmethod
|
@@ -95,6 +96,16 @@ class SemScore(evaluate.Metric):
|
|
95 |
token_embeddings = model_output[0]
|
96 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
97 |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
def _compute(
|
100 |
self,
|
@@ -102,10 +113,12 @@ class SemScore(evaluate.Metric):
|
|
102 |
references,
|
103 |
batch_size=32,
|
104 |
device=None,
|
|
|
105 |
):
|
106 |
"""Returns the scores"""
|
107 |
|
108 |
assert len(predictions) == len(references), "predictions and references should have the same length."
|
|
|
109 |
if device is not None:
|
110 |
if "cuda" in device:
|
111 |
assert torch.cuda.is_available()
|
@@ -123,8 +136,12 @@ class SemScore(evaluate.Metric):
|
|
123 |
encoded_preds = self.tokenizer(batch_preds, padding=True, truncation=True, return_tensors='pt')
|
124 |
model_output_refs = self.model(**encoded_refs.to(device))
|
125 |
model_output_preds = self.model(**encoded_preds.to(device))
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
128 |
pooled_refs.append(batch_pooled_refs)
|
129 |
pooled_preds.append(batch_pooled_preds)
|
130 |
pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)
|
@@ -136,4 +153,4 @@ class SemScore(evaluate.Metric):
|
|
136 |
return {
|
137 |
"semscore": round(semscore.item(), 2),
|
138 |
"similarities": similarities.tolist()
|
139 |
-
}
|
|
|
87 |
# Load model and tokenizer from HuggingFace Hub
|
88 |
self.model = AutoModel.from_pretrained(checkpoint)
|
89 |
self.model.eval()
|
90 |
+
padding_side = "left" if self
|
91 |
self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
92 |
|
93 |
@staticmethod
|
|
|
96 |
token_embeddings = model_output[0]
|
97 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
98 |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def _last_token_pooling(last_hidden_states, attention_mask):
|
102 |
+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
103 |
+
if left_padding:
|
104 |
+
return last_hidden_states[:, -1]
|
105 |
+
else:
|
106 |
+
sequence_lengths = attention_mask.sum(dim=1) - 1
|
107 |
+
batch_size = last_hidden_states.shape[0]
|
108 |
+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
109 |
|
110 |
def _compute(
|
111 |
self,
|
|
|
113 |
references,
|
114 |
batch_size=32,
|
115 |
device=None,
|
116 |
+
pooling="mean"
|
117 |
):
|
118 |
"""Returns the scores"""
|
119 |
|
120 |
assert len(predictions) == len(references), "predictions and references should have the same length."
|
121 |
+
assert pooling in ["mean", "last"]
|
122 |
if device is not None:
|
123 |
if "cuda" in device:
|
124 |
assert torch.cuda.is_available()
|
|
|
136 |
encoded_preds = self.tokenizer(batch_preds, padding=True, truncation=True, return_tensors='pt')
|
137 |
model_output_refs = self.model(**encoded_refs.to(device))
|
138 |
model_output_preds = self.model(**encoded_preds.to(device))
|
139 |
+
if pooling == "mean":
|
140 |
+
batch_pooled_refs = self._mean_pooling(model_output_refs, encoded_refs['attention_mask'])
|
141 |
+
batch_pooled_preds = self._mean_pooling(model_output_preds, encoded_preds['attention_mask'])
|
142 |
+
elif pooling == "last":
|
143 |
+
batch_pooled_refs = self._last_token_pooling(model_output_refs, encoded_refs['attention_mask'])
|
144 |
+
batch_pooled_preds = self._last_token_pooling(model_output_preds, encoded_preds['attention_mask'])
|
145 |
pooled_refs.append(batch_pooled_refs)
|
146 |
pooled_preds.append(batch_pooled_preds)
|
147 |
pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)
|
|
|
153 |
return {
|
154 |
"semscore": round(semscore.item(), 2),
|
155 |
"similarities": similarities.tolist()
|
156 |
+
}
|