darpanaswal commited on
Commit
df4ff7a
·
verified ·
1 Parent(s): 19e7ebd

Update cross_encoder_reranking_train.py

Browse files
Files changed (1) hide show
  1. cross_encoder_reranking_train.py +17 -12
cross_encoder_reranking_train.py CHANGED
@@ -276,24 +276,29 @@ def hybrid_score(cross_encoder_score, semantic_score, weight_cross=0.7, weight_s
276
  def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=64, max_length=2048):
277
  device = next(model.parameters()).device
278
  cross_scores = []
279
- query_emb = embed_text_list([query_text])[0]
280
 
281
  instructed_query = get_detailed_instruct("", query_text)
282
 
283
- for i in tqdm(range(0, len(doc_texts), batch_size), desc="Scoring documents", leave=False):
284
- batch_docs = doc_texts[i:i+batch_size]
285
 
286
- input_texts = [instructed_query] + batch_docs
 
287
 
288
  with torch.no_grad():
289
- batch_dict = tokenizer(input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt').to(device)
290
-
291
- outputs = model(**batch_dict)
292
- embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
293
- embeddings = F.normalize(embeddings, p=2, dim=1)
294
-
295
- batch_cross_scores = (embeddings[0].unsqueeze(0) @ embeddings[1:].T).squeeze(0).cpu().numpy()
296
- cross_scores.extend(batch_cross_scores)
 
 
 
 
297
 
298
  # Semantic scores
299
  doc_embeddings = embed_text_list(doc_texts)
 
276
  def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=64, max_length=2048):
277
  device = next(model.parameters()).device
278
  cross_scores = []
279
+ query_emb = embed_text_list([query_text])[0] # Move embedder to CPU
280
 
281
  instructed_query = get_detailed_instruct("", query_text)
282
 
283
+ # Pre-create all input pairs (concatenation-based cross-encoder setup)
284
+ input_texts = [f"{instructed_query} {doc}" for doc in doc_texts]
285
 
286
+ for i in tqdm(range(0, len(input_texts), batch_size), desc="Scoring documents", leave=False):
287
+ batch_input_texts = input_texts[i:i+batch_size]
288
 
289
  with torch.no_grad():
290
+ batch_dict = tokenizer(batch_input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt').to(device)
291
+
292
+ # Mixed precision for faster inference and lower memory
293
+ with torch.cuda.amp.autocast():
294
+ outputs = model(**batch_dict)
295
+ embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
296
+ embeddings = F.normalize(embeddings, p=2, dim=1)
297
+
298
+ # Since queries are repeated in each pair, compare to instructed query embedding (first one)
299
+ query_vector = embeddings[0].unsqueeze(0) # Use first as query
300
+ batch_cross_scores = (query_vector @ embeddings.T).squeeze(0).cpu().numpy()[1:] # Exclude self-comparison
301
+ cross_scores.extend(batch_cross_scores)
302
 
303
  # Semantic scores
304
  doc_embeddings = embed_text_list(doc_texts)