AnshulS commited on
Commit
24f19c6
·
verified ·
1 Parent(s): 9d9d3fa

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +62 -25
retriever.py CHANGED
@@ -1,38 +1,75 @@
1
  import pandas as pd
2
  from sentence_transformers import SentenceTransformer, util
 
3
 
4
  model = SentenceTransformer("all-MiniLM-L6-v2")
5
 
 
 
 
 
 
 
 
 
 
 
 
6
  def get_relevant_passages(query, df, top_k=20):
 
7
  # Create a copy to avoid modifying the original dataframe
8
  df_copy = df.copy()
9
 
10
- # Ensure URL field is properly formatted
11
- if 'url' in df_copy.columns:
12
- # Clean up URLs if needed
13
- df_copy['url'] = df_copy['url'].astype(str)
14
- # Ensure URLs start with http or https
15
- mask = ~df_copy['url'].str.startswith(('http://', 'https://'))
16
- df_copy.loc[mask, 'url'] = 'https://www.shl.com/' + df_copy.loc[mask, 'url'].str.lstrip('/')
17
-
18
- # Format test_type for better representation
19
- def format_test_type(test_types):
20
- if isinstance(test_types, list):
21
- return ', '.join(test_types)
22
- return str(test_types)
23
-
24
- # Concatenate all fields into a single string per row
25
- corpus = df_copy.apply(
26
- lambda row: f"{row['description']} "
27
- f"Test types: {format_test_type(row['test_type'])}. "
28
- f"Adaptive support: {row['adaptive_support']}. "
29
- f"Remote support: {row['remote_support']}. "
30
- f"Duration: {row['duration'] if pd.notna(row['duration']) else 'N/A'} minutes.",
31
- axis=1
32
- ).tolist()
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
 
 
34
  corpus_embeddings = model.encode(corpus, convert_to_tensor=True)
35
  query_embedding = model.encode(query, convert_to_tensor=True)
36
- hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
37
 
38
- return df_copy.iloc[[hit['corpus_id'] for hit in hits]]
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  from sentence_transformers import SentenceTransformer, util
3
+ import json
4
 
5
  model = SentenceTransformer("all-MiniLM-L6-v2")
6
 
7
+ def format_test_type(test_types):
8
+ """Format test type for embedding."""
9
+ if isinstance(test_types, list):
10
+ return ', '.join(test_types)
11
+ if isinstance(test_types, str) and test_types.startswith('['):
12
+ try:
13
+ return ', '.join(eval(test_types))
14
+ except:
15
+ pass
16
+ return str(test_types)
17
+
18
  def get_relevant_passages(query, df, top_k=20):
19
+ """Find most relevant assessments using semantic search."""
20
  # Create a copy to avoid modifying the original dataframe
21
  df_copy = df.copy()
22
 
23
+ if df_copy.empty:
24
+ print("Warning: Empty dataframe passed to get_relevant_passages")
25
+ return df_copy
26
+
27
+ # Display dataframe info for debugging
28
+ print(f"Dataframe columns: {df_copy.columns}")
29
+ print(f"Dataframe sample: {df_copy.head(1).to_dict('records')}")
30
+
31
+ # Ensure test_type is properly formatted
32
+ if 'test_type' in df_copy.columns:
33
+ # Convert test_type to proper format if it's a string representation of a list
34
+ df_copy['test_type'] = df_copy['test_type'].apply(
35
+ lambda x: eval(x) if isinstance(x, str) and x.startswith('[') else
36
+ ([x] if not isinstance(x, list) else x)
37
+ )
38
+
39
+ # Concatenate all fields into a single string per row for embedding
40
+ corpus = []
41
+ for _, row in df_copy.iterrows():
42
+ try:
43
+ description = row['description'] if pd.notna(row['description']) else ""
44
+ test_types = format_test_type(row['test_type']) if 'test_type' in row else ""
45
+ adaptive = row['adaptive_support'] if 'adaptive_support' in row else "Unknown"
46
+ remote = row['remote_support'] if 'remote_support' in row else "Unknown"
47
+ duration = f"{row['duration']} minutes" if pd.notna(row.get('duration')) else "Unknown duration"
48
+
49
+ text = (f"{description} "
50
+ f"Test types: {test_types}. "
51
+ f"Adaptive support: {adaptive}. "
52
+ f"Remote support: {remote}. "
53
+ f"Duration: {duration}.")
54
+ corpus.append(text)
55
+ except Exception as e:
56
+ print(f"Error processing row: {e}")
57
+ corpus.append("Error processing assessment")
58
 
59
+ print(f"Created corpus with {len(corpus)} items")
60
+
61
+ # Generate embeddings
62
  corpus_embeddings = model.encode(corpus, convert_to_tensor=True)
63
  query_embedding = model.encode(query, convert_to_tensor=True)
 
64
 
65
+ # Find most similar
66
+ hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=min(top_k, len(corpus)))[0]
67
+
68
+ # Get top matches
69
+ result = df_copy.iloc[[hit['corpus_id'] for hit in hits]].copy()
70
+ print(f"Found {len(result)} relevant passages")
71
+
72
+ # Add score for debugging
73
+ result['score'] = [hit['score'] for hit in hits]
74
+
75
+ return result