AnshulS commited on
Commit
2133db4
·
verified ·
1 Parent(s): c0029e4

Create retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +12 -0
retriever.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from sentence_transformers import SentenceTransformer, util
3
+
4
+ model = SentenceTransformer("all-MiniLM-L6-v2")
5
+
6
+ def get_relevant_passages(query, df, top_k=20):
7
+ corpus = df["description"].astype(str).tolist()
8
+ corpus_embeddings = model.encode(corpus, convert_to_tensor=True)
9
+ query_embedding = model.encode(query, convert_to_tensor=True)
10
+
11
+ hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
12
+ return df.iloc[[hit['corpus_id'] for hit in hits]]