darpanaswal commited on
Commit
58f39f1
·
verified ·
1 Parent(s): 44d07af

Update cross_encoder_reranking_train.py

Browse files
Files changed (1) hide show
  1. cross_encoder_reranking_train.py +11 -2
cross_encoder_reranking_train.py CHANGED
@@ -37,8 +37,17 @@ def cluster_and_rank(texts, threshold=0.75):
37
  return texts
38
 
39
  embeddings = embed_text_list(texts)
40
- clustering = AgglomerativeClustering(n_clusters=None, distance_threshold=1-threshold, metric = "cosine", linkage='average')
41
- labels = clustering.fit_predict(embeddings)
 
 
 
 
 
 
 
 
 
42
 
43
  clustered_texts = {}
44
  for label, text in zip(labels, texts):
 
37
  return texts
38
 
39
  embeddings = embed_text_list(texts)
40
+ similarity_matrix = np.dot(embeddings, np.transpose(embeddings))
41
+ # Convert similarity to distance (larger dot product = more similar => smaller distance)
42
+ distance_matrix = -similarity_matrix # negative dot product for distance-like behavior
43
+
44
+ clustering = AgglomerativeClustering(
45
+ n_clusters=None,
46
+ distance_threshold=-threshold, # lower threshold = tighter clusters
47
+ affinity='precomputed',
48
+ linkage='average'
49
+ )
50
+ labels = clustering.fit_predict(distance_matrix)
51
 
52
  clustered_texts = {}
53
  for label, text in zip(labels, texts):