Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -90,55 +90,74 @@ def process_documents(selected_files):
|
|
90 |
# ✅ Query document
|
91 |
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
def query_document(question):
|
94 |
-
|
95 |
if vector_store is None:
|
96 |
return "❌ No documents processed.", None
|
97 |
|
98 |
# ✅ Fetch stored documents
|
99 |
-
stored_docs = vector_store.get()["documents"]
|
100 |
|
101 |
# ✅ Calculate total word count safely
|
102 |
-
total_words = sum(len(doc.split()) if isinstance(doc, str) else len(doc.page_content.split()) for doc in stored_docs)
|
103 |
-
|
104 |
-
# ✅
|
105 |
-
if total_words < 500:
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
k_value =
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
**Question:** {question}
|
121 |
-
"""
|
122 |
|
123 |
-
# ✅ Dynamically select model based on
|
124 |
-
if
|
125 |
-
model_name = "gemini-2.0-pro-exp-02-05"
|
126 |
-
else:
|
127 |
-
model_name = "gemini-2.0-flash"
|
128 |
|
129 |
-
logging.info(f"🧠 Using Model: {model_name} for
|
130 |
|
131 |
-
model = ChatGoogleGenerativeAI(model=model_name, google_api_key=GOOGLE_API_KEY)
|
132 |
-
qa_chain = RetrievalQA.from_chain_type(llm=model, retriever=retriever)
|
133 |
-
response = qa_chain.invoke({"query": detailed_prompt})["result"]
|
134 |
|
135 |
# ✅ Convert response to speech
|
136 |
-
tts = gTTS(text=response, lang="en")
|
137 |
-
temp_audio_path = os.path.join(temp_dir, "response.mp3")
|
138 |
-
tts.save(temp_audio_path)
|
139 |
-
temp_file_map["response.mp3"] = time.time()
|
|
|
|
|
140 |
|
141 |
-
return response, temp_audio_path
|
142 |
|
143 |
|
144 |
# ✅ Gradio UI
|
|
|
90 |
# ✅ Query document
|
91 |
|
92 |
|
93 |
+
import os
|
94 |
+
import time
|
95 |
+
import logging
|
96 |
+
from gtts import gTTS
|
97 |
+
from langchain.chains import RetrievalQA
|
98 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
99 |
+
|
100 |
+
# ✅ Ensure temp_file_map exists
|
101 |
+
temp_file_map = {}
|
102 |
+
|
103 |
def query_document(question):
|
|
|
104 |
if vector_store is None:
|
105 |
return "❌ No documents processed.", None
|
106 |
|
107 |
# ✅ Fetch stored documents
|
108 |
+
stored_docs = vector_store.get()["documents"]
|
109 |
|
110 |
# ✅ Calculate total word count safely
|
111 |
+
total_words = sum(len(doc.split()) if isinstance(doc, str) else len(doc.page_content.split()) for doc in stored_docs)
|
112 |
+
|
113 |
+
# ✅ Categorize file size
|
114 |
+
if total_words < 500:
|
115 |
+
file_size_category = "small"
|
116 |
+
k_value = 3
|
117 |
+
elif total_words < 2000:
|
118 |
+
file_size_category = "medium"
|
119 |
+
k_value = 5
|
120 |
+
else:
|
121 |
+
file_size_category = "large"
|
122 |
+
k_value = 10
|
123 |
+
|
124 |
+
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": k_value})
|
125 |
+
|
126 |
+
# ✅ Adjust response detail based on file size
|
127 |
+
if file_size_category == "small":
|
128 |
+
prompt_prefix = "Provide a **concise** response focusing on key points."
|
129 |
+
elif file_size_category == "medium":
|
130 |
+
prompt_prefix = "Provide a **detailed response** with examples and key insights."
|
131 |
+
else:
|
132 |
+
prompt_prefix = "Provide a **comprehensive and structured response**, including step-by-step analysis and explanations."
|
133 |
+
|
134 |
+
# ✅ Final prompt
|
135 |
+
detailed_prompt = f"""{prompt_prefix}
|
136 |
+
- Ensure clarity and completeness.
|
137 |
+
- Highlight the most relevant information.
|
138 |
**Question:** {question}
|
139 |
+
"""
|
140 |
|
141 |
+
# ✅ Dynamically select model based on file size
|
142 |
+
if file_size_category in ["small", "medium"]:
|
143 |
+
model_name = "gemini-2.0-pro-exp-02-05"
|
144 |
+
else:
|
145 |
+
model_name = "gemini-2.0-flash"
|
146 |
|
147 |
+
logging.info(f"🧠 Using Model: {model_name} for {file_size_category} file.")
|
148 |
|
149 |
+
model = ChatGoogleGenerativeAI(model=model_name, google_api_key=GOOGLE_API_KEY)
|
150 |
+
qa_chain = RetrievalQA.from_chain_type(llm=model, retriever=retriever)
|
151 |
+
response = qa_chain.invoke({"query": detailed_prompt})["result"]
|
152 |
|
153 |
# ✅ Convert response to speech
|
154 |
+
tts = gTTS(text=response, lang="en")
|
155 |
+
temp_audio_path = os.path.join(temp_dir, "response.mp3")
|
156 |
+
tts.save(temp_audio_path)
|
157 |
+
temp_file_map["response.mp3"] = time.time()
|
158 |
+
|
159 |
+
return response, temp_audio_path
|
160 |
|
|
|
161 |
|
162 |
|
163 |
# ✅ Gradio UI
|