AnshulS commited on
Commit
e5766c5
·
verified ·
1 Parent(s): 6e0aaf8

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +16 -5
retriever.py CHANGED
@@ -7,8 +7,19 @@ def get_relevant_passages(query, df, top_k=20):
7
  # Create a copy to avoid modifying the original dataframe
8
  df_copy = df.copy()
9
 
 
 
 
 
 
 
 
 
 
10
  # Ensure URL field is properly formatted
11
- if 'url' in df_copy.columns:
 
 
12
  # Clean up URLs if needed
13
  df_copy['url'] = df_copy['url'].astype(str)
14
  # Ensure URLs start with http or https
@@ -18,16 +29,16 @@ def get_relevant_passages(query, df, top_k=20):
18
  # Format test_type for better representation
19
  def format_test_type(test_types):
20
  if isinstance(test_types, list):
21
- return ', '.join(test_types)
22
  return str(test_types)
23
 
24
  # Concatenate all fields into a single string per row
25
  corpus = df_copy.apply(
26
- lambda row: f"{row['description']} "
27
  f"Test types: {format_test_type(row['test_type'])}. "
28
  f"Adaptive support: {row['adaptive_support']}. "
29
  f"Remote support: {row['remote_support']}. "
30
- f"Duration: {row['duration'] if pd.notna(row['duration']) else 'N/A'} minutes.",
31
  axis=1
32
  ).tolist()
33
 
@@ -35,4 +46,4 @@ def get_relevant_passages(query, df, top_k=20):
35
  query_embedding = model.encode(query, convert_to_tensor=True)
36
  hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
37
 
38
- return df_copy.iloc[[hit['corpus_id'] for hit in hits]]
 
7
  # Create a copy to avoid modifying the original dataframe
8
  df_copy = df.copy()
9
 
10
+ # Print shape for debugging
11
+ print(f"DataFrame shape: {df_copy.shape}")
12
+ print(f"DataFrame columns: {df_copy.columns.tolist()}")
13
+
14
+ # Handle missing columns gracefully
15
+ for col in ['description', 'test_type', 'adaptive_support', 'remote_support', 'duration']:
16
+ if col not in df_copy.columns:
17
+ df_copy[col] = 'N/A'
18
+
19
  # Ensure URL field is properly formatted
20
+ if 'url' not in df_copy.columns:
21
+ df_copy['url'] = 'https://www.shl.com/missing-url'
22
+ else:
23
  # Clean up URLs if needed
24
  df_copy['url'] = df_copy['url'].astype(str)
25
  # Ensure URLs start with http or https
 
29
  # Format test_type for better representation
30
  def format_test_type(test_types):
31
  if isinstance(test_types, list):
32
+ return ', '.join([str(t) for t in test_types if t])
33
  return str(test_types)
34
 
35
  # Concatenate all fields into a single string per row
36
  corpus = df_copy.apply(
37
+ lambda row: f"{row.get('assessment_name', '')} {row.get('description', '')} "
38
  f"Test types: {format_test_type(row['test_type'])}. "
39
  f"Adaptive support: {row['adaptive_support']}. "
40
  f"Remote support: {row['remote_support']}. "
41
+ f"Duration: {row['duration']} minutes.",
42
  axis=1
43
  ).tolist()
44
 
 
46
  query_embedding = model.encode(query, convert_to_tensor=True)
47
  hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
48
 
49
+ return df_copy.iloc[[hit['corpus_id'] for hit in hits]]