|
import streamlit as st |
|
import os |
|
import dotenv |
|
import uuid |
|
import logging |
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp/.cache/huggingface" |
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface" |
|
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/.cache/huggingface" |
|
|
|
|
|
os.makedirs("/tmp/.cache/huggingface", exist_ok=True) |
|
os.makedirs("/tmp/chroma_persistent_db", exist_ok=True) |
|
os.makedirs("/tmp/source_files", exist_ok=True) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
from langchain.schema import HumanMessage, AIMessage |
|
from langchain_groq import ChatGroq |
|
from rag_methods import ( |
|
load_doc_to_db, |
|
load_url_to_db, |
|
stream_llm_response, |
|
stream_llm_rag_response, |
|
) |
|
|
|
dotenv.load_dotenv() |
|
|
|
|
|
def apply_custom_css(): |
|
st.markdown(""" |
|
<style> |
|
.main .block-container { |
|
padding-top: 2rem; |
|
padding-bottom: 2rem; |
|
} |
|
h1, h2, h3, h4 { |
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
|
font-weight: 600; |
|
} |
|
.app-title { |
|
text-align: center; |
|
color: #4361ee; |
|
font-size: 2.2rem; |
|
font-weight: 700; |
|
margin-bottom: 1.5rem; |
|
padding: 1rem; |
|
border-radius: 10px; |
|
background: linear-gradient(90deg, rgba(67, 97, 238, 0.1), rgba(58, 12, 163, 0.1)); |
|
text-shadow: 0px 0px 2px rgba(0,0,0,0.1); |
|
} |
|
.chat-container { |
|
border-radius: 10px; |
|
padding: 10px; |
|
margin-bottom: 1rem; |
|
} |
|
.message-container { |
|
padding: 0.8rem; |
|
margin-bottom: 0.8rem; |
|
border-radius: 8px; |
|
} |
|
.user-message { |
|
background-color: rgba(67, 97, 238, 0.15); |
|
border-left: 4px solid #4361ee; |
|
} |
|
.assistant-message { |
|
background-color: rgba(58, 12, 163, 0.1); |
|
border-left: 4px solid #3a0ca3; |
|
} |
|
.document-list { |
|
background-color: rgba(67, 97, 238, 0.05); |
|
border-radius: 8px; |
|
padding: 0.7rem; |
|
} |
|
.upload-container { |
|
border: 2px dashed rgba(67, 97, 238, 0.5); |
|
border-radius: 10px; |
|
padding: 1rem; |
|
margin-bottom: 1rem; |
|
text-align: center; |
|
} |
|
.status-indicator { |
|
font-size: 0.85rem; |
|
font-weight: 600; |
|
padding: 0.3rem 0.7rem; |
|
border-radius: 20px; |
|
display: inline-block; |
|
margin-bottom: 0.5rem; |
|
} |
|
.status-active { |
|
background-color: rgba(46, 196, 182, 0.2); |
|
color: #2EC4B6; |
|
} |
|
.status-inactive { |
|
background-color: rgba(231, 111, 81, 0.2); |
|
color: #E76F51; |
|
} |
|
@media screen and (max-width: 768px) { |
|
.app-title { |
|
font-size: 1.8rem; |
|
padding: 0.7rem; |
|
} |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.set_page_config( |
|
page_title="RAG-Xpert: An Enhanced RAG Framework", |
|
page_icon="π", |
|
layout="centered", |
|
initial_sidebar_state="expanded" |
|
) |
|
|
|
apply_custom_css() |
|
|
|
st.markdown('<h1 class="app-title">π RAG-Xpert: An Enhanced Retrieval-Augmented Generation Framework π€</h1>', unsafe_allow_html=True) |
|
|
|
|
|
if "session_id" not in st.session_state: |
|
st.session_state.session_id = str(uuid.uuid4()) |
|
if "rag_sources" not in st.session_state: |
|
st.session_state.rag_sources = [] |
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [ |
|
{"role": "user", "content": "Hello"}, |
|
{"role": "assistant", "content": "Hi there! How can I assist you today?"} |
|
] |
|
|
|
|
|
with st.sidebar: |
|
st.markdown(""" |
|
<div style=" |
|
text-align: center; |
|
padding: 1rem 0; |
|
margin-bottom: 1.5rem; |
|
background: linear-gradient(to right, #4361ee22, #3a0ca322); |
|
border-radius: 10px;"> |
|
<div style="font-size: 0.85rem; color: #888;">Developed By</div> |
|
<div style="font-size: 1.2rem; font-weight: 700; color: #4361ee;">Uditanshu Pandey</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
is_vector_db_loaded = "vector_db" in st.session_state and st.session_state.vector_db is not None |
|
rag_status = st.toggle("Enable Knowledge Enhancement (RAG)", value=is_vector_db_loaded, key="use_rag", disabled=not is_vector_db_loaded) |
|
|
|
if rag_status: |
|
st.markdown('<div class="status-indicator status-active">RAG Mode: Active β</div>', unsafe_allow_html=True) |
|
else: |
|
st.markdown('<div class="status-indicator status-inactive">RAG Mode: Inactive β</div>', unsafe_allow_html=True) |
|
|
|
st.toggle("Show Retrieved Context", key="debug_mode", value=False) |
|
st.button("π§Ή Clear Chat History", on_click=lambda: st.session_state.messages.clear(), type="primary") |
|
|
|
st.markdown("<h3 style='text-align: center; color: #4361ee; margin-top: 1.5rem;'>π Knowledge Sources</h3>", unsafe_allow_html=True) |
|
st.markdown('<div class="upload-container">', unsafe_allow_html=True) |
|
st.file_uploader("π Upload Documents", type=["pdf", "txt", "docx", "md"], accept_multiple_files=True, on_change=load_doc_to_db, key="rag_docs") |
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
st.text_input("π Add Webpage URL", placeholder="https://example.com", on_change=load_url_to_db, key="rag_url") |
|
|
|
doc_count = len(st.session_state.rag_sources) if is_vector_db_loaded else 0 |
|
with st.expander(f"π Knowledge Base ({doc_count} sources)"): |
|
if doc_count: |
|
st.markdown('<div class="document-list">', unsafe_allow_html=True) |
|
for i, source in enumerate(st.session_state.rag_sources): |
|
st.markdown(f"**{i+1}.** {source}") |
|
st.markdown('</div>', unsafe_allow_html=True) |
|
else: |
|
st.info("No documents added yet. Upload files or add URLs to enhance the assistant's knowledge.") |
|
|
|
|
|
llm_stream = ChatGroq( |
|
model_name="meta-llama/llama-4-scout-17b-16e-instruct", |
|
api_key=os.getenv("GROQ_API_KEY"), |
|
temperature=0.4, |
|
max_tokens=1024, |
|
) |
|
|
|
|
|
st.markdown('<div class="chat-container">', unsafe_allow_html=True) |
|
for message in st.session_state.messages: |
|
avatar = "π€" if message["role"] == "user" else "π€" |
|
css_class = "user-message" if message["role"] == "user" else "assistant-message" |
|
with st.chat_message(message["role"], avatar=avatar): |
|
st.markdown(f'<div class="message-container {css_class}">{message["content"]}</div>', unsafe_allow_html=True) |
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
|
|
if prompt := st.chat_input("Ask me anything..."): |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
with st.chat_message("user", avatar="π€"): |
|
st.markdown(f'<div class="message-container user-message">{prompt}</div>', unsafe_allow_html=True) |
|
|
|
with st.chat_message("assistant", avatar="π€"): |
|
thinking_placeholder = st.empty() |
|
thinking_placeholder.info("Thinking... Please wait a moment.") |
|
messages = [ |
|
HumanMessage(content=m["content"]) if m["role"] == "user" else AIMessage(content=m["content"]) |
|
for m in st.session_state.messages |
|
] |
|
if not st.session_state.use_rag: |
|
thinking_placeholder.empty() |
|
st.write_stream(stream_llm_response(llm_stream, messages)) |
|
else: |
|
thinking_placeholder.info("Searching knowledge base... Please wait.") |
|
st.write_stream(stream_llm_rag_response(llm_stream, messages)) |