Spaces:
Configuration error
Configuration error
Update cross_encoder_reranking_train.py
Browse files- cross_encoder_reranking_train.py +17 -12
cross_encoder_reranking_train.py
CHANGED
@@ -276,24 +276,29 @@ def hybrid_score(cross_encoder_score, semantic_score, weight_cross=0.7, weight_s
|
|
276 |
def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=64, max_length=2048):
|
277 |
device = next(model.parameters()).device
|
278 |
cross_scores = []
|
279 |
-
query_emb = embed_text_list([query_text])[0]
|
280 |
|
281 |
instructed_query = get_detailed_instruct("", query_text)
|
282 |
|
283 |
-
|
284 |
-
|
285 |
|
286 |
-
|
|
|
287 |
|
288 |
with torch.no_grad():
|
289 |
-
batch_dict = tokenizer(
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
297 |
|
298 |
# Semantic scores
|
299 |
doc_embeddings = embed_text_list(doc_texts)
|
|
|
276 |
def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=64, max_length=2048):
|
277 |
device = next(model.parameters()).device
|
278 |
cross_scores = []
|
279 |
+
query_emb = embed_text_list([query_text])[0] # Move embedder to CPU
|
280 |
|
281 |
instructed_query = get_detailed_instruct("", query_text)
|
282 |
|
283 |
+
# Pre-create all input pairs (concatenation-based cross-encoder setup)
|
284 |
+
input_texts = [f"{instructed_query} {doc}" for doc in doc_texts]
|
285 |
|
286 |
+
for i in tqdm(range(0, len(input_texts), batch_size), desc="Scoring documents", leave=False):
|
287 |
+
batch_input_texts = input_texts[i:i+batch_size]
|
288 |
|
289 |
with torch.no_grad():
|
290 |
+
batch_dict = tokenizer(batch_input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt').to(device)
|
291 |
+
|
292 |
+
# Mixed precision for faster inference and lower memory
|
293 |
+
with torch.cuda.amp.autocast():
|
294 |
+
outputs = model(**batch_dict)
|
295 |
+
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
296 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
297 |
+
|
298 |
+
# Since queries are repeated in each pair, compare to instructed query embedding (first one)
|
299 |
+
query_vector = embeddings[0].unsqueeze(0) # Use first as query
|
300 |
+
batch_cross_scores = (query_vector @ embeddings.T).squeeze(0).cpu().numpy()[1:] # Exclude self-comparison
|
301 |
+
cross_scores.extend(batch_cross_scores)
|
302 |
|
303 |
# Semantic scores
|
304 |
doc_embeddings = embed_text_list(doc_texts)
|