endpointwebappshl / retriever.py
AnshulS's picture
Update retriever.py
7e0bee0 verified
raw
history blame
1.62 kB
import pandas as pd
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer("all-MiniLM-L6-v2")
def get_relevant_passages(query, df, top_k=20):
# Create a copy to avoid modifying the original dataframe
df_copy = df.copy()
# Ensure URL field is properly formatted
if 'url' in df_copy.columns:
# Clean up URLs if needed
df_copy['url'] = df_copy['url'].astype(str)
# Ensure URLs start with http or https
mask = ~df_copy['url'].str.startswith(('http://', 'https://'))
df_copy.loc[mask, 'url'] = 'https://www.shl.com/' + df_copy.loc[mask, 'url'].str.lstrip('/')
# Format test_type for better representation
def format_test_type(test_types):
if isinstance(test_types, list):
return ', '.join(test_types)
return str(test_types)
# Concatenate all fields into a single string per row
corpus = df_copy.apply(
lambda row: f"{row['description']} "
f"Test types: {format_test_type(row['test_type'])}. "
f"Adaptive support: {row['adaptive_support']}. "
f"Remote support: {row['remote_support']}. "
f"Duration: {row['duration'] if pd.notna(row['duration']) else 'N/A'} minutes.",
axis=1
).tolist()
corpus_embeddings = model.encode(corpus, convert_to_tensor=True)
query_embedding = model.encode(query, convert_to_tensor=True)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
return df_copy.iloc[[hit['corpus_id'] for hit in hits]]