endpointwebappshl / retriever.py
AnshulS's picture
Create retriever.py
2133db4 verified
raw
history blame
505 Bytes
import pandas as pd
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer("all-MiniLM-L6-v2")
def get_relevant_passages(query, df, top_k=20):
corpus = df["description"].astype(str).tolist()
corpus_embeddings = model.encode(corpus, convert_to_tensor=True)
query_embedding = model.encode(query, convert_to_tensor=True)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
return df.iloc[[hit['corpus_id'] for hit in hits]]