RAG-Xpert / src /app_updated.py
TechyCode's picture
Update src/app_updated.py
0bd7570 verified
import streamlit as st
import os
import dotenv
import uuid
import logging
# Configure environment for Hugging Face Spaces
os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/.cache/huggingface"
# Create necessary directories
os.makedirs("/tmp/.cache/huggingface", exist_ok=True)
os.makedirs("/tmp/chroma_persistent_db", exist_ok=True)
os.makedirs("/tmp/source_files", exist_ok=True)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from langchain.schema import HumanMessage, AIMessage
from langchain_groq import ChatGroq
from rag_methods import (
load_doc_to_db,
load_url_to_db,
stream_llm_response,
stream_llm_rag_response,
)
dotenv.load_dotenv()
# --- Custom CSS Styling ---
def apply_custom_css():
st.markdown("""
<style>
.main .block-container {
padding-top: 2rem;
padding-bottom: 2rem;
}
h1, h2, h3, h4 {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
font-weight: 600;
}
.app-title {
text-align: center;
color: #4361ee;
font-size: 2.2rem;
font-weight: 700;
margin-bottom: 1.5rem;
padding: 1rem;
border-radius: 10px;
background: linear-gradient(90deg, rgba(67, 97, 238, 0.1), rgba(58, 12, 163, 0.1));
text-shadow: 0px 0px 2px rgba(0,0,0,0.1);
}
.chat-container {
border-radius: 10px;
padding: 10px;
margin-bottom: 1rem;
}
.message-container {
padding: 0.8rem;
margin-bottom: 0.8rem;
border-radius: 8px;
}
.user-message {
background-color: rgba(67, 97, 238, 0.15);
border-left: 4px solid #4361ee;
}
.assistant-message {
background-color: rgba(58, 12, 163, 0.1);
border-left: 4px solid #3a0ca3;
}
.document-list {
background-color: rgba(67, 97, 238, 0.05);
border-radius: 8px;
padding: 0.7rem;
}
.upload-container {
border: 2px dashed rgba(67, 97, 238, 0.5);
border-radius: 10px;
padding: 1rem;
margin-bottom: 1rem;
text-align: center;
}
.status-indicator {
font-size: 0.85rem;
font-weight: 600;
padding: 0.3rem 0.7rem;
border-radius: 20px;
display: inline-block;
margin-bottom: 0.5rem;
}
.status-active {
background-color: rgba(46, 196, 182, 0.2);
color: #2EC4B6;
}
.status-inactive {
background-color: rgba(231, 111, 81, 0.2);
color: #E76F51;
}
@media screen and (max-width: 768px) {
.app-title {
font-size: 1.8rem;
padding: 0.7rem;
}
}
</style>
""", unsafe_allow_html=True)
# --- Page Setup ---
st.set_page_config(
page_title="RAG-Xpert: An Enhanced RAG Framework",
page_icon="πŸ“š",
layout="centered",
initial_sidebar_state="expanded"
)
apply_custom_css()
st.markdown('<h1 class="app-title">πŸ“š RAG-Xpert: An Enhanced Retrieval-Augmented Generation Framework πŸ€–</h1>', unsafe_allow_html=True)
# --- Session Initialization ---
if "session_id" not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
if "rag_sources" not in st.session_state:
st.session_state.rag_sources = []
if "messages" not in st.session_state:
st.session_state.messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there! How can I assist you today?"}
]
# --- Sidebar ---
with st.sidebar:
st.markdown("""
<div style="
text-align: center;
padding: 1rem 0;
margin-bottom: 1.5rem;
background: linear-gradient(to right, #4361ee22, #3a0ca322);
border-radius: 10px;">
<div style="font-size: 0.85rem; color: #888;">Developed By</div>
<div style="font-size: 1.2rem; font-weight: 700; color: #4361ee;">Uditanshu Pandey</div>
</div>
""", unsafe_allow_html=True)
is_vector_db_loaded = "vector_db" in st.session_state and st.session_state.vector_db is not None
rag_status = st.toggle("Enable Knowledge Enhancement (RAG)", value=is_vector_db_loaded, key="use_rag", disabled=not is_vector_db_loaded)
if rag_status:
st.markdown('<div class="status-indicator status-active">RAG Mode: Active βœ“</div>', unsafe_allow_html=True)
else:
st.markdown('<div class="status-indicator status-inactive">RAG Mode: Inactive βœ—</div>', unsafe_allow_html=True)
st.toggle("Show Retrieved Context", key="debug_mode", value=False)
st.button("🧹 Clear Chat History", on_click=lambda: st.session_state.messages.clear(), type="primary")
st.markdown("<h3 style='text-align: center; color: #4361ee; margin-top: 1.5rem;'>πŸ“š Knowledge Sources</h3>", unsafe_allow_html=True)
st.markdown('<div class="upload-container">', unsafe_allow_html=True)
st.file_uploader("πŸ“„ Upload Documents", type=["pdf", "txt", "docx", "md"], accept_multiple_files=True, on_change=load_doc_to_db, key="rag_docs")
st.markdown('</div>', unsafe_allow_html=True)
st.text_input("🌐 Add Webpage URL", placeholder="https://example.com", on_change=load_url_to_db, key="rag_url")
doc_count = len(st.session_state.rag_sources) if is_vector_db_loaded else 0
with st.expander(f"πŸ“‘ Knowledge Base ({doc_count} sources)"):
if doc_count:
st.markdown('<div class="document-list">', unsafe_allow_html=True)
for i, source in enumerate(st.session_state.rag_sources):
st.markdown(f"**{i+1}.** {source}")
st.markdown('</div>', unsafe_allow_html=True)
else:
st.info("No documents added yet. Upload files or add URLs to enhance the assistant's knowledge.")
# --- Initialize LLM ---
llm_stream = ChatGroq(
model_name="meta-llama/llama-4-scout-17b-16e-instruct",
api_key=os.getenv("GROQ_API_KEY"),
temperature=0.4,
max_tokens=1024,
)
# --- Chat Display ---
st.markdown('<div class="chat-container">', unsafe_allow_html=True)
for message in st.session_state.messages:
avatar = "πŸ‘€" if message["role"] == "user" else "πŸ€–"
css_class = "user-message" if message["role"] == "user" else "assistant-message"
with st.chat_message(message["role"], avatar=avatar):
st.markdown(f'<div class="message-container {css_class}">{message["content"]}</div>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
# --- User Input Handling ---
if prompt := st.chat_input("Ask me anything..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user", avatar="πŸ‘€"):
st.markdown(f'<div class="message-container user-message">{prompt}</div>', unsafe_allow_html=True)
with st.chat_message("assistant", avatar="πŸ€–"):
thinking_placeholder = st.empty()
thinking_placeholder.info("Thinking... Please wait a moment.")
messages = [
HumanMessage(content=m["content"]) if m["role"] == "user" else AIMessage(content=m["content"])
for m in st.session_state.messages
]
if not st.session_state.use_rag:
thinking_placeholder.empty()
st.write_stream(stream_llm_response(llm_stream, messages))
else:
thinking_placeholder.info("Searching knowledge base... Please wait.")
st.write_stream(stream_llm_rag_response(llm_stream, messages))