import os import dotenv from time import time import streamlit as st from langchain_community.document_loaders.text import TextLoader from langchain_community.document_loaders import ( WebBaseLoader, PyPDFLoader, Docx2txtLoader, ) from langchain_community.vectorstores import Chroma from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.chains import create_history_aware_retriever, create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain dotenv.load_dotenv() os.environ["USER_AGENT"] = "myagent" DB_DOCS_LIMIT = 10 # Stream non-RAG LLM response def stream_llm_response(llm_stream, messages): response_message = "" for chunk in llm_stream.stream(messages): response_message += chunk.content yield chunk st.session_state.messages.append({"role": "assistant", "content": response_message}) # --- Document Loading and Indexing --- def load_doc_to_db(): if "rag_docs" in st.session_state and st.session_state.rag_docs: docs = [] for doc_file in st.session_state.rag_docs: if doc_file.name not in st.session_state.rag_sources: if len(st.session_state.rag_sources) < DB_DOCS_LIMIT: os.makedirs("source_files", exist_ok=True) file_path = f"./source_files/{doc_file.name}" with open(file_path, "wb") as file: file.write(doc_file.read()) try: if doc_file.type == "application/pdf": loader = PyPDFLoader(file_path) elif doc_file.name.endswith(".docx"): loader = Docx2txtLoader(file_path) elif doc_file.type in ["text/plain", "text/markdown"]: loader = TextLoader(file_path) else: st.warning(f"Unsupported document type: {doc_file.type}") continue docs.extend(loader.load()) st.session_state.rag_sources.append(doc_file.name) except Exception as e: st.toast(f"Error loading document {doc_file.name}: {e}", icon="⚠️") finally: os.remove(file_path) else: st.error(f"Max documents reached ({DB_DOCS_LIMIT}).") if docs: _split_and_load_docs(docs) st.toast(f"Documents loaded successfully.", icon="✅") def load_url_to_db(): if "rag_url" in st.session_state and st.session_state.rag_url: url = st.session_state.rag_url docs = [] if url not in st.session_state.rag_sources: if len(st.session_state.rag_sources) < DB_DOCS_LIMIT: try: loader = WebBaseLoader(url) docs.extend(loader.load()) st.session_state.rag_sources.append(url) except Exception as e: st.error(f"Error loading from URL {url}: {e}") if docs: _split_and_load_docs(docs) st.toast(f"Loaded content from URL: {url}", icon="✅") else: st.error(f"Max documents reached ({DB_DOCS_LIMIT}).") def initialize_vector_db(docs): # Initialize HuggingFace embeddings embedding = HuggingFaceEmbeddings( model_name="BAAI/bge-large-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': False} ) # Shared persistent directory for long-term storage persist_dir = "./chroma_persistent_db" collection_name = "persistent_collection" # Create the persistent Chroma vector store vector_db = Chroma.from_documents( documents=docs, embedding=embedding, persist_directory=persist_dir, collection_name=collection_name ) # Persist to disk vector_db.persist() return vector_db def _split_and_load_docs(docs): text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200, ) chunks = text_splitter.split_documents(docs) if "vector_db" not in st.session_state: st.session_state.vector_db = initialize_vector_db(chunks) else: st.session_state.vector_db.add_documents(chunks) st.session_state.vector_db.persist() # Save changes # --- RAG Chain --- def _get_context_retriever_chain(vector_db, llm): retriever = vector_db.as_retriever() prompt = ChatPromptTemplate.from_messages([ MessagesPlaceholder(variable_name="messages"), ("user", "{input}"), ("user", "Given the above conversation, generate a search query to find relevant information.") ]) return create_history_aware_retriever(llm, retriever, prompt) def get_conversational_rag_chain(llm): retriever_chain = _get_context_retriever_chain(st.session_state.vector_db, llm) prompt = ChatPromptTemplate.from_messages([ ("system", """You are a helpful assistant answering the user's queries using the provided context if available.\n {context}"""), MessagesPlaceholder(variable_name="messages"), ("user", "{input}") ]) stuff_documents_chain = create_stuff_documents_chain(llm, prompt) return create_retrieval_chain(retriever_chain, stuff_documents_chain) # Stream RAG LLM response def stream_llm_rag_response(llm_stream, messages): rag_chain = get_conversational_rag_chain(llm_stream) # Extract latest user input and prior messages input_text = messages[-1].content history = messages[:-1] # --- DEBUG: Show context retrieved --- if st.session_state.get("debug_mode"): retriever = st.session_state.vector_db.as_retriever() retrieved_docs = retriever.get_relevant_documents(input_text) st.markdown("### 🔍 Retrieved Context (Debug Mode)") for i, doc in enumerate(retrieved_docs): st.markdown(f"**Chunk {i+1}:**\n```\n{doc.page_content.strip()}\n```") response_message = "*(RAG Response)*\n" response = rag_chain.stream({ "messages": history, "input": input_text }) for chunk in response: if 'answer' in chunk: response_message += chunk['answer'] yield chunk['answer'] st.session_state.messages.append({"role": "assistant", "content": response_message})