|
import streamlit as st
|
|
import os
|
|
import dotenv
|
|
import uuid
|
|
|
|
|
|
if os.name == 'posix':
|
|
__import__('pysqlite3')
|
|
import sys
|
|
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
|
|
|
|
from langchain.schema import HumanMessage, AIMessage
|
|
from langchain_groq import ChatGroq
|
|
|
|
from rag_methods import (
|
|
load_doc_to_db,
|
|
load_url_to_db,
|
|
stream_llm_response,
|
|
stream_llm_rag_response,
|
|
)
|
|
|
|
dotenv.load_dotenv()
|
|
|
|
|
|
def apply_custom_css():
|
|
st.markdown("""
|
|
<style>
|
|
.main .block-container {
|
|
padding-top: 2rem;
|
|
padding-bottom: 2rem;
|
|
}
|
|
h1, h2, h3, h4 {
|
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
|
font-weight: 600;
|
|
}
|
|
.app-title {
|
|
text-align: center;
|
|
color: #4361ee;
|
|
font-size: 2.2rem;
|
|
font-weight: 700;
|
|
margin-bottom: 1.5rem;
|
|
padding: 1rem;
|
|
border-radius: 10px;
|
|
background: linear-gradient(90deg, rgba(67, 97, 238, 0.1), rgba(58, 12, 163, 0.1));
|
|
text-shadow: 0px 0px 2px rgba(0,0,0,0.1);
|
|
}
|
|
.chat-container {
|
|
border-radius: 10px;
|
|
padding: 10px;
|
|
margin-bottom: 1rem;
|
|
}
|
|
.message-container {
|
|
padding: 0.8rem;
|
|
margin-bottom: 0.8rem;
|
|
border-radius: 8px;
|
|
}
|
|
.user-message {
|
|
background-color: rgba(67, 97, 238, 0.15);
|
|
border-left: 4px solid #4361ee;
|
|
}
|
|
.assistant-message {
|
|
background-color: rgba(58, 12, 163, 0.1);
|
|
border-left: 4px solid #3a0ca3;
|
|
}
|
|
.document-list {
|
|
background-color: rgba(67, 97, 238, 0.05);
|
|
border-radius: 8px;
|
|
padding: 0.7rem;
|
|
}
|
|
.upload-container {
|
|
border: 2px dashed rgba(67, 97, 238, 0.5);
|
|
border-radius: 10px;
|
|
padding: 1rem;
|
|
margin-bottom: 1rem;
|
|
text-align: center;
|
|
}
|
|
.status-indicator {
|
|
font-size: 0.85rem;
|
|
font-weight: 600;
|
|
padding: 0.3rem 0.7rem;
|
|
border-radius: 20px;
|
|
display: inline-block;
|
|
margin-bottom: 0.5rem;
|
|
}
|
|
.status-active {
|
|
background-color: rgba(46, 196, 182, 0.2);
|
|
color: #2EC4B6;
|
|
}
|
|
.status-inactive {
|
|
background-color: rgba(231, 111, 81, 0.2);
|
|
color: #E76F51;
|
|
}
|
|
@media screen and (max-width: 768px) {
|
|
.app-title {
|
|
font-size: 1.8rem;
|
|
padding: 0.7rem;
|
|
}
|
|
}
|
|
</style>
|
|
""", unsafe_allow_html=True)
|
|
|
|
|
|
st.set_page_config(
|
|
page_title="RAG-Xpert: An Enhanced RAG Framework",
|
|
page_icon="π",
|
|
layout="centered",
|
|
initial_sidebar_state="expanded"
|
|
)
|
|
|
|
apply_custom_css()
|
|
|
|
st.markdown('<h1 class="app-title">π RAG-Xpert: An Enhanced Retrieval-Augmented Generation Framework π€</h1>', unsafe_allow_html=True)
|
|
|
|
|
|
if "session_id" not in st.session_state:
|
|
st.session_state.session_id = str(uuid.uuid4())
|
|
if "rag_sources" not in st.session_state:
|
|
st.session_state.rag_sources = []
|
|
if "messages" not in st.session_state:
|
|
st.session_state.messages = [
|
|
{"role": "user", "content": "Hello"},
|
|
{"role": "assistant", "content": "Hi there! How can I assist you today?"}
|
|
]
|
|
|
|
|
|
with st.sidebar:
|
|
st.markdown("""
|
|
<div style="
|
|
text-align: center;
|
|
padding: 1rem 0;
|
|
margin-bottom: 1.5rem;
|
|
background: linear-gradient(to right, #4361ee22, #3a0ca322);
|
|
border-radius: 10px;">
|
|
<div style="font-size: 0.85rem; color: #888;">Developed By</div>
|
|
<div style="font-size: 1.2rem; font-weight: 700; color: #4361ee;">Uditanshu Pandey</div>
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
|
|
is_vector_db_loaded = "vector_db" in st.session_state and st.session_state.vector_db is not None
|
|
rag_status = st.toggle("Enable Knowledge Enhancement (RAG)", value=is_vector_db_loaded, key="use_rag", disabled=not is_vector_db_loaded)
|
|
|
|
if rag_status:
|
|
st.markdown('<div class="status-indicator status-active">RAG Mode: Active β</div>', unsafe_allow_html=True)
|
|
else:
|
|
st.markdown('<div class="status-indicator status-inactive">RAG Mode: Inactive β</div>', unsafe_allow_html=True)
|
|
|
|
st.toggle("Show Retrieved Context", key="debug_mode", value=False)
|
|
st.button("π§Ή Clear Chat History", on_click=lambda: st.session_state.messages.clear(), type="primary")
|
|
|
|
st.markdown("<h3 style='text-align: center; color: #4361ee; margin-top: 1.5rem;'>π Knowledge Sources</h3>", unsafe_allow_html=True)
|
|
st.markdown('<div class="upload-container">', unsafe_allow_html=True)
|
|
st.file_uploader("π Upload Documents", type=["pdf", "txt", "docx", "md"], accept_multiple_files=True, on_change=load_doc_to_db, key="rag_docs")
|
|
st.markdown('</div>', unsafe_allow_html=True)
|
|
|
|
st.text_input("π Add Webpage URL", placeholder="https://example.com", on_change=load_url_to_db, key="rag_url")
|
|
|
|
doc_count = len(st.session_state.rag_sources) if is_vector_db_loaded else 0
|
|
with st.expander(f"π Knowledge Base ({doc_count} sources)"):
|
|
if doc_count:
|
|
st.markdown('<div class="document-list">', unsafe_allow_html=True)
|
|
for i, source in enumerate(st.session_state.rag_sources):
|
|
st.markdown(f"**{i+1}.** {source}")
|
|
st.markdown('</div>', unsafe_allow_html=True)
|
|
else:
|
|
st.info("No documents added yet. Upload files or add URLs to enhance the assistant's knowledge.")
|
|
|
|
|
|
llm_stream = ChatGroq(
|
|
model_name="meta-llama/llama-4-scout-17b-16e-instruct",
|
|
api_key=os.getenv("GROQ_API_KEY"),
|
|
temperature=0.4,
|
|
max_tokens=1024,
|
|
)
|
|
|
|
|
|
st.markdown('<div class="chat-container">', unsafe_allow_html=True)
|
|
for message in st.session_state.messages:
|
|
avatar = "π€" if message["role"] == "user" else "π€"
|
|
css_class = "user-message" if message["role"] == "user" else "assistant-message"
|
|
with st.chat_message(message["role"], avatar=avatar):
|
|
st.markdown(f'<div class="message-container {css_class}">{message["content"]}</div>', unsafe_allow_html=True)
|
|
st.markdown('</div>', unsafe_allow_html=True)
|
|
|
|
|
|
if prompt := st.chat_input("Ask me anything..."):
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
with st.chat_message("user", avatar="π€"):
|
|
st.markdown(f'<div class="message-container user-message">{prompt}</div>', unsafe_allow_html=True)
|
|
|
|
with st.chat_message("assistant", avatar="π€"):
|
|
thinking_placeholder = st.empty()
|
|
thinking_placeholder.info("Thinking... Please wait a moment.")
|
|
messages = [
|
|
HumanMessage(content=m["content"]) if m["role"] == "user" else AIMessage(content=m["content"])
|
|
for m in st.session_state.messages
|
|
]
|
|
if not st.session_state.use_rag:
|
|
thinking_placeholder.empty()
|
|
st.write_stream(stream_llm_response(llm_stream, messages))
|
|
else:
|
|
thinking_placeholder.info("Searching knowledge base... Please wait.")
|
|
st.write_stream(stream_llm_rag_response(llm_stream, messages)) |