Shreyas094 commited on
Commit
9c1a06a
·
verified ·
1 Parent(s): 1718c18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -2
app.py CHANGED
@@ -18,6 +18,7 @@ from huggingface_hub import InferenceClient
18
  import inspect
19
  import logging
20
  import shutil
 
21
 
22
 
23
  # Set up basic configuration for logging
@@ -430,6 +431,22 @@ def create_web_search_vectors(search_results):
430
 
431
  return FAISS.from_documents(documents, embed)
432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  def get_response_with_search(query, model, num_calls=3, temperature=0.2):
434
  search_results = duckduckgo_search(query)
435
  web_search_database = create_web_search_vectors(search_results)
@@ -438,12 +455,15 @@ def get_response_with_search(query, model, num_calls=3, temperature=0.2):
438
  yield "No web search results available. Please try again.", ""
439
  return
440
 
441
- retriever = web_search_database.as_retriever(search_kwargs={"k": 5})
442
  relevant_docs = retriever.get_relevant_documents(query)
443
 
 
 
 
444
  accumulated_response = ""
445
 
446
- for i, doc in enumerate(relevant_docs, 1):
447
  context = doc.page_content
448
  source = doc.metadata.get('source', 'Unknown source')
449
 
@@ -502,6 +522,32 @@ Highlight any conflicting information or gaps in the available data."""
502
  overall_summary += chunk
503
  accumulated_response += f"Overall Summary:\n\n{overall_summary}\n\n"
504
  yield accumulated_response, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
 
506
  def get_response_from_pdf(query, model, selected_docs, num_calls=3, temperature=0.2):
507
  logging.info(f"Entering get_response_from_pdf with query: {query}, model: {model}, selected_docs: {selected_docs}")
 
18
  import inspect
19
  import logging
20
  import shutil
21
+ from sentence_transformers import CrossEncoder
22
 
23
 
24
  # Set up basic configuration for logging
 
431
 
432
  return FAISS.from_documents(documents, embed)
433
 
434
+ def rerank_web_results(query, documents, top_k=5):
435
+ # Initialize the cross-encoder model
436
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
437
+
438
+ # Prepare input pairs for the cross-encoder
439
+ pairs = [[query, doc.page_content] for doc in documents]
440
+
441
+ # Compute relevance scores
442
+ scores = cross_encoder.predict(pairs)
443
+
444
+ # Sort documents by score
445
+ reranked_docs = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
446
+
447
+ # Return top_k documents
448
+ return [doc for doc, score in reranked_docs[:top_k]]
449
+
450
  def get_response_with_search(query, model, num_calls=3, temperature=0.2):
451
  search_results = duckduckgo_search(query)
452
  web_search_database = create_web_search_vectors(search_results)
 
455
  yield "No web search results available. Please try again.", ""
456
  return
457
 
458
+ retriever = web_search_database.as_retriever(search_kwargs={"k": 20}) # Retrieve more documents for reranking
459
  relevant_docs = retriever.get_relevant_documents(query)
460
 
461
+ # Rerank the documents
462
+ reranked_docs = rerank_web_results(query, relevant_docs, top_k=5)
463
+
464
  accumulated_response = ""
465
 
466
+ for i, doc in enumerate(reranked_docs, 1):
467
  context = doc.page_content
468
  source = doc.metadata.get('source', 'Unknown source')
469
 
 
522
  overall_summary += chunk
523
  accumulated_response += f"Overall Summary:\n\n{overall_summary}\n\n"
524
  yield accumulated_response, ""
525
+
526
+ # Generate an overall summary after processing all sources
527
+ overall_prompt = f"""Based on the summaries you've generated for each source, provide a concise overall summary that addresses the user's query: '{query}'
528
+ Highlight any conflicting information or gaps in the available data."""
529
+
530
+ if model == "@cf/meta/llama-3.1-8b-instruct":
531
+ # Use Cloudflare API for overall summary
532
+ overall_response = ""
533
+ for response in get_response_from_cloudflare(prompt="", context="", query=overall_prompt, num_calls=1, temperature=temperature, search_type="web"):
534
+ overall_response += response
535
+ accumulated_response += f"Overall Summary:\n\n{overall_response}\n\n"
536
+ yield accumulated_response, ""
537
+ else:
538
+ # Use Hugging Face API for overall summary
539
+ overall_summary = ""
540
+ for message in client.chat_completion(
541
+ messages=[{"role": "user", "content": overall_prompt}],
542
+ max_tokens=2000,
543
+ temperature=temperature,
544
+ stream=True,
545
+ ):
546
+ if message.choices and message.choices[0].delta and message.choices[0].delta.content:
547
+ chunk = message.choices[0].delta.content
548
+ overall_summary += chunk
549
+ accumulated_response += f"Overall Summary:\n\n{overall_summary}\n\n"
550
+ yield accumulated_response, ""
551
 
552
  def get_response_from_pdf(query, model, selected_docs, num_calls=3, temperature=0.2):
553
  logging.info(f"Entering get_response_from_pdf with query: {query}, model: {model}, selected_docs: {selected_docs}")