abhivsh commited on
Commit
b0de7f5
·
verified ·
1 Parent(s): 8230753

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -72,13 +72,15 @@ def chat_query_doc(question, history):
72
  llm = ChatOpenAI(model = llm_name, temperature = 0.1, api_key = OPENAI_API_KEY)
73
  #llm = GoogleGenerativeAI(model = "gemini-pro", google_api_key = GEMINI_API_KEY)
74
  #llm = ChatGoogleGenerativeAI(model = "gemini-1.0-pro", google_api_key = GEMINI_API_KEY, temperature = 0.1, top_k = 1, top_p = 0.95)
 
75
 
76
- def get_relevant_passage(query, db):
77
- passage = db.query(query_texts=[query], n_results=1)['documents'][0][0]
 
78
  return passage
79
 
80
  # Perform embedding search
81
- passage = get_relevant_passage(question, vectordb)
82
 
83
 
84
  def make_prompt(query, relevant_passage):
@@ -100,7 +102,7 @@ def chat_query_doc(question, history):
100
 
101
  # Conversation Retrival Chain with Memory
102
  # memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
103
- retriever=vectordb.as_retriever()
104
  qa = ConversationalRetrievalChain.from_llm(llm, retriever=retriever, return_source_documents=True)
105
 
106
  # # Replace input() with question variable for Gradio
 
72
  llm = ChatOpenAI(model = llm_name, temperature = 0.1, api_key = OPENAI_API_KEY)
73
  #llm = GoogleGenerativeAI(model = "gemini-pro", google_api_key = GEMINI_API_KEY)
74
  #llm = ChatGoogleGenerativeAI(model = "gemini-1.0-pro", google_api_key = GEMINI_API_KEY, temperature = 0.1, top_k = 1, top_p = 0.95)
75
+ retriever=vectordb.as_retriever(search_type="mmr")
76
 
77
+ def get_relevant_passage(query, retriever):
78
+
79
+ passage = retriever.invoke(query)[0]['page_content']
80
  return passage
81
 
82
  # Perform embedding search
83
+ passage = get_relevant_passage(question, retriever)
84
 
85
 
86
  def make_prompt(query, relevant_passage):
 
102
 
103
  # Conversation Retrival Chain with Memory
104
  # memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
105
+
106
  qa = ConversationalRetrievalChain.from_llm(llm, retriever=retriever, return_source_documents=True)
107
 
108
  # # Replace input() with question variable for Gradio