import os import dotenv from time import time import streamlit as st import logging # Configure environment for Hugging Face Spaces os.environ["HF_HOME"] = "/tmp/.cache/huggingface" os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface" os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/.cache/huggingface" # Create necessary directories os.makedirs("/tmp/.cache/huggingface", exist_ok=True) os.makedirs("/tmp/chroma_persistent_db", exist_ok=True) os.makedirs("/tmp/source_files", exist_ok=True) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) from langchain_community.document_loaders.text import TextLoader from langchain_community.document_loaders import ( WebBaseLoader, PyPDFLoader, Docx2txtLoader, ) from langchain_community.vectorstores import Chroma from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.chains import create_history_aware_retriever, create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain dotenv.load_dotenv() os.environ["USER_AGENT"] = "myagent" DB_DOCS_LIMIT = 10 def clean_temp_files(): """Clean up temporary files to prevent storage issues""" try: for folder in ["/tmp/source_files"]: for filename in os.listdir(folder): file_path = os.path.join(folder, filename) if os.path.isfile(file_path): os.unlink(file_path) except Exception as e: logger.warning(f"Error cleaning temp files: {e}") def stream_llm_response(llm_stream, messages): response_message = "" for chunk in llm_stream.stream(messages): response_message += chunk.content yield chunk st.session_state.messages.append({"role": "assistant", "content": response_message}) def load_doc_to_db(): if "rag_docs" in st.session_state and st.session_state.rag_docs: docs = [] for doc_file in st.session_state.rag_docs: if doc_file.name not in st.session_state.rag_sources: if len(st.session_state.rag_sources) < DB_DOCS_LIMIT: try: file_path = f"/tmp/source_files/{doc_file.name}" with open(file_path, "wb") as file: file.write(doc_file.getbuffer()) if doc_file.type == "application/pdf": loader = PyPDFLoader(file_path) elif doc_file.name.endswith(".docx"): loader = Docx2txtLoader(file_path) elif doc_file.type in ["text/plain", "text/markdown"]: loader = TextLoader(file_path) else: st.warning(f"Unsupported document type: {doc_file.type}") continue docs.extend(loader.load()) st.session_state.rag_sources.append(doc_file.name) logger.info(f"Successfully loaded document: {doc_file.name}") except Exception as e: st.toast(f"Error loading document {doc_file.name}: {str(e)}", icon="⚠️") logger.error(f"Error loading document: {e}") finally: if os.path.exists(file_path): os.remove(file_path) else: st.error(f"Max documents reached ({DB_DOCS_LIMIT}).") if docs: _split_and_load_docs(docs) st.toast("Documents loaded successfully.", icon="✅") clean_temp_files() def load_url_to_db(): if "rag_url" in st.session_state and st.session_state.rag_url: url = st.session_state.rag_url docs = [] if url not in st.session_state.rag_sources: if len(st.session_state.rag_sources) < DB_DOCS_LIMIT: try: loader = WebBaseLoader(url) docs.extend(loader.load()) st.session_state.rag_sources.append(url) logger.info(f"Successfully loaded URL: {url}") except Exception as e: st.error(f"Error loading from URL {url}: {str(e)}") logger.error(f"Error loading URL: {e}") if docs: _split_and_load_docs(docs) st.toast(f"Loaded content from URL: {url}", icon="✅") else: st.error(f"Max documents reached ({DB_DOCS_LIMIT}).") def initialize_vector_db(docs): embedding = HuggingFaceEmbeddings( model_name="BAAI/bge-large-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': False}, cache_folder="/tmp/.cache" ) persist_dir = "/tmp/chroma_persistent_db" collection_name = "persistent_collection" vector_db = Chroma.from_documents( documents=docs, embedding=embedding, persist_directory=persist_dir, collection_name=collection_name ) vector_db.persist() logger.info("Vector database initialized and persisted") return vector_db def _split_and_load_docs(docs): text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200, ) chunks = text_splitter.split_documents(docs) if "vector_db" not in st.session_state: st.session_state.vector_db = initialize_vector_db(chunks) else: st.session_state.vector_db.add_documents(chunks) st.session_state.vector_db.persist() logger.info("Added new documents to existing vector database") def _get_context_retriever_chain(vector_db, llm): retriever = vector_db.as_retriever() prompt = ChatPromptTemplate.from_messages([ MessagesPlaceholder(variable_name="messages"), ("user", "{input}"), ("user", "Given the above conversation, generate a search query to find relevant information.") ]) return create_history_aware_retriever(llm, retriever, prompt) def get_conversational_rag_chain(llm): retriever_chain = _get_context_retriever_chain