darpanaswal commited on
Commit
d807d27
·
verified ·
1 Parent(s): 00a410f

Update cross_encoder_reranking_train.py

Browse files
Files changed (1) hide show
  1. cross_encoder_reranking_train.py +6 -6
cross_encoder_reranking_train.py CHANGED
@@ -17,13 +17,13 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
  embedder = SentenceTransformer("AI-Growth-Lab/PatentSBERTa").to(device)
18
 
19
 
20
- # def embed_text_list(texts):
21
- # return embedder.encode(texts, convert_to_tensor=False, device=device)
22
-
23
  def embed_text_list(texts):
24
- # E5 models expect "query: " prefix for proper embedding behavior
25
- formatted_texts = [f"query: {text}" for text in texts]
26
- return embedder.encode(formatted_texts, convert_to_tensor=False, device=device)
 
 
 
27
 
28
  def rank_by_centrality(texts):
29
  embeddings = embed_text_list(texts)
 
17
  embedder = SentenceTransformer("AI-Growth-Lab/PatentSBERTa").to(device)
18
 
19
 
 
 
 
20
  def embed_text_list(texts):
21
+ return embedder.encode(texts, convert_to_tensor=False, device=device)
22
+
23
+ # def embed_text_list(texts):
24
+ # # E5 models expect "query: " prefix for proper embedding behavior
25
+ # formatted_texts = [f"query: {text}" for text in texts]
26
+ # return embedder.encode(formatted_texts, convert_to_tensor=False, device=device)
27
 
28
  def rank_by_centrality(texts):
29
  embeddings = embed_text_list(texts)