Spaces:
Configuration error
Configuration error
Update cross_encoder_reranking_train.py
Browse files
cross_encoder_reranking_train.py
CHANGED
@@ -134,13 +134,9 @@ def extract_text(content_dict, text_type="full"):
|
|
134 |
for key, value in content_dict.items():
|
135 |
if key.startswith('c-'):
|
136 |
content.append(value)
|
137 |
-
if key=="features":
|
138 |
-
content+=list(content_dict[key].values())
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
# for key, value in content_dict["features"]:
|
143 |
-
# content.append(value)
|
144 |
return " ".join(content)
|
145 |
|
146 |
elif text_type == "tac1":
|
@@ -232,6 +228,13 @@ def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tenso
|
|
232 |
batch_size = last_hidden_states.shape[0]
|
233 |
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
def get_detailed_instruct(task_description: str, query: str) -> str:
|
236 |
"""Create an instruction-formatted query"""
|
237 |
return f'Instruct: {task_description}\nQuery: {query}'
|
@@ -273,7 +276,7 @@ def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=
|
|
273 |
|
274 |
# Get embeddings
|
275 |
outputs = model(**batch_dict)
|
276 |
-
embeddings =
|
277 |
|
278 |
# Normalize embeddings
|
279 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
|
|
134 |
for key, value in content_dict.items():
|
135 |
if key.startswith('c-'):
|
136 |
content.append(value)
|
|
|
|
|
137 |
|
138 |
+
for key, value in content_dict["features"]:
|
139 |
+
content.append(value)
|
|
|
|
|
140 |
return " ".join(content)
|
141 |
|
142 |
elif text_type == "tac1":
|
|
|
228 |
batch_size = last_hidden_states.shape[0]
|
229 |
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
230 |
|
231 |
+
def cls_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
232 |
+
"""Extract [CLS] token representations, accounting for left padding."""
|
233 |
+
# Get the index of the first non-padding token in each sequence
|
234 |
+
cls_indices = attention_mask.float().argmax(dim=1)
|
235 |
+
batch_size = last_hidden_states.size(0)
|
236 |
+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), cls_indices]
|
237 |
+
|
238 |
def get_detailed_instruct(task_description: str, query: str) -> str:
|
239 |
"""Create an instruction-formatted query"""
|
240 |
return f'Instruct: {task_description}\nQuery: {query}'
|
|
|
276 |
|
277 |
# Get embeddings
|
278 |
outputs = model(**batch_dict)
|
279 |
+
embeddings = cls_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
280 |
|
281 |
# Normalize embeddings
|
282 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|