krishnadhulipalla commited on
Commit
ab32d1c
·
1 Parent(s): 2fa0d66

reduced latency

Browse files
Files changed (1) hide show
  1. app.py +41 -18
app.py CHANGED
@@ -5,6 +5,7 @@ import hashlib
5
  import gradio as gr
6
  import time
7
  from functools import partial
 
8
  from collections import defaultdict
9
  from pathlib import Path
10
  from typing import List, Dict, Any
@@ -77,6 +78,8 @@ def initialize_resources():
77
 
78
  vectorstore, all_chunks, all_texts, metadatas = initialize_resources()
79
 
 
 
80
  # LLMs
81
  repharser_llm = ChatNVIDIA(model="mistralai/mistral-7b-instruct-v0.3") | StrOutputParser()
82
  instruct_llm = ChatNVIDIA(model="mistralai/mixtral-8x22b-instruct-v0.1") | StrOutputParser()
@@ -162,10 +165,10 @@ answer_prompt_relevant = ChatPromptTemplate.from_template(
162
  "Answer:"
163
  )
164
 
165
-
166
  answer_prompt_fallback = ChatPromptTemplate.from_template(
167
  "You are Krishna’s personal AI assistant. The user asked a question unrelated to Krishna’s background.\n"
168
  "Respond with a touch of humor, then guide the conversation back to Krishna’s actual skills, experiences, or projects.\n\n"
 
169
  "Krishna's Background:\n{profile}\n\n"
170
  "User Question:\n{query}\n\n"
171
  "Your Answer:"
@@ -178,13 +181,13 @@ def parse_rewrites(raw_response: str) -> list[str]:
178
  def hybrid_retrieve(inputs, exclude_terms=None):
179
  # if exclude_terms is None:
180
  # exclude_terms = ["cgpa", "university", "b.tech", "m.s.", "certification", "coursera", "edx", "goal", "aspiration", "linkedin", "publication", "ieee", "doi", "degree"]
181
-
182
  all_queries = inputs["all_queries"]
183
- bm25_retriever = BM25Retriever.from_texts(texts=all_texts, metadatas=metadatas)
184
  bm25_retriever.k = inputs["k_per_query"]
185
  vectorstore = inputs["vectorstore"]
186
  alpha = inputs["alpha"]
187
  top_k = inputs.get("top_k", 15)
 
188
 
189
  scored_chunks = defaultdict(lambda: {
190
  "vector_scores": [],
@@ -192,23 +195,45 @@ def hybrid_retrieve(inputs, exclude_terms=None):
192
  "content": None,
193
  "metadata": None,
194
  })
195
-
196
- for subquery in all_queries:
197
- vec_hits = vectorstore.similarity_search_with_score(subquery, k=inputs["k_per_query"])
 
 
 
198
  for doc, score in vec_hits:
199
  key = hashlib.md5(doc.page_content.encode("utf-8")).hexdigest()
200
- scored_chunks[key]["vector_scores"].append(score)
201
- scored_chunks[key]["content"] = doc.page_content
202
- scored_chunks[key]["metadata"] = doc.metadata
203
-
204
  bm_hits = bm25_retriever.invoke(subquery)
 
205
  for rank, doc in enumerate(bm_hits):
206
  key = hashlib.md5(doc.page_content.encode("utf-8")).hexdigest()
207
- bm_score = 1.0 - (rank / inputs["k_per_query"])
208
- scored_chunks[key]["bm25_score"] += bm_score
209
- scored_chunks[key]["content"] = doc.page_content
210
- scored_chunks[key]["metadata"] = doc.metadata
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  all_vec_means = [np.mean(v["vector_scores"]) for v in scored_chunks.values() if v["vector_scores"]]
213
  max_vec = max(all_vec_means) if all_vec_means else 1
214
  min_vec = min(all_vec_means) if all_vec_means else 0
@@ -221,8 +246,6 @@ def hybrid_retrieve(inputs, exclude_terms=None):
221
  final_score = alpha * norm_vec + (1 - alpha) * bm25_score
222
 
223
  content = chunk["content"].lower()
224
- # if any(term in content for term in exclude_terms):
225
- # continue
226
  if final_score < 0.05 or len(content.strip()) < 100:
227
  continue
228
 
@@ -334,7 +357,7 @@ def chat_interface(message, history):
334
  "k_per_query": 3,
335
  "alpha": 0.7,
336
  "vectorstore": vectorstore,
337
- "full_document": "",
338
  }
339
  response = ""
340
  for chunk in full_pipeline.stream(inputs):
@@ -358,7 +381,7 @@ demo = gr.ChatInterface(
358
  )
359
 
360
  if __name__ == "__main__":
361
- demo.launch(debug=True)
362
 
363
  # with gr.Blocks(css="""
364
  # html, body, .gradio-container {
 
5
  import gradio as gr
6
  import time
7
  from functools import partial
8
+ import concurrent.futures
9
  from collections import defaultdict
10
  from pathlib import Path
11
  from typing import List, Dict, Any
 
78
 
79
  vectorstore, all_chunks, all_texts, metadatas = initialize_resources()
80
 
81
+ bm25_retriever = BM25Retriever.from_texts(texts=all_texts, metadatas=metadatas)
82
+
83
  # LLMs
84
  repharser_llm = ChatNVIDIA(model="mistralai/mistral-7b-instruct-v0.3") | StrOutputParser()
85
  instruct_llm = ChatNVIDIA(model="mistralai/mixtral-8x22b-instruct-v0.1") | StrOutputParser()
 
165
  "Answer:"
166
  )
167
 
 
168
  answer_prompt_fallback = ChatPromptTemplate.from_template(
169
  "You are Krishna’s personal AI assistant. The user asked a question unrelated to Krishna’s background.\n"
170
  "Respond with a touch of humor, then guide the conversation back to Krishna’s actual skills, experiences, or projects.\n\n"
171
+ "Make it clear that everything you mention afterward comes from Krishna's actual profile.\n\n"
172
  "Krishna's Background:\n{profile}\n\n"
173
  "User Question:\n{query}\n\n"
174
  "Your Answer:"
 
181
  def hybrid_retrieve(inputs, exclude_terms=None):
182
  # if exclude_terms is None:
183
  # exclude_terms = ["cgpa", "university", "b.tech", "m.s.", "certification", "coursera", "edx", "goal", "aspiration", "linkedin", "publication", "ieee", "doi", "degree"]
184
+ bm25_retriever = inputs["bm25_retriever"]
185
  all_queries = inputs["all_queries"]
 
186
  bm25_retriever.k = inputs["k_per_query"]
187
  vectorstore = inputs["vectorstore"]
188
  alpha = inputs["alpha"]
189
  top_k = inputs.get("top_k", 15)
190
+ k_per_query = inputs["k_per_query"]
191
 
192
  scored_chunks = defaultdict(lambda: {
193
  "vector_scores": [],
 
195
  "content": None,
196
  "metadata": None,
197
  })
198
+
199
+ # Function to process each subquery
200
+ def process_subquery(subquery, k_per_query=3):
201
+ # Vector retrieval
202
+ vec_hits = vectorstore.similarity_search_with_score(subquery, k=k_per_query)
203
+ vec_results = []
204
  for doc, score in vec_hits:
205
  key = hashlib.md5(doc.page_content.encode("utf-8")).hexdigest()
206
+ vec_results.append((key, doc, score))
207
+
208
+ # BM25 retrieval
 
209
  bm_hits = bm25_retriever.invoke(subquery)
210
+ bm_results = []
211
  for rank, doc in enumerate(bm_hits):
212
  key = hashlib.md5(doc.page_content.encode("utf-8")).hexdigest()
213
+ bm_score = 1.0 - (rank / k_per_query)
214
+ bm_results.append((key, doc, bm_score))
215
+
216
+ return vec_results, bm_results
217
 
218
+ # Process subqueries in parallel
219
+ with concurrent.futures.ThreadPoolExecutor() as executor:
220
+ futures = [executor.submit(process_subquery, q) for q in all_queries]
221
+ for future in concurrent.futures.as_completed(futures):
222
+ vec_results, bm_results = future.result()
223
+
224
+ # Process vector results
225
+ for key, doc, score in vec_results:
226
+ scored_chunks[key]["vector_scores"].append(score)
227
+ scored_chunks[key]["content"] = doc.page_content
228
+ scored_chunks[key]["metadata"] = doc.metadata
229
+
230
+ # Process BM25 results
231
+ for key, doc, bm_score in bm_results:
232
+ scored_chunks[key]["bm25_score"] += bm_score
233
+ scored_chunks[key]["content"] = doc.page_content
234
+ scored_chunks[key]["metadata"] = doc.metadata
235
+
236
+ # Rest of the scoring and filtering logic remains the same
237
  all_vec_means = [np.mean(v["vector_scores"]) for v in scored_chunks.values() if v["vector_scores"]]
238
  max_vec = max(all_vec_means) if all_vec_means else 1
239
  min_vec = min(all_vec_means) if all_vec_means else 0
 
246
  final_score = alpha * norm_vec + (1 - alpha) * bm25_score
247
 
248
  content = chunk["content"].lower()
 
 
249
  if final_score < 0.05 or len(content.strip()) < 100:
250
  continue
251
 
 
357
  "k_per_query": 3,
358
  "alpha": 0.7,
359
  "vectorstore": vectorstore,
360
+ "bm25_retriever": bm25_retriever,
361
  }
362
  response = ""
363
  for chunk in full_pipeline.stream(inputs):
 
381
  )
382
 
383
  if __name__ == "__main__":
384
+ demo.launch(max_threads=4, prevent_thread_lock=True, debug=True)
385
 
386
  # with gr.Blocks(css="""
387
  # html, body, .gradio-container {