AnshulS commited on
Commit
397f5c9
·
verified ·
1 Parent(s): 759fd26

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +28 -9
retriever.py CHANGED
@@ -4,16 +4,35 @@ from sentence_transformers import SentenceTransformer, util
4
  model = SentenceTransformer("all-MiniLM-L6-v2")
5
 
6
  def get_relevant_passages(query, df, top_k=20):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  # Concatenate all fields into a single string per row
8
- corpus = df.apply(lambda row: f"{row['description']} "
9
- f"Test types: {', '.join(row['test_type'])}. "
10
- f"Adaptive support: {row['adaptive_support']}. "
11
- f"Remote support: {row['remote_support']}. "
12
- f"Duration: {row['duration']} minutes.",
13
- axis=1).tolist()
14
-
 
 
15
  corpus_embeddings = model.encode(corpus, convert_to_tensor=True)
16
  query_embedding = model.encode(query, convert_to_tensor=True)
17
-
18
  hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
19
- return df.iloc[[hit['corpus_id'] for hit in hits]]
 
 
4
  model = SentenceTransformer("all-MiniLM-L6-v2")
5
 
6
  def get_relevant_passages(query, df, top_k=20):
7
+ # Create a copy to avoid modifying the original dataframe
8
+ df_copy = df.copy()
9
+
10
+ # Ensure URL field is properly formatted
11
+ if 'url' in df_copy.columns:
12
+ # Clean up URLs if needed
13
+ df_copy['url'] = df_copy['url'].astype(str)
14
+ # Ensure URLs start with http or https
15
+ mask = ~df_copy['url'].str.startswith(('http://', 'https://'))
16
+ df_copy.loc[mask, 'url'] = 'https://www.shl.com/' + df_copy.loc[mask, 'url'].str.lstrip('/')
17
+
18
+ # Format test_type for better representation
19
+ def format_test_type(test_types):
20
+ if isinstance(test_types, list):
21
+ return ', '.join(test_types)
22
+ return str(test_types)
23
+
24
  # Concatenate all fields into a single string per row
25
+ corpus = df_copy.apply(
26
+ lambda row: f"{row['description']} "
27
+ f"Test types: {format_test_type(row['test_type'])}. "
28
+ f"Adaptive support: {row['adaptive_support']}. "
29
+ f"Remote support: {row['remote_support']}. "
30
+ f"Duration: {row['duration'] if pd.notna(row['duration']) else 'N/A'} minutes.",
31
+ axis=1
32
+ ).tolist()
33
+
34
  corpus_embeddings = model.encode(corpus, convert_to_tensor=True)
35
  query_embedding = model.encode(query, convert_to_tensor=True)
 
36
  hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
37
+
38
+ return df_copy.iloc[[hit['corpus_id'] for hit in hits]]