import streamlit as st
import os
import dotenv
import uuid
import logging
# Configure environment for Hugging Face Spaces
os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/.cache/huggingface"
# Create necessary directories
os.makedirs("/tmp/.cache/huggingface", exist_ok=True)
os.makedirs("/tmp/chroma_persistent_db", exist_ok=True)
os.makedirs("/tmp/source_files", exist_ok=True)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from langchain.schema import HumanMessage, AIMessage
from langchain_groq import ChatGroq
from rag_methods import (
load_doc_to_db,
load_url_to_db,
stream_llm_response,
stream_llm_rag_response,
)
dotenv.load_dotenv()
# --- Custom CSS Styling ---
def apply_custom_css():
st.markdown("""
""", unsafe_allow_html=True)
# --- Page Setup ---
st.set_page_config(
page_title="RAG-Xpert: An Enhanced RAG Framework",
page_icon="๐",
layout="centered",
initial_sidebar_state="expanded"
)
apply_custom_css()
st.markdown('
๐ RAG-Xpert: An Enhanced Retrieval-Augmented Generation Framework ๐ค
', unsafe_allow_html=True)
# --- Session Initialization ---
if "session_id" not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
if "rag_sources" not in st.session_state:
st.session_state.rag_sources = []
if "messages" not in st.session_state:
st.session_state.messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there! How can I assist you today?"}
]
# --- Sidebar ---
with st.sidebar:
st.markdown("""
Developed By
Uditanshu Pandey
""", unsafe_allow_html=True)
is_vector_db_loaded = "vector_db" in st.session_state and st.session_state.vector_db is not None
rag_status = st.toggle("Enable Knowledge Enhancement (RAG)", value=is_vector_db_loaded, key="use_rag", disabled=not is_vector_db_loaded)
if rag_status:
st.markdown('RAG Mode: Active โ
', unsafe_allow_html=True)
else:
st.markdown('RAG Mode: Inactive โ
', unsafe_allow_html=True)
st.toggle("Show Retrieved Context", key="debug_mode", value=False)
st.button("๐งน Clear Chat History", on_click=lambda: st.session_state.messages.clear(), type="primary")
st.markdown("๐ Knowledge Sources
", unsafe_allow_html=True)
st.markdown('', unsafe_allow_html=True)
st.file_uploader("๐ Upload Documents", type=["pdf", "txt", "docx", "md"], accept_multiple_files=True, on_change=load_doc_to_db, key="rag_docs")
st.markdown('
', unsafe_allow_html=True)
st.text_input("๐ Add Webpage URL", placeholder="https://example.com", on_change=load_url_to_db, key="rag_url")
doc_count = len(st.session_state.rag_sources) if is_vector_db_loaded else 0
with st.expander(f"๐ Knowledge Base ({doc_count} sources)"):
if doc_count:
st.markdown('', unsafe_allow_html=True)
for i, source in enumerate(st.session_state.rag_sources):
st.markdown(f"**{i+1}.** {source}")
st.markdown('
', unsafe_allow_html=True)
else:
st.info("No documents added yet. Upload files or add URLs to enhance the assistant's knowledge.")
# --- Initialize LLM ---
llm_stream = ChatGroq(
model_name="meta-llama/llama-4-scout-17b-16e-instruct",
api_key=os.getenv("GROQ_API_KEY"),
temperature=0.4,
max_tokens=1024,
)
# --- Chat Display ---
st.markdown('', unsafe_allow_html=True)
for message in st.session_state.messages:
avatar = "๐ค" if message["role"] == "user" else "๐ค"
css_class = "user-message" if message["role"] == "user" else "assistant-message"
with st.chat_message(message["role"], avatar=avatar):
st.markdown(f'
{message["content"]}
', unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)
# --- User Input Handling ---
if prompt := st.chat_input("Ask me anything..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user", avatar="๐ค"):
st.markdown(f'{prompt}
', unsafe_allow_html=True)
with st.chat_message("assistant", avatar="๐ค"):
thinking_placeholder = st.empty()
thinking_placeholder.info("Thinking... Please wait a moment.")
messages = [
HumanMessage(content=m["content"]) if m["role"] == "user" else AIMessage(content=m["content"])
for m in st.session_state.messages
]
if not st.session_state.use_rag:
thinking_placeholder.empty()
st.write_stream(stream_llm_response(llm_stream, messages))
else:
thinking_placeholder.info("Searching knowledge base... Please wait.")
st.write_stream(stream_llm_rag_response(llm_stream, messages))