snsynth commited on
Commit
62ec630
·
1 Parent(s): a643df2

add renranker changes

Browse files
Files changed (1) hide show
  1. rag_app/rag_2.py +12 -0
rag_app/rag_2.py CHANGED
@@ -77,6 +77,15 @@ def is_harmful(query):
77
  return any(keyword in query.lower() for keyword in harmful_keywords)
78
 
79
 
 
 
 
 
 
 
 
 
 
80
  def answer_question(query):
81
  print("loading bm25 retriever")
82
  bm25_retriever = BM25Retriever.from_persist_dir("models/bm25_retriever")
@@ -106,5 +115,8 @@ def answer_question(query):
106
  if is_harmful(query):
107
  return "This query has been flagged as unsafe."
108
 
 
 
 
109
  response = keyword_query_engine.query(query)
110
  return str(response)
 
77
  return any(keyword in query.lower() for keyword in harmful_keywords)
78
 
79
 
80
+ def is_relevant(query, index, threshold=0.7):
81
+ retriever = index.as_retriever(similarity_top_k=1)
82
+ nodes = retriever.retrieve(query)
83
+ if not nodes:
84
+ return False
85
+ similarity = nodes[0].score
86
+ return not similarity <= threshold
87
+
88
+
89
  def answer_question(query):
90
  print("loading bm25 retriever")
91
  bm25_retriever = BM25Retriever.from_persist_dir("models/bm25_retriever")
 
115
  if is_harmful(query):
116
  return "This query has been flagged as unsafe."
117
 
118
+ if not is_relevant(query, index, 0.2):
119
+ return "This query doesn't appear relevant to finance."
120
+
121
  response = keyword_query_engine.query(query)
122
  return str(response)