darpanaswal commited on
Commit
a0a8763
·
verified ·
1 Parent(s): 4ee2e0d

Update cross_encoder_reranking_train.py

Browse files
Files changed (1) hide show
  1. cross_encoder_reranking_train.py +4 -1
cross_encoder_reranking_train.py CHANGED
@@ -11,11 +11,14 @@ from transformers import AutoTokenizer, AutoModel
11
  from sklearn.cluster import AgglomerativeClustering
12
  from sklearn.metrics.pairwise import cosine_similarity
13
 
 
14
  # Load embedder once
15
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
 
 
16
 
17
  def embed_text_list(texts):
18
- return embedder.encode(texts, convert_to_tensor=False)
19
 
20
  def rank_by_centrality(texts):
21
  embeddings = embed_text_list(texts)
 
11
  from sklearn.cluster import AgglomerativeClustering
12
  from sklearn.metrics.pairwise import cosine_similarity
13
 
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  # Load embedder once
16
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
17
+ embedder = embedder.to(device)
18
+
19
 
20
  def embed_text_list(texts):
21
+ return embedder.encode(texts, convert_to_tensor=False, device=device)
22
 
23
  def rank_by_centrality(texts):
24
  embeddings = embed_text_list(texts)