AnshulS commited on
Commit
7e0bee0
·
verified ·
1 Parent(s): 385103f

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +5 -16
retriever.py CHANGED
@@ -7,19 +7,8 @@ def get_relevant_passages(query, df, top_k=20):
7
  # Create a copy to avoid modifying the original dataframe
8
  df_copy = df.copy()
9
 
10
- # Print shape for debugging
11
- print(f"DataFrame shape: {df_copy.shape}")
12
- print(f"DataFrame columns: {df_copy.columns.tolist()}")
13
-
14
- # Handle missing columns gracefully
15
- for col in ['description', 'test_type', 'adaptive_support', 'remote_support', 'duration']:
16
- if col not in df_copy.columns:
17
- df_copy[col] = 'N/A'
18
-
19
  # Ensure URL field is properly formatted
20
- if 'url' not in df_copy.columns:
21
- df_copy['url'] = 'https://www.shl.com/missing-url'
22
- else:
23
  # Clean up URLs if needed
24
  df_copy['url'] = df_copy['url'].astype(str)
25
  # Ensure URLs start with http or https
@@ -29,16 +18,16 @@ def get_relevant_passages(query, df, top_k=20):
29
  # Format test_type for better representation
30
  def format_test_type(test_types):
31
  if isinstance(test_types, list):
32
- return ', '.join([str(t) for t in test_types if t])
33
  return str(test_types)
34
 
35
  # Concatenate all fields into a single string per row
36
  corpus = df_copy.apply(
37
- lambda row: f"{row.get('assessment_name', '')} {row.get('description', '')} "
38
  f"Test types: {format_test_type(row['test_type'])}. "
39
  f"Adaptive support: {row['adaptive_support']}. "
40
  f"Remote support: {row['remote_support']}. "
41
- f"Duration: {row['duration']} minutes.",
42
  axis=1
43
  ).tolist()
44
 
@@ -46,4 +35,4 @@ def get_relevant_passages(query, df, top_k=20):
46
  query_embedding = model.encode(query, convert_to_tensor=True)
47
  hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
48
 
49
- return df_copy.iloc[[hit['corpus_id'] for hit in hits]]
 
7
  # Create a copy to avoid modifying the original dataframe
8
  df_copy = df.copy()
9
 
 
 
 
 
 
 
 
 
 
10
  # Ensure URL field is properly formatted
11
+ if 'url' in df_copy.columns:
 
 
12
  # Clean up URLs if needed
13
  df_copy['url'] = df_copy['url'].astype(str)
14
  # Ensure URLs start with http or https
 
18
  # Format test_type for better representation
19
  def format_test_type(test_types):
20
  if isinstance(test_types, list):
21
+ return ', '.join(test_types)
22
  return str(test_types)
23
 
24
  # Concatenate all fields into a single string per row
25
  corpus = df_copy.apply(
26
+ lambda row: f"{row['description']} "
27
  f"Test types: {format_test_type(row['test_type'])}. "
28
  f"Adaptive support: {row['adaptive_support']}. "
29
  f"Remote support: {row['remote_support']}. "
30
+ f"Duration: {row['duration'] if pd.notna(row['duration']) else 'N/A'} minutes.",
31
  axis=1
32
  ).tolist()
33
 
 
35
  query_embedding = model.encode(query, convert_to_tensor=True)
36
  hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
37
 
38
+ return df_copy.iloc[[hit['corpus_id'] for hit in hits]]