Commit
·
ab32d1c
1
Parent(s):
2fa0d66
reduced latency
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ import hashlib
|
|
5 |
import gradio as gr
|
6 |
import time
|
7 |
from functools import partial
|
|
|
8 |
from collections import defaultdict
|
9 |
from pathlib import Path
|
10 |
from typing import List, Dict, Any
|
@@ -77,6 +78,8 @@ def initialize_resources():
|
|
77 |
|
78 |
vectorstore, all_chunks, all_texts, metadatas = initialize_resources()
|
79 |
|
|
|
|
|
80 |
# LLMs
|
81 |
repharser_llm = ChatNVIDIA(model="mistralai/mistral-7b-instruct-v0.3") | StrOutputParser()
|
82 |
instruct_llm = ChatNVIDIA(model="mistralai/mixtral-8x22b-instruct-v0.1") | StrOutputParser()
|
@@ -162,10 +165,10 @@ answer_prompt_relevant = ChatPromptTemplate.from_template(
|
|
162 |
"Answer:"
|
163 |
)
|
164 |
|
165 |
-
|
166 |
answer_prompt_fallback = ChatPromptTemplate.from_template(
|
167 |
"You are Krishna’s personal AI assistant. The user asked a question unrelated to Krishna’s background.\n"
|
168 |
"Respond with a touch of humor, then guide the conversation back to Krishna’s actual skills, experiences, or projects.\n\n"
|
|
|
169 |
"Krishna's Background:\n{profile}\n\n"
|
170 |
"User Question:\n{query}\n\n"
|
171 |
"Your Answer:"
|
@@ -178,13 +181,13 @@ def parse_rewrites(raw_response: str) -> list[str]:
|
|
178 |
def hybrid_retrieve(inputs, exclude_terms=None):
|
179 |
# if exclude_terms is None:
|
180 |
# exclude_terms = ["cgpa", "university", "b.tech", "m.s.", "certification", "coursera", "edx", "goal", "aspiration", "linkedin", "publication", "ieee", "doi", "degree"]
|
181 |
-
|
182 |
all_queries = inputs["all_queries"]
|
183 |
-
bm25_retriever = BM25Retriever.from_texts(texts=all_texts, metadatas=metadatas)
|
184 |
bm25_retriever.k = inputs["k_per_query"]
|
185 |
vectorstore = inputs["vectorstore"]
|
186 |
alpha = inputs["alpha"]
|
187 |
top_k = inputs.get("top_k", 15)
|
|
|
188 |
|
189 |
scored_chunks = defaultdict(lambda: {
|
190 |
"vector_scores": [],
|
@@ -192,23 +195,45 @@ def hybrid_retrieve(inputs, exclude_terms=None):
|
|
192 |
"content": None,
|
193 |
"metadata": None,
|
194 |
})
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
198 |
for doc, score in vec_hits:
|
199 |
key = hashlib.md5(doc.page_content.encode("utf-8")).hexdigest()
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
bm_hits = bm25_retriever.invoke(subquery)
|
|
|
205 |
for rank, doc in enumerate(bm_hits):
|
206 |
key = hashlib.md5(doc.page_content.encode("utf-8")).hexdigest()
|
207 |
-
bm_score = 1.0 - (rank /
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
all_vec_means = [np.mean(v["vector_scores"]) for v in scored_chunks.values() if v["vector_scores"]]
|
213 |
max_vec = max(all_vec_means) if all_vec_means else 1
|
214 |
min_vec = min(all_vec_means) if all_vec_means else 0
|
@@ -221,8 +246,6 @@ def hybrid_retrieve(inputs, exclude_terms=None):
|
|
221 |
final_score = alpha * norm_vec + (1 - alpha) * bm25_score
|
222 |
|
223 |
content = chunk["content"].lower()
|
224 |
-
# if any(term in content for term in exclude_terms):
|
225 |
-
# continue
|
226 |
if final_score < 0.05 or len(content.strip()) < 100:
|
227 |
continue
|
228 |
|
@@ -334,7 +357,7 @@ def chat_interface(message, history):
|
|
334 |
"k_per_query": 3,
|
335 |
"alpha": 0.7,
|
336 |
"vectorstore": vectorstore,
|
337 |
-
"
|
338 |
}
|
339 |
response = ""
|
340 |
for chunk in full_pipeline.stream(inputs):
|
@@ -358,7 +381,7 @@ demo = gr.ChatInterface(
|
|
358 |
)
|
359 |
|
360 |
if __name__ == "__main__":
|
361 |
-
demo.launch(debug=True)
|
362 |
|
363 |
# with gr.Blocks(css="""
|
364 |
# html, body, .gradio-container {
|
|
|
5 |
import gradio as gr
|
6 |
import time
|
7 |
from functools import partial
|
8 |
+
import concurrent.futures
|
9 |
from collections import defaultdict
|
10 |
from pathlib import Path
|
11 |
from typing import List, Dict, Any
|
|
|
78 |
|
79 |
vectorstore, all_chunks, all_texts, metadatas = initialize_resources()
|
80 |
|
81 |
+
bm25_retriever = BM25Retriever.from_texts(texts=all_texts, metadatas=metadatas)
|
82 |
+
|
83 |
# LLMs
|
84 |
repharser_llm = ChatNVIDIA(model="mistralai/mistral-7b-instruct-v0.3") | StrOutputParser()
|
85 |
instruct_llm = ChatNVIDIA(model="mistralai/mixtral-8x22b-instruct-v0.1") | StrOutputParser()
|
|
|
165 |
"Answer:"
|
166 |
)
|
167 |
|
|
|
168 |
answer_prompt_fallback = ChatPromptTemplate.from_template(
|
169 |
"You are Krishna’s personal AI assistant. The user asked a question unrelated to Krishna’s background.\n"
|
170 |
"Respond with a touch of humor, then guide the conversation back to Krishna’s actual skills, experiences, or projects.\n\n"
|
171 |
+
"Make it clear that everything you mention afterward comes from Krishna's actual profile.\n\n"
|
172 |
"Krishna's Background:\n{profile}\n\n"
|
173 |
"User Question:\n{query}\n\n"
|
174 |
"Your Answer:"
|
|
|
181 |
def hybrid_retrieve(inputs, exclude_terms=None):
|
182 |
# if exclude_terms is None:
|
183 |
# exclude_terms = ["cgpa", "university", "b.tech", "m.s.", "certification", "coursera", "edx", "goal", "aspiration", "linkedin", "publication", "ieee", "doi", "degree"]
|
184 |
+
bm25_retriever = inputs["bm25_retriever"]
|
185 |
all_queries = inputs["all_queries"]
|
|
|
186 |
bm25_retriever.k = inputs["k_per_query"]
|
187 |
vectorstore = inputs["vectorstore"]
|
188 |
alpha = inputs["alpha"]
|
189 |
top_k = inputs.get("top_k", 15)
|
190 |
+
k_per_query = inputs["k_per_query"]
|
191 |
|
192 |
scored_chunks = defaultdict(lambda: {
|
193 |
"vector_scores": [],
|
|
|
195 |
"content": None,
|
196 |
"metadata": None,
|
197 |
})
|
198 |
+
|
199 |
+
# Function to process each subquery
|
200 |
+
def process_subquery(subquery, k_per_query=3):
|
201 |
+
# Vector retrieval
|
202 |
+
vec_hits = vectorstore.similarity_search_with_score(subquery, k=k_per_query)
|
203 |
+
vec_results = []
|
204 |
for doc, score in vec_hits:
|
205 |
key = hashlib.md5(doc.page_content.encode("utf-8")).hexdigest()
|
206 |
+
vec_results.append((key, doc, score))
|
207 |
+
|
208 |
+
# BM25 retrieval
|
|
|
209 |
bm_hits = bm25_retriever.invoke(subquery)
|
210 |
+
bm_results = []
|
211 |
for rank, doc in enumerate(bm_hits):
|
212 |
key = hashlib.md5(doc.page_content.encode("utf-8")).hexdigest()
|
213 |
+
bm_score = 1.0 - (rank / k_per_query)
|
214 |
+
bm_results.append((key, doc, bm_score))
|
215 |
+
|
216 |
+
return vec_results, bm_results
|
217 |
|
218 |
+
# Process subqueries in parallel
|
219 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
220 |
+
futures = [executor.submit(process_subquery, q) for q in all_queries]
|
221 |
+
for future in concurrent.futures.as_completed(futures):
|
222 |
+
vec_results, bm_results = future.result()
|
223 |
+
|
224 |
+
# Process vector results
|
225 |
+
for key, doc, score in vec_results:
|
226 |
+
scored_chunks[key]["vector_scores"].append(score)
|
227 |
+
scored_chunks[key]["content"] = doc.page_content
|
228 |
+
scored_chunks[key]["metadata"] = doc.metadata
|
229 |
+
|
230 |
+
# Process BM25 results
|
231 |
+
for key, doc, bm_score in bm_results:
|
232 |
+
scored_chunks[key]["bm25_score"] += bm_score
|
233 |
+
scored_chunks[key]["content"] = doc.page_content
|
234 |
+
scored_chunks[key]["metadata"] = doc.metadata
|
235 |
+
|
236 |
+
# Rest of the scoring and filtering logic remains the same
|
237 |
all_vec_means = [np.mean(v["vector_scores"]) for v in scored_chunks.values() if v["vector_scores"]]
|
238 |
max_vec = max(all_vec_means) if all_vec_means else 1
|
239 |
min_vec = min(all_vec_means) if all_vec_means else 0
|
|
|
246 |
final_score = alpha * norm_vec + (1 - alpha) * bm25_score
|
247 |
|
248 |
content = chunk["content"].lower()
|
|
|
|
|
249 |
if final_score < 0.05 or len(content.strip()) < 100:
|
250 |
continue
|
251 |
|
|
|
357 |
"k_per_query": 3,
|
358 |
"alpha": 0.7,
|
359 |
"vectorstore": vectorstore,
|
360 |
+
"bm25_retriever": bm25_retriever,
|
361 |
}
|
362 |
response = ""
|
363 |
for chunk in full_pipeline.stream(inputs):
|
|
|
381 |
)
|
382 |
|
383 |
if __name__ == "__main__":
|
384 |
+
demo.launch(max_threads=4, prevent_thread_lock=True, debug=True)
|
385 |
|
386 |
# with gr.Blocks(css="""
|
387 |
# html, body, .gradio-container {
|