TechyCode commited on
Commit
a2ac738
·
verified ·
1 Parent(s): 2478ba2

Upload 3 files

Browse files
Files changed (3) hide show
  1. src/app_updated.py +202 -0
  2. src/rag_methods.py +174 -0
  3. src/requirements.txt +0 -0
src/app_updated.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import dotenv
4
+ import uuid
5
+
6
+ # Patch sqlite3 for Streamlit Cloud compatibility
7
+ if os.name == 'posix':
8
+ __import__('pysqlite3')
9
+ import sys
10
+ sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
11
+
12
+ from langchain.schema import HumanMessage, AIMessage
13
+ from langchain_groq import ChatGroq
14
+
15
+ from rag_methods import (
16
+ load_doc_to_db,
17
+ load_url_to_db,
18
+ stream_llm_response,
19
+ stream_llm_rag_response,
20
+ )
21
+
22
+ dotenv.load_dotenv()
23
+
24
+ # --- Custom CSS Styling ---
25
+ def apply_custom_css():
26
+ st.markdown("""
27
+ <style>
28
+ .main .block-container {
29
+ padding-top: 2rem;
30
+ padding-bottom: 2rem;
31
+ }
32
+ h1, h2, h3, h4 {
33
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
34
+ font-weight: 600;
35
+ }
36
+ .app-title {
37
+ text-align: center;
38
+ color: #4361ee;
39
+ font-size: 2.2rem;
40
+ font-weight: 700;
41
+ margin-bottom: 1.5rem;
42
+ padding: 1rem;
43
+ border-radius: 10px;
44
+ background: linear-gradient(90deg, rgba(67, 97, 238, 0.1), rgba(58, 12, 163, 0.1));
45
+ text-shadow: 0px 0px 2px rgba(0,0,0,0.1);
46
+ }
47
+ .chat-container {
48
+ border-radius: 10px;
49
+ padding: 10px;
50
+ margin-bottom: 1rem;
51
+ }
52
+ .message-container {
53
+ padding: 0.8rem;
54
+ margin-bottom: 0.8rem;
55
+ border-radius: 8px;
56
+ }
57
+ .user-message {
58
+ background-color: rgba(67, 97, 238, 0.15);
59
+ border-left: 4px solid #4361ee;
60
+ }
61
+ .assistant-message {
62
+ background-color: rgba(58, 12, 163, 0.1);
63
+ border-left: 4px solid #3a0ca3;
64
+ }
65
+ .document-list {
66
+ background-color: rgba(67, 97, 238, 0.05);
67
+ border-radius: 8px;
68
+ padding: 0.7rem;
69
+ }
70
+ .upload-container {
71
+ border: 2px dashed rgba(67, 97, 238, 0.5);
72
+ border-radius: 10px;
73
+ padding: 1rem;
74
+ margin-bottom: 1rem;
75
+ text-align: center;
76
+ }
77
+ .status-indicator {
78
+ font-size: 0.85rem;
79
+ font-weight: 600;
80
+ padding: 0.3rem 0.7rem;
81
+ border-radius: 20px;
82
+ display: inline-block;
83
+ margin-bottom: 0.5rem;
84
+ }
85
+ .status-active {
86
+ background-color: rgba(46, 196, 182, 0.2);
87
+ color: #2EC4B6;
88
+ }
89
+ .status-inactive {
90
+ background-color: rgba(231, 111, 81, 0.2);
91
+ color: #E76F51;
92
+ }
93
+ @media screen and (max-width: 768px) {
94
+ .app-title {
95
+ font-size: 1.8rem;
96
+ padding: 0.7rem;
97
+ }
98
+ }
99
+ </style>
100
+ """, unsafe_allow_html=True)
101
+
102
+ # --- Page Setup ---
103
+ st.set_page_config(
104
+ page_title="RAG-Xpert: An Enhanced RAG Framework",
105
+ page_icon="📚",
106
+ layout="centered",
107
+ initial_sidebar_state="expanded"
108
+ )
109
+
110
+ apply_custom_css()
111
+
112
+ st.markdown('<h1 class="app-title">📚 RAG-Xpert: An Enhanced Retrieval-Augmented Generation Framework 🤖</h1>', unsafe_allow_html=True)
113
+
114
+ # --- Session Initialization ---
115
+ if "session_id" not in st.session_state:
116
+ st.session_state.session_id = str(uuid.uuid4())
117
+ if "rag_sources" not in st.session_state:
118
+ st.session_state.rag_sources = []
119
+ if "messages" not in st.session_state:
120
+ st.session_state.messages = [
121
+ {"role": "user", "content": "Hello"},
122
+ {"role": "assistant", "content": "Hi there! How can I assist you today?"}
123
+ ]
124
+
125
+ # --- Sidebar ---
126
+ with st.sidebar:
127
+ st.markdown("""
128
+ <div style="
129
+ text-align: center;
130
+ padding: 1rem 0;
131
+ margin-bottom: 1.5rem;
132
+ background: linear-gradient(to right, #4361ee22, #3a0ca322);
133
+ border-radius: 10px;">
134
+ <div style="font-size: 0.85rem; color: #888;">Developed By</div>
135
+ <div style="font-size: 1.2rem; font-weight: 700; color: #4361ee;">Uditanshu Pandey</div>
136
+ </div>
137
+ """, unsafe_allow_html=True)
138
+
139
+ is_vector_db_loaded = "vector_db" in st.session_state and st.session_state.vector_db is not None
140
+ rag_status = st.toggle("Enable Knowledge Enhancement (RAG)", value=is_vector_db_loaded, key="use_rag", disabled=not is_vector_db_loaded)
141
+
142
+ if rag_status:
143
+ st.markdown('<div class="status-indicator status-active">RAG Mode: Active ✓</div>', unsafe_allow_html=True)
144
+ else:
145
+ st.markdown('<div class="status-indicator status-inactive">RAG Mode: Inactive ✗</div>', unsafe_allow_html=True)
146
+
147
+ st.toggle("Show Retrieved Context", key="debug_mode", value=False)
148
+ st.button("🧹 Clear Chat History", on_click=lambda: st.session_state.messages.clear(), type="primary")
149
+
150
+ st.markdown("<h3 style='text-align: center; color: #4361ee; margin-top: 1.5rem;'>📚 Knowledge Sources</h3>", unsafe_allow_html=True)
151
+ st.markdown('<div class="upload-container">', unsafe_allow_html=True)
152
+ st.file_uploader("📄 Upload Documents", type=["pdf", "txt", "docx", "md"], accept_multiple_files=True, on_change=load_doc_to_db, key="rag_docs")
153
+ st.markdown('</div>', unsafe_allow_html=True)
154
+
155
+ st.text_input("🌐 Add Webpage URL", placeholder="https://example.com", on_change=load_url_to_db, key="rag_url")
156
+
157
+ doc_count = len(st.session_state.rag_sources) if is_vector_db_loaded else 0
158
+ with st.expander(f"📑 Knowledge Base ({doc_count} sources)"):
159
+ if doc_count:
160
+ st.markdown('<div class="document-list">', unsafe_allow_html=True)
161
+ for i, source in enumerate(st.session_state.rag_sources):
162
+ st.markdown(f"**{i+1}.** {source}")
163
+ st.markdown('</div>', unsafe_allow_html=True)
164
+ else:
165
+ st.info("No documents added yet. Upload files or add URLs to enhance the assistant's knowledge.")
166
+
167
+ # --- Initialize LLM ---
168
+ llm_stream = ChatGroq(
169
+ model_name="meta-llama/llama-4-scout-17b-16e-instruct",
170
+ api_key=os.getenv("GROQ_API_KEY"),
171
+ temperature=0.4,
172
+ max_tokens=1024,
173
+ )
174
+
175
+ # --- Chat Display ---
176
+ st.markdown('<div class="chat-container">', unsafe_allow_html=True)
177
+ for message in st.session_state.messages:
178
+ avatar = "👤" if message["role"] == "user" else "🤖"
179
+ css_class = "user-message" if message["role"] == "user" else "assistant-message"
180
+ with st.chat_message(message["role"], avatar=avatar):
181
+ st.markdown(f'<div class="message-container {css_class}">{message["content"]}</div>', unsafe_allow_html=True)
182
+ st.markdown('</div>', unsafe_allow_html=True)
183
+
184
+ # --- User Input Handling ---
185
+ if prompt := st.chat_input("Ask me anything..."):
186
+ st.session_state.messages.append({"role": "user", "content": prompt})
187
+ with st.chat_message("user", avatar="👤"):
188
+ st.markdown(f'<div class="message-container user-message">{prompt}</div>', unsafe_allow_html=True)
189
+
190
+ with st.chat_message("assistant", avatar="🤖"):
191
+ thinking_placeholder = st.empty()
192
+ thinking_placeholder.info("Thinking... Please wait a moment.")
193
+ messages = [
194
+ HumanMessage(content=m["content"]) if m["role"] == "user" else AIMessage(content=m["content"])
195
+ for m in st.session_state.messages
196
+ ]
197
+ if not st.session_state.use_rag:
198
+ thinking_placeholder.empty()
199
+ st.write_stream(stream_llm_response(llm_stream, messages))
200
+ else:
201
+ thinking_placeholder.info("Searching knowledge base... Please wait.")
202
+ st.write_stream(stream_llm_rag_response(llm_stream, messages))
src/rag_methods.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import dotenv
3
+ from time import time
4
+ import streamlit as st
5
+
6
+ from langchain_community.document_loaders.text import TextLoader
7
+ from langchain_community.document_loaders import (
8
+ WebBaseLoader,
9
+ PyPDFLoader,
10
+ Docx2txtLoader,
11
+ )
12
+ from langchain_community.vectorstores import Chroma
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from langchain_huggingface import HuggingFaceEmbeddings
15
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
16
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
17
+ from langchain.chains.combine_documents import create_stuff_documents_chain
18
+
19
+ dotenv.load_dotenv()
20
+
21
+ os.environ["USER_AGENT"] = "myagent"
22
+ DB_DOCS_LIMIT = 10
23
+
24
+ # Stream non-RAG LLM response
25
+ def stream_llm_response(llm_stream, messages):
26
+ response_message = ""
27
+ for chunk in llm_stream.stream(messages):
28
+ response_message += chunk.content
29
+ yield chunk
30
+ st.session_state.messages.append({"role": "assistant", "content": response_message})
31
+
32
+ # --- Document Loading and Indexing ---
33
+ def load_doc_to_db():
34
+ if "rag_docs" in st.session_state and st.session_state.rag_docs:
35
+ docs = []
36
+ for doc_file in st.session_state.rag_docs:
37
+ if doc_file.name not in st.session_state.rag_sources:
38
+ if len(st.session_state.rag_sources) < DB_DOCS_LIMIT:
39
+ os.makedirs("source_files", exist_ok=True)
40
+ file_path = f"./source_files/{doc_file.name}"
41
+ with open(file_path, "wb") as file:
42
+ file.write(doc_file.read())
43
+ try:
44
+ if doc_file.type == "application/pdf":
45
+ loader = PyPDFLoader(file_path)
46
+ elif doc_file.name.endswith(".docx"):
47
+ loader = Docx2txtLoader(file_path)
48
+ elif doc_file.type in ["text/plain", "text/markdown"]:
49
+ loader = TextLoader(file_path)
50
+ else:
51
+ st.warning(f"Unsupported document type: {doc_file.type}")
52
+ continue
53
+ docs.extend(loader.load())
54
+ st.session_state.rag_sources.append(doc_file.name)
55
+ except Exception as e:
56
+ st.toast(f"Error loading document {doc_file.name}: {e}", icon="⚠️")
57
+ finally:
58
+ os.remove(file_path)
59
+ else:
60
+ st.error(f"Max documents reached ({DB_DOCS_LIMIT}).")
61
+ if docs:
62
+ _split_and_load_docs(docs)
63
+ st.toast(f"Documents loaded successfully.", icon="✅")
64
+
65
+ def load_url_to_db():
66
+ if "rag_url" in st.session_state and st.session_state.rag_url:
67
+ url = st.session_state.rag_url
68
+ docs = []
69
+ if url not in st.session_state.rag_sources:
70
+ if len(st.session_state.rag_sources) < DB_DOCS_LIMIT:
71
+ try:
72
+ loader = WebBaseLoader(url)
73
+ docs.extend(loader.load())
74
+ st.session_state.rag_sources.append(url)
75
+ except Exception as e:
76
+ st.error(f"Error loading from URL {url}: {e}")
77
+ if docs:
78
+ _split_and_load_docs(docs)
79
+ st.toast(f"Loaded content from URL: {url}", icon="✅")
80
+ else:
81
+ st.error(f"Max documents reached ({DB_DOCS_LIMIT}).")
82
+
83
+ def initialize_vector_db(docs):
84
+ # Initialize HuggingFace embeddings
85
+ embedding = HuggingFaceEmbeddings(
86
+ model_name="BAAI/bge-large-en-v1.5",
87
+ model_kwargs={'device': 'cpu'},
88
+ encode_kwargs={'normalize_embeddings': False}
89
+ )
90
+
91
+ # Shared persistent directory for long-term storage
92
+ persist_dir = "./chroma_persistent_db"
93
+ collection_name = "persistent_collection"
94
+
95
+ # Create the persistent Chroma vector store
96
+ vector_db = Chroma.from_documents(
97
+ documents=docs,
98
+ embedding=embedding,
99
+ persist_directory=persist_dir,
100
+ collection_name=collection_name
101
+ )
102
+
103
+ # Persist to disk
104
+ vector_db.persist()
105
+
106
+ return vector_db
107
+
108
+
109
+ def _split_and_load_docs(docs):
110
+ text_splitter = RecursiveCharacterTextSplitter(
111
+ chunk_size=1000,
112
+ chunk_overlap=200,
113
+ )
114
+
115
+ chunks = text_splitter.split_documents(docs)
116
+
117
+ if "vector_db" not in st.session_state:
118
+ st.session_state.vector_db = initialize_vector_db(chunks)
119
+ else:
120
+ st.session_state.vector_db.add_documents(chunks)
121
+ st.session_state.vector_db.persist() # Save changes
122
+
123
+ # --- RAG Chain ---
124
+
125
+ def _get_context_retriever_chain(vector_db, llm):
126
+ retriever = vector_db.as_retriever()
127
+ prompt = ChatPromptTemplate.from_messages([
128
+ MessagesPlaceholder(variable_name="messages"),
129
+ ("user", "{input}"),
130
+ ("user", "Given the above conversation, generate a search query to find relevant information.")
131
+ ])
132
+ return create_history_aware_retriever(llm, retriever, prompt)
133
+
134
+ def get_conversational_rag_chain(llm):
135
+ retriever_chain = _get_context_retriever_chain(st.session_state.vector_db, llm)
136
+ prompt = ChatPromptTemplate.from_messages([
137
+ ("system",
138
+ """You are a helpful assistant answering the user's queries using the provided context if available.\n
139
+ {context}"""),
140
+ MessagesPlaceholder(variable_name="messages"),
141
+ ("user", "{input}")
142
+ ])
143
+ stuff_documents_chain = create_stuff_documents_chain(llm, prompt)
144
+ return create_retrieval_chain(retriever_chain, stuff_documents_chain)
145
+
146
+ # Stream RAG LLM response
147
+ def stream_llm_rag_response(llm_stream, messages):
148
+ rag_chain = get_conversational_rag_chain(llm_stream)
149
+
150
+ # Extract latest user input and prior messages
151
+ input_text = messages[-1].content
152
+ history = messages[:-1]
153
+
154
+ # --- DEBUG: Show context retrieved ---
155
+ if st.session_state.get("debug_mode"):
156
+ retriever = st.session_state.vector_db.as_retriever()
157
+ retrieved_docs = retriever.get_relevant_documents(input_text)
158
+ st.markdown("### 🔍 Retrieved Context (Debug Mode)")
159
+ for i, doc in enumerate(retrieved_docs):
160
+ st.markdown(f"**Chunk {i+1}:**\n```\n{doc.page_content.strip()}\n```")
161
+
162
+ response_message = "*(RAG Response)*\n"
163
+ response = rag_chain.stream({
164
+ "messages": history,
165
+ "input": input_text
166
+ })
167
+
168
+ for chunk in response:
169
+ if 'answer' in chunk:
170
+ response_message += chunk['answer']
171
+ yield chunk['answer']
172
+
173
+ st.session_state.messages.append({"role": "assistant", "content": response_message})
174
+
src/requirements.txt ADDED
Binary file (6.74 kB). View file