darpanaswal commited on
Commit
d753b02
·
verified ·
1 Parent(s): a0e2313

Update cross_encoder_reranking_train.py

Browse files
Files changed (1) hide show
  1. cross_encoder_reranking_train.py +6 -2
cross_encoder_reranking_train.py CHANGED
@@ -17,9 +17,13 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
  embedder = SentenceTransformer("intfloat/e5-large-v2").to(device)
18
 
19
 
 
 
 
20
  def embed_text_list(texts):
21
- # return embedder.encode(texts, convert_to_tensor=False, device=device)
22
- return embedder.encode([f"query: {texts}"], convert_to_tensor=False, device=device)
 
23
 
24
  def rank_by_centrality(texts):
25
  embeddings = embed_text_list(texts)
 
17
  embedder = SentenceTransformer("intfloat/e5-large-v2").to(device)
18
 
19
 
20
+ # def embed_text_list(texts):
21
+ # return embedder.encode(texts, convert_to_tensor=False, device=device)
22
+
23
  def embed_text_list(texts):
24
+ # E5 models expect "query: " prefix for proper embedding behavior
25
+ formatted_texts = [f"query: {text}" for text in texts]
26
+ return embedder.encode(formatted_texts, convert_to_tensor=False, device=device)
27
 
28
  def rank_by_centrality(texts):
29
  embeddings = embed_text_list(texts)