darpanaswal commited on
Commit
87f4272
·
verified ·
1 Parent(s): 3253735

Update cross_encoder_reranking_train.py

Browse files
Files changed (1) hide show
  1. cross_encoder_reranking_train.py +10 -7
cross_encoder_reranking_train.py CHANGED
@@ -134,13 +134,9 @@ def extract_text(content_dict, text_type="full"):
134
  for key, value in content_dict.items():
135
  if key.startswith('c-'):
136
  content.append(value)
137
- if key=="features":
138
- content+=list(content_dict[key].values())
139
 
140
- # for key, _ in content_dict.items
141
-
142
- # for key, value in content_dict["features"]:
143
- # content.append(value)
144
  return " ".join(content)
145
 
146
  elif text_type == "tac1":
@@ -232,6 +228,13 @@ def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tenso
232
  batch_size = last_hidden_states.shape[0]
233
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
234
 
 
 
 
 
 
 
 
235
  def get_detailed_instruct(task_description: str, query: str) -> str:
236
  """Create an instruction-formatted query"""
237
  return f'Instruct: {task_description}\nQuery: {query}'
@@ -273,7 +276,7 @@ def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=
273
 
274
  # Get embeddings
275
  outputs = model(**batch_dict)
276
- embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
277
 
278
  # Normalize embeddings
279
  embeddings = F.normalize(embeddings, p=2, dim=1)
 
134
  for key, value in content_dict.items():
135
  if key.startswith('c-'):
136
  content.append(value)
 
 
137
 
138
+ for key, value in content_dict["features"]:
139
+ content.append(value)
 
 
140
  return " ".join(content)
141
 
142
  elif text_type == "tac1":
 
228
  batch_size = last_hidden_states.shape[0]
229
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
230
 
231
+ def cls_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
232
+ """Extract [CLS] token representations, accounting for left padding."""
233
+ # Get the index of the first non-padding token in each sequence
234
+ cls_indices = attention_mask.float().argmax(dim=1)
235
+ batch_size = last_hidden_states.size(0)
236
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), cls_indices]
237
+
238
  def get_detailed_instruct(task_description: str, query: str) -> str:
239
  """Create an instruction-formatted query"""
240
  return f'Instruct: {task_description}\nQuery: {query}'
 
276
 
277
  # Get embeddings
278
  outputs = model(**batch_dict)
279
+ embeddings = cls_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
280
 
281
  # Normalize embeddings
282
  embeddings = F.normalize(embeddings, p=2, dim=1)