Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,8 +8,6 @@ from typing import List
|
|
8 |
from pydantic import BaseModel, Field
|
9 |
from tempfile import NamedTemporaryFile
|
10 |
from langchain_community.vectorstores import FAISS
|
11 |
-
from langchain_community.vectorstores import FAISS as WebSearchFAISS
|
12 |
-
from langchain_core.documents import Document as WebSearchDocument
|
13 |
from langchain_core.vectorstores import VectorStore
|
14 |
from langchain_core.documents import Document
|
15 |
from langchain_community.document_loaders import PyPDFLoader
|
@@ -20,8 +18,6 @@ from huggingface_hub import InferenceClient
|
|
20 |
import inspect
|
21 |
import logging
|
22 |
import shutil
|
23 |
-
import tempfile
|
24 |
-
from typing import List, Tuple
|
25 |
|
26 |
|
27 |
# Set up basic configuration for logging
|
@@ -275,72 +271,10 @@ def generate_chunked_response(prompt, model, max_tokens=10000, num_calls=3, temp
|
|
275 |
print(f"Final clean response: {final_response[:100]}...")
|
276 |
return final_response
|
277 |
|
278 |
-
def get_web_search_database():
|
279 |
-
embed = get_embeddings()
|
280 |
-
temp_dir = tempfile.mkdtemp()
|
281 |
-
|
282 |
-
try:
|
283 |
-
# Create a dummy document to initialize the database
|
284 |
-
dummy_doc = Document(page_content="Dummy content", metadata={"source": "dummy"})
|
285 |
-
database = FAISS.from_documents([dummy_doc], embed)
|
286 |
-
logging.info("Successfully initialized WebSearchFAISS database with dummy document")
|
287 |
-
# Remove the dummy document
|
288 |
-
database.delete(["dummy"])
|
289 |
-
logging.info("Removed dummy document from database")
|
290 |
-
except Exception as e:
|
291 |
-
logging.error(f"Error initializing WebSearchFAISS: {str(e)}", exc_info=True)
|
292 |
-
# If initialization fails, create an empty database
|
293 |
-
database = FAISS(embed, None, {}, {}, None)
|
294 |
-
logging.info("Created empty WebSearchFAISS database manually")
|
295 |
-
|
296 |
-
return database, temp_dir
|
297 |
-
|
298 |
-
def cleanup_web_search_database(temp_dir):
|
299 |
-
shutil.rmtree(temp_dir)
|
300 |
-
|
301 |
def duckduckgo_search(query):
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
results = list(ddgs.text(query, max_results=5))
|
306 |
-
|
307 |
-
logging.info(f"Number of search results: {len(results)}")
|
308 |
-
|
309 |
-
database, temp_dir = get_web_search_database()
|
310 |
-
documents = []
|
311 |
-
for result in results:
|
312 |
-
content = f"{result['title']}\n{result['body']}"
|
313 |
-
doc = Document(page_content=content, metadata={"source": result['href']})
|
314 |
-
documents.append(doc)
|
315 |
-
|
316 |
-
logging.info(f"Number of documents created: {len(documents)}")
|
317 |
-
if documents:
|
318 |
-
try:
|
319 |
-
database.add_documents(documents)
|
320 |
-
logging.info(f"Successfully added {len(documents)} documents to the database")
|
321 |
-
except Exception as e:
|
322 |
-
logging.error(f"Error adding documents to database: {str(e)}", exc_info=True)
|
323 |
-
# If adding documents fails, create a new database with these documents
|
324 |
-
database = FAISS.from_documents(documents, get_embeddings())
|
325 |
-
logging.info("Created new WebSearchFAISS database with search results")
|
326 |
-
return database, temp_dir, results
|
327 |
-
except Exception as e:
|
328 |
-
logging.error(f"Error in duckduckgo_search: {str(e)}", exc_info=True)
|
329 |
-
return None, None, []
|
330 |
-
|
331 |
-
def retrieve_web_search_results(database, query):
|
332 |
-
logging.info(f"Retrieving web search results for query: {query}")
|
333 |
-
retriever = database.as_retriever(search_kwargs={"k": 5})
|
334 |
-
relevant_docs = retriever.get_relevant_documents(query)
|
335 |
-
|
336 |
-
logging.info(f"Number of relevant documents retrieved: {len(relevant_docs)}")
|
337 |
-
|
338 |
-
if not relevant_docs:
|
339 |
-
logging.warning("No relevant documents found in the database")
|
340 |
-
return "No relevant information found."
|
341 |
-
|
342 |
-
context = "\n".join([f"{doc.page_content}\nSource: {doc.metadata['source']}" for doc in relevant_docs])
|
343 |
-
return context
|
344 |
|
345 |
class CitingSources(BaseModel):
|
346 |
sources: List[str] = Field(
|
@@ -377,66 +311,50 @@ def respond(message, history, model, temperature, num_calls, use_web_search, sel
|
|
377 |
logging.info(f"User Query: {message}")
|
378 |
logging.info(f"Model Used: {model}")
|
379 |
logging.info(f"Search Type: {'Web Search' if use_web_search else 'PDF Search'}")
|
|
|
380 |
logging.info(f"Selected Documents: {selected_docs}")
|
381 |
|
382 |
try:
|
383 |
if use_web_search:
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
try:
|
390 |
-
context = retrieve_web_search_results(database, message)
|
391 |
-
logging.info(f"Retrieved context length: {len(context)}")
|
392 |
-
|
393 |
-
if model == "@cf/meta/llama-3.1-8b-instruct":
|
394 |
-
# Use Cloudflare API
|
395 |
-
for partial_response in get_response_from_cloudflare(prompt="", context=context, query=message, num_calls=num_calls, temperature=temperature, search_type="web"):
|
396 |
-
logging.debug(f"Partial response: {partial_response[:100]}...") # Log first 100 chars
|
397 |
-
yield partial_response
|
398 |
-
else:
|
399 |
-
# Use Hugging Face API
|
400 |
-
for main_content, sources in get_response_with_search(message, model, num_calls=num_calls, temperature=temperature):
|
401 |
-
response = f"{main_content}\n\n{sources}"
|
402 |
-
logging.debug(f"Response: {response[:100]}...") # Log first 100 chars
|
403 |
-
yield response
|
404 |
-
finally:
|
405 |
-
# Clean up the temporary database
|
406 |
-
cleanup_web_search_database(temp_dir)
|
407 |
else:
|
408 |
-
# PDF search logic
|
409 |
embed = get_embeddings()
|
410 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
yield "No documents available. Please upload PDF documents to answer questions."
|
412 |
return
|
413 |
|
414 |
-
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
415 |
-
retriever = database.as_retriever(search_kwargs={"k": 20})
|
416 |
-
|
417 |
-
all_relevant_docs = retriever.get_relevant_documents(message)
|
418 |
-
relevant_docs = [doc for doc in all_relevant_docs if doc.metadata["source"] in selected_docs]
|
419 |
-
|
420 |
-
if not relevant_docs:
|
421 |
-
yield "No relevant information found in the selected documents. Please try selecting different documents or rephrasing your query."
|
422 |
-
return
|
423 |
-
|
424 |
-
context_str = "\n".join([doc.page_content for doc in relevant_docs])
|
425 |
-
logging.info(f"Context length for PDF search: {len(context_str)}")
|
426 |
-
|
427 |
if model == "@cf/meta/llama-3.1-8b-instruct":
|
428 |
# Use Cloudflare API
|
429 |
for partial_response in get_response_from_cloudflare(prompt="", context=context_str, query=message, num_calls=num_calls, temperature=temperature, search_type="pdf"):
|
430 |
-
|
|
|
431 |
yield partial_response
|
432 |
else:
|
433 |
# Use Hugging Face API
|
434 |
for partial_response in get_response_from_pdf(message, model, selected_docs, num_calls=num_calls, temperature=temperature):
|
435 |
-
|
|
|
436 |
yield partial_response
|
437 |
-
|
438 |
except Exception as e:
|
439 |
-
logging.error(f"Error
|
440 |
if "microsoft/Phi-3-mini-4k-instruct" in model:
|
441 |
logging.info("Falling back to Mistral model due to Phi-3 error")
|
442 |
fallback_model = "mistralai/Mistral-7B-Instruct-v0.3"
|
@@ -501,42 +419,65 @@ After writing the document, please provide a list of sources used in your respon
|
|
501 |
if not full_response:
|
502 |
yield "I apologize, but I couldn't generate a response at this time. Please try again later."
|
503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
504 |
def get_response_with_search(query, model, num_calls=3, temperature=0.2):
|
505 |
-
|
506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
507 |
|
508 |
prompt = f"""Using the following context from web search results:
|
509 |
-
|
510 |
-
|
511 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
512 |
|
513 |
-
try:
|
514 |
-
if model == "@cf/meta/llama-3.1-8b-instruct":
|
515 |
-
# Use Cloudflare API
|
516 |
-
for response in get_response_from_cloudflare(prompt="", context=context, query=query, num_calls=num_calls, temperature=temperature, search_type="web"):
|
517 |
-
yield response, "" # Yield streaming response without sources
|
518 |
-
else:
|
519 |
-
# Use Hugging Face API
|
520 |
-
client = InferenceClient(model, token=huggingface_token)
|
521 |
-
|
522 |
-
main_content = ""
|
523 |
-
for i in range(num_calls):
|
524 |
-
for message in client.chat_completion(
|
525 |
-
messages=[{"role": "user", "content": prompt}],
|
526 |
-
max_tokens=10000,
|
527 |
-
temperature=temperature,
|
528 |
-
stream=True,
|
529 |
-
):
|
530 |
-
if message.choices and len(message.choices) > 0 and message.choices[0].delta and message.choices[0].delta.content:
|
531 |
-
chunk = message.choices[0].delta.content
|
532 |
-
main_content += chunk
|
533 |
-
yield main_content, "" # Yield partial main content without sources
|
534 |
-
except Exception as e:
|
535 |
-
logging.error(f"Error in get_response_with_search: {str(e)}", exc_info=True)
|
536 |
-
yield f"An error occurred while processing the search results: {str(e)}", ""
|
537 |
-
finally:
|
538 |
-
# Clean up the temporary database
|
539 |
-
cleanup_web_search_database(temp_dir)
|
540 |
|
541 |
def get_response_from_pdf(query, model, selected_docs, num_calls=3, temperature=0.2):
|
542 |
logging.info(f"Entering get_response_from_pdf with query: {query}, model: {model}, selected_docs: {selected_docs}")
|
|
|
8 |
from pydantic import BaseModel, Field
|
9 |
from tempfile import NamedTemporaryFile
|
10 |
from langchain_community.vectorstores import FAISS
|
|
|
|
|
11 |
from langchain_core.vectorstores import VectorStore
|
12 |
from langchain_core.documents import Document
|
13 |
from langchain_community.document_loaders import PyPDFLoader
|
|
|
18 |
import inspect
|
19 |
import logging
|
20 |
import shutil
|
|
|
|
|
21 |
|
22 |
|
23 |
# Set up basic configuration for logging
|
|
|
271 |
print(f"Final clean response: {final_response[:100]}...")
|
272 |
return final_response
|
273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
def duckduckgo_search(query):
|
275 |
+
with DDGS() as ddgs:
|
276 |
+
results = ddgs.text(query, max_results=5)
|
277 |
+
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
|
279 |
class CitingSources(BaseModel):
|
280 |
sources: List[str] = Field(
|
|
|
311 |
logging.info(f"User Query: {message}")
|
312 |
logging.info(f"Model Used: {model}")
|
313 |
logging.info(f"Search Type: {'Web Search' if use_web_search else 'PDF Search'}")
|
314 |
+
|
315 |
logging.info(f"Selected Documents: {selected_docs}")
|
316 |
|
317 |
try:
|
318 |
if use_web_search:
|
319 |
+
for main_content, sources in get_response_with_search(message, model, num_calls=num_calls, temperature=temperature):
|
320 |
+
response = f"{main_content}\n\n{sources}"
|
321 |
+
first_line = response.split('\n')[0] if response else ''
|
322 |
+
# logging.info(f"Generated Response (first line): {first_line}")
|
323 |
+
yield response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
else:
|
|
|
325 |
embed = get_embeddings()
|
326 |
+
if os.path.exists("faiss_database"):
|
327 |
+
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
328 |
+
retriever = database.as_retriever(search_kwargs={"k": 20})
|
329 |
+
|
330 |
+
# Filter relevant documents based on user selection
|
331 |
+
all_relevant_docs = retriever.get_relevant_documents(message)
|
332 |
+
relevant_docs = [doc for doc in all_relevant_docs if doc.metadata["source"] in selected_docs]
|
333 |
+
|
334 |
+
if not relevant_docs:
|
335 |
+
yield "No relevant information found in the selected documents. Please try selecting different documents or rephrasing your query."
|
336 |
+
return
|
337 |
+
|
338 |
+
context_str = "\n".join([doc.page_content for doc in relevant_docs])
|
339 |
+
else:
|
340 |
+
context_str = "No documents available."
|
341 |
yield "No documents available. Please upload PDF documents to answer questions."
|
342 |
return
|
343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
if model == "@cf/meta/llama-3.1-8b-instruct":
|
345 |
# Use Cloudflare API
|
346 |
for partial_response in get_response_from_cloudflare(prompt="", context=context_str, query=message, num_calls=num_calls, temperature=temperature, search_type="pdf"):
|
347 |
+
first_line = partial_response.split('\n')[0] if partial_response else ''
|
348 |
+
# logging.info(f"Generated Response (first line): {first_line}")
|
349 |
yield partial_response
|
350 |
else:
|
351 |
# Use Hugging Face API
|
352 |
for partial_response in get_response_from_pdf(message, model, selected_docs, num_calls=num_calls, temperature=temperature):
|
353 |
+
first_line = partial_response.split('\n')[0] if partial_response else ''
|
354 |
+
# logging.info(f"Generated Response (first line): {first_line}")
|
355 |
yield partial_response
|
|
|
356 |
except Exception as e:
|
357 |
+
logging.error(f"Error with {model}: {str(e)}")
|
358 |
if "microsoft/Phi-3-mini-4k-instruct" in model:
|
359 |
logging.info("Falling back to Mistral model due to Phi-3 error")
|
360 |
fallback_model = "mistralai/Mistral-7B-Instruct-v0.3"
|
|
|
419 |
if not full_response:
|
420 |
yield "I apologize, but I couldn't generate a response at this time. Please try again later."
|
421 |
|
422 |
+
# New global variable for web search database
|
423 |
+
web_search_database = None
|
424 |
+
|
425 |
+
def update_web_search_vectors(search_results):
|
426 |
+
global web_search_database
|
427 |
+
embed = get_embeddings()
|
428 |
+
|
429 |
+
documents = []
|
430 |
+
for result in search_results:
|
431 |
+
if 'body' in result:
|
432 |
+
content = f"{result['title']}\n{result['body']}\nSource: {result['href']}"
|
433 |
+
documents.append(Document(page_content=content, metadata={"source": result['href']}))
|
434 |
+
|
435 |
+
if web_search_database is None:
|
436 |
+
web_search_database = FAISS.from_documents(documents, embed)
|
437 |
+
else:
|
438 |
+
web_search_database.add_documents(documents)
|
439 |
+
|
440 |
def get_response_with_search(query, model, num_calls=3, temperature=0.2):
|
441 |
+
global web_search_database
|
442 |
+
|
443 |
+
search_results = duckduckgo_search(query)
|
444 |
+
update_web_search_vectors(search_results)
|
445 |
+
|
446 |
+
if web_search_database is None:
|
447 |
+
yield "No web search results available. Please try again.", ""
|
448 |
+
return
|
449 |
+
|
450 |
+
retriever = web_search_database.as_retriever(search_kwargs={"k": 5})
|
451 |
+
relevant_docs = retriever.get_relevant_documents(query)
|
452 |
+
|
453 |
+
context = "\n".join([doc.page_content for doc in relevant_docs])
|
454 |
|
455 |
prompt = f"""Using the following context from web search results:
|
456 |
+
{context}
|
457 |
+
Write a detailed and complete research document that fulfills the following user request: '{query}'
|
458 |
+
After writing the document, please provide a list of sources used in your response."""
|
459 |
+
|
460 |
+
if model == "@cf/meta/llama-3.1-8b-instruct":
|
461 |
+
# Use Cloudflare API
|
462 |
+
for response in get_response_from_cloudflare(prompt="", context=context, query=query, num_calls=num_calls, temperature=temperature, search_type="web"):
|
463 |
+
yield response, "" # Yield streaming response without sources
|
464 |
+
else:
|
465 |
+
# Use Hugging Face API
|
466 |
+
client = InferenceClient(model, token=huggingface_token)
|
467 |
+
|
468 |
+
main_content = ""
|
469 |
+
for i in range(num_calls):
|
470 |
+
for message in client.chat_completion(
|
471 |
+
messages=[{"role": "user", "content": prompt}],
|
472 |
+
max_tokens=10000,
|
473 |
+
temperature=temperature,
|
474 |
+
stream=True,
|
475 |
+
):
|
476 |
+
if message.choices and message.choices[0].delta and message.choices[0].delta.content:
|
477 |
+
chunk = message.choices[0].delta.content
|
478 |
+
main_content += chunk
|
479 |
+
yield main_content, "" # Yield partial main content without sources
|
480 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
|
482 |
def get_response_from_pdf(query, model, selected_docs, num_calls=3, temperature=0.2):
|
483 |
logging.info(f"Entering get_response_from_pdf with query: {query}, model: {model}, selected_docs: {selected_docs}")
|