vikramronavrsc commited on
Commit
1169500
Β·
verified Β·
1 Parent(s): 32a645f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +743 -492
app.py CHANGED
@@ -1,593 +1,844 @@
1
- # app.py - Enhanced UI with animations
2
  import os
 
3
  import shutil
 
4
  import streamlit as st
5
  import torch
6
- import atexit
 
 
 
 
 
 
7
  import time
8
- from advanced_rag import AdvancedRAG
9
- from metamask_component import metamask_connector
10
- from voice_component import voice_input_component
 
 
 
11
 
12
- # Custom CSS for enhanced UI
13
- def load_custom_css():
 
 
 
 
 
 
 
14
  st.markdown("""
15
  <style>
16
- /* Main container styling */
17
  .main {
18
  background-color: #f9fafb;
19
  }
20
 
21
  /* Card styling */
22
- .stCard {
23
- border-radius: 12px !important;
24
- box-shadow: 0 6px 16px rgba(0,0,0,0.05) !important;
25
- transition: all 0.3s ease !important;
26
- }
27
- .stCard:hover {
28
- transform: translateY(-2px);
29
- box-shadow: 0 12px 24px rgba(0,0,0,0.08) !important;
30
  }
31
 
32
- /* Chat message styling */
33
- .chat-message {
34
- padding: 16px;
35
- border-radius: 12px;
36
- margin-bottom: 10px;
37
- animation: fadeIn 0.5s ease;
38
- }
39
- .user-message {
40
- background-color: #f0f7ff;
41
- border-left: 5px solid #3b82f6;
42
  }
43
- .assistant-message {
44
- background-color: #f0fdf4;
45
- border-left: 5px solid #22c55e;
 
 
 
 
 
46
  }
47
 
48
- /* Source section styling */
49
  .source-item {
50
- padding: 12px;
51
- border-radius: 8px;
52
- background-color: #f8fafc;
53
- border: 1px solid #e2e8f0;
54
  margin-bottom: 10px;
55
- transition: all 0.2s ease;
56
- }
57
- .source-item:hover {
58
- border-color: #cbd5e1;
59
- background-color: #f1f5f9;
60
  }
 
61
  .source-header {
62
- font-weight: 600;
 
63
  display: flex;
64
  justify-content: space-between;
65
- margin-bottom: 8px;
66
- align-items: center;
67
- }
68
- .source-content {
69
- font-size: 0.9em;
70
- color: #475569;
71
- max-height: 200px;
72
- overflow-y: auto;
73
  }
 
74
  .verified-badge {
75
- background-color: #10b981;
76
  color: white;
77
  padding: 2px 8px;
78
- border-radius: 12px;
79
- font-size: 0.7em;
80
- display: inline-flex;
81
- align-items: center;
82
- gap: 4px;
83
- }
84
-
85
- /* Animated loader */
86
- @keyframes pulse-animation {
87
- 0% { box-shadow: 0 0 0 0 rgba(59, 130, 246, 0.7); }
88
- 70% { box-shadow: 0 0 0 10px rgba(59, 130, 246, 0); }
89
- 100% { box-shadow: 0 0 0 0 rgba(59, 130, 246, 0); }
90
- }
91
- .pulse {
92
- animation: pulse-animation 2s infinite;
93
  }
94
 
95
- /* Fade in animation */
96
- @keyframes fadeIn {
97
- from { opacity: 0; transform: translateY(10px); }
98
- to { opacity: 1; transform: translateY(0); }
 
99
  }
100
 
101
- /* Method selection buttons */
102
  .method-button {
103
- border-radius: 8px;
104
- padding: 8px 16px;
105
- transition: all 0.3s ease;
106
- border: none;
107
  cursor: pointer;
108
- font-weight: 500;
109
- display: inline-flex;
110
- align-items: center;
111
- gap: 8px;
112
  }
113
- .method-direct {
114
- background-color: #e0f2fe;
115
- color: #0284c7;
 
 
116
  }
117
- .method-direct:hover {
118
- background-color: #bae6fd;
 
119
  }
120
- .method-enhanced {
121
- background-color: #dbeafe;
122
- color: #2563eb;
 
 
123
  }
124
- .method-enhanced:hover {
125
- background-color: #bfdbfe;
 
126
  }
 
127
  .method-active {
128
- box-shadow: 0 0 0 2px #3b82f6;
129
  }
130
 
131
- /* Two-column layout for answer and sources */
132
- .answer-container {
133
- border-radius: 12px;
134
- background-color: white;
135
- padding: 20px;
136
- box-shadow: 0 4px 12px rgba(0,0,0,0.05);
137
- margin-bottom: 20px;
138
- }
139
- .answer-header {
140
- margin-bottom: 16px;
141
- color: #1e293b;
142
- font-weight: 600;
143
- font-size: 1.1em;
144
- }
145
- .answer-content {
146
- font-size: 1em;
147
- line-height: 1.6;
148
- color: #334155;
149
  }
150
- .sources-container {
151
- border-radius: 12px;
152
- background-color: white;
153
- padding: 20px;
154
- box-shadow: 0 4px 12px rgba(0,0,0,0.05);
155
  }
156
- .sources-header {
157
- margin-bottom: 16px;
158
- color: #1e293b;
159
- font-weight: 600;
160
- font-size: 1.1em;
161
  }
162
 
163
- /* Section animations */
164
- .animate-section {
165
- animation: fadeIn 0.5s ease;
 
166
  }
167
  </style>
168
  """, unsafe_allow_html=True)
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  # Helper function to initialize session state
171
  def initialize_session_state():
172
- """Initialize Streamlit session state variables."""
173
  if "rag" not in st.session_state:
174
  st.session_state.rag = None
175
  if "messages" not in st.session_state:
176
  st.session_state.messages = []
177
  if "temp_dir" not in st.session_state:
178
  st.session_state.temp_dir = None
179
- if "metamask_connected" not in st.session_state:
180
- st.session_state.metamask_connected = False
 
 
181
  if "retrieval_method" not in st.session_state:
182
  st.session_state.retrieval_method = "enhanced"
183
- if "voice_transcript" not in st.session_state:
184
- st.session_state.voice_transcript = ""
185
  if "current_answer" not in st.session_state:
186
  st.session_state.current_answer = None
187
 
188
  # Helper function to clean up temporary files
189
  def cleanup_temp_files():
190
- """Clean up temporary files when application exits."""
191
  if st.session_state.get('temp_dir') and os.path.exists(st.session_state.temp_dir):
192
  try:
193
  shutil.rmtree(st.session_state.temp_dir)
194
- print(f"Cleaned up temporary directory: {st.session_state.temp_dir}")
195
  except Exception as e:
196
  print(f"Error cleaning up temporary directory: {e}")
197
 
198
- # Create an animated loading spinner
199
- def animated_loader(text="Processing..."):
200
- with st.spinner(text):
201
- # Add a pulsing animation while processing
202
- st.markdown("""
203
- <div style="display: flex; justify-content: center; margin: 20px 0;">
204
- <div class="pulse" style="width: 20px; height: 20px; border-radius: 50%; background-color: #3b82f6;"></div>
205
- </div>
206
- """, unsafe_allow_html=True)
207
-
208
- # Animated section container
209
- def animated_section(key):
210
- return st.container(key=f"animated_{key}")
211
-
212
- # Create a method selection button with animation
213
- def method_button(label, icon, method, current_method):
214
- active_class = "method-active" if method == current_method else ""
215
- method_class = "method-direct" if method == "direct" else "method-enhanced"
216
 
217
- button_html = f"""
218
- <button class="method-button {method_class} {active_class}">
219
- {icon} {label}
220
- </button>
221
- """
222
- return button_html
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- # Streamlit UI
225
  def main():
226
- st.set_page_config(
227
- page_title="Advanced RAG System",
228
- layout="wide",
229
- initial_sidebar_state="expanded"
230
- )
231
-
232
- # Load custom CSS
233
- load_custom_css()
234
-
235
- # Page header with animation
236
- with animated_section("header"):
237
- st.title("πŸš€ Advanced RAG System")
238
- st.markdown("""
239
- <div style="display: flex; gap: 15px; margin-bottom: 20px;">
240
- <div style="background-color: #e0f2fe; color: #0284c7; padding: 8px 16px; border-radius: 20px; font-size: 0.9em; display: flex; align-items: center; gap: 8px;">
241
- <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polygon points="12 2 15.09 8.26 22 9.27 17 14.14 18.18 21.02 12 17.77 5.82 21.02 7 14.14 2 9.27 8.91 8.26 12 2"></polygon></svg>
242
- Document Analysis
243
- </div>
244
- <div style="background-color: #f0fdf4; color: #16a34a; padding: 8px 16px; border-radius: 20px; font-size: 0.9em; display: flex; align-items: center; gap: 8px;">
245
- <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="3" y="11" width="18" height="11" rx="2" ry="2"></rect><path d="M7 11V7a5 5 0 0 1 10 0v4"></path></svg>
246
- Blockchain Verification
247
- </div>
248
- <div style="background-color: #fef2f2; color: #dc2626; padding: 8px 16px; border-radius: 20px; font-size: 0.9em; display: flex; align-items: center; gap: 8px;">
249
- <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 2a3 3 0 0 0-3 3v7a3 3 0 0 0 6 0V5a3 3 0 0 0-3-3Z"></path><path d="M19 10v2a7 7 0 0 1-14 0v-2"></path><line x1="12" y1="19" x2="12" y2="22"></line></svg>
250
- Voice Input
251
- </div>
252
- </div>
253
- """, unsafe_allow_html=True)
254
 
255
  # Initialize session state
256
  initialize_session_state()
257
 
258
- # Sidebar for configuration and file upload
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  with st.sidebar:
260
- with animated_section("sidebar_header"):
261
- st.header("System Configuration")
262
- st.markdown("""
263
- <div style="margin-bottom: 15px; padding: 10px; border-radius: 8px; background-color: #f1f5f9; border-left: 4px solid #3b82f6;">
264
- Configure your RAG system and upload documents
265
- </div>
266
- """, unsafe_allow_html=True)
267
 
268
- # MetaMask Connection
269
- with animated_section("metamask"):
270
- st.subheader("🦊 MetaMask Connection")
271
-
272
- # Add MetaMask connector and get connection info
273
- metamask_info = metamask_connector()
274
-
275
- # Display MetaMask connection status
276
- if metamask_info and metamask_info.get("connected"):
277
- st.success(f"βœ… Connected: {metamask_info.get('address')[:10]}...{metamask_info.get('address')[-6:]}")
278
- st.info(f"Network: {metamask_info.get('network_name')}")
279
- st.session_state.metamask_connected = True
280
- else:
281
- st.warning("⚠️ MetaMask not connected")
282
- st.session_state.metamask_connected = False
283
-
284
- # Update RAG system with MetaMask connection if needed
285
- if st.session_state.rag and metamask_info:
286
- is_connected = st.session_state.rag.update_blockchain_connection(metamask_info)
287
- if is_connected:
288
- st.success("RAG system updated with MetaMask connection")
289
 
290
- # System Configuration
291
- with animated_section("config"):
292
- st.subheader("βš™οΈ System Configuration")
293
-
294
- # GPU Detection
295
- gpu_available = torch.cuda.is_available()
296
- if gpu_available:
297
- try:
298
- gpu_info = torch.cuda.get_device_properties(0)
299
- st.markdown(f"""
300
- <div style="display: flex; align-items: center; gap: 8px; padding: 8px 12px; background-color: #ecfdf5; border-radius: 8px; margin-bottom: 15px;">
301
- <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="#10b981" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M17 18a5 5 0 0 1-10 0"></path><line x1="12" y1="2" x2="12" y2="9"></line><line x1="4.22" y1="10.22" x2="5.64" y2="11.64"></line><line x1="1" y1="18" x2="3" y2="18"></line><line x1="21" y1="18" x2="23" y2="18"></line><line x1="18.36" y1="11.64" x2="19.78" y2="10.22"></line><line x1="23" y1="22" x2="1" y2="22"></line><polyline points="8 6 12 2 16 6"></polyline></svg>
302
- <span style="color: #10b981; font-weight: 500;">GPU: {gpu_info.name} ({gpu_info.total_memory / 1024**3:.1f} GB)</span>
303
- </div>
304
- """, unsafe_allow_html=True)
305
- except Exception as e:
306
- st.warning(f"GPU detected but couldn't get properties")
307
- else:
308
- st.warning("No GPU detected. Running in CPU mode.")
309
-
310
- # Model selection
311
- llm_model = st.selectbox(
312
- "LLM Model",
313
- options=[
314
- "mistralai/Mistral-7B-Instruct-v0.2",
315
- "google/gemma-7b-it",
316
- "google/flan-t5-xl",
317
- "Salesforce/xgen-7b-8k-inst",
318
- "tiiuae/falcon-7b-instruct"
319
- ],
320
- index=0
321
- )
322
-
323
- embedding_model = st.selectbox(
324
- "Embedding Model",
325
- options=[
326
- "sentence-transformers/all-mpnet-base-v2",
327
- "sentence-transformers/all-MiniLM-L6-v2",
328
- "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
329
- ],
330
- index=1
331
- )
332
-
333
- use_gpu = st.checkbox("Use GPU Acceleration", value=gpu_available)
334
-
335
- # Blockchain configuration
336
- use_blockchain = st.checkbox("Enable Blockchain Verification", value=True)
337
-
338
- if use_blockchain:
339
- # Hardcoded contract address - replace with your deployed contract
340
- contract_address = os.environ.get("CONTRACT_ADDRESS", "0x123abc...") # Your pre-deployed contract
341
-
342
- st.info(f"Using contract: {contract_address[:10]}...")
343
-
344
- # Advanced options
345
- with st.expander("Advanced Options"):
346
- chunk_size = st.slider("Chunk Size", 100, 2000, 1000)
347
- chunk_overlap = st.slider("Chunk Overlap", 0, 500, 200)
348
-
349
- # Initialize button with animation
350
- if st.button("Initialize System", key="init_button"):
351
- with st.spinner("Initializing..."):
352
- animated_loader("Setting up RAG system...")
353
-
354
- if use_blockchain and not contract_address:
355
- st.error("Contract address is required for blockchain integration")
356
- else:
357
- st.session_state.rag = AdvancedRAG(
358
- llm_model_name=llm_model,
359
- embedding_model_name=embedding_model,
360
- chunk_size=chunk_size,
361
- chunk_overlap=chunk_overlap,
362
- use_gpu=use_gpu and gpu_available,
363
- use_blockchain=use_blockchain,
364
- contract_address=contract_address if use_blockchain else None
365
- )
366
-
367
- # Update with current MetaMask connection if available
368
- if use_blockchain and metamask_info:
369
- st.session_state.rag.update_blockchain_connection(metamask_info)
370
-
371
- st.success(f"System initialized with {embedding_model}")
372
 
373
- # Document Upload
374
- with animated_section("upload"):
375
- st.subheader("πŸ“„ Document Upload")
376
- uploaded_files = st.file_uploader("Select PDFs", type="pdf", accept_multiple_files=True)
377
-
378
- if uploaded_files and st.button("Process PDFs", key="process_button"):
379
- if not st.session_state.rag:
380
- with st.spinner("Initializing system first..."):
381
- animated_loader("Setting up RAG system...")
382
-
383
- st.session_state.rag = AdvancedRAG(
384
- llm_model_name=llm_model,
385
- embedding_model_name=embedding_model,
386
- chunk_size=chunk_size,
387
- chunk_overlap=chunk_overlap,
388
- use_gpu=use_gpu and gpu_available,
389
- use_blockchain=use_blockchain,
390
- contract_address=contract_address if use_blockchain else None
391
- )
392
-
393
- # Update with current MetaMask connection if available
394
- if use_blockchain and metamask_info:
395
- st.session_state.rag.update_blockchain_connection(metamask_info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
- with st.spinner("Processing documents..."):
398
- animated_loader("Analyzing and indexing PDFs...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
- success = st.session_state.rag.process_pdfs(uploaded_files)
401
- if success:
402
- metrics = st.session_state.rag.get_performance_metrics()
403
- if metrics:
404
- st.success("πŸ“„ PDFs processed successfully!")
405
- with st.expander("πŸ’Ή Performance Metrics"):
406
- st.markdown(f"**Documents processed:** {metrics['documents_processed']} chunks")
407
- st.markdown(f"**Index building time:** {metrics['index_building_time']:.2f} seconds")
408
- st.markdown(f"**Total processing time:** {metrics['total_processing_time']:.2f} seconds")
409
 
410
- # Main content area - Two column layout
411
- main_col1, main_col2 = st.columns([2, 1])
 
412
 
413
- # Left column - Chat and Answer section
414
- with main_col1:
415
- # Method Selection
416
- with animated_section("method_selection"):
417
- st.markdown("### Answer Method")
418
- col1, col2 = st.columns(2)
419
-
420
- with col1:
421
- direct_html = method_button(
422
- "Direct Retrieval",
423
- '<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="11" cy="11" r="8"></circle><line x1="21" y1="21" x2="16.65" y2="16.65"></line></svg>',
424
- "direct",
425
- st.session_state.retrieval_method
426
- )
427
- if st.markdown(direct_html, unsafe_allow_html=True):
428
- st.session_state.retrieval_method = "direct"
429
- st.rerun()
430
-
431
- with col2:
432
- enhanced_html = method_button(
433
- "Enhanced Answers",
434
- '<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polygon points="13 2 3 14 12 14 11 22 21 10 12 10 13 2"></polygon></svg>',
435
- "enhanced",
436
- st.session_state.retrieval_method
437
- )
438
- if st.markdown(enhanced_html, unsafe_allow_html=True):
439
- st.session_state.retrieval_method = "enhanced"
440
- st.rerun()
441
 
442
- # Show current method description
443
- if st.session_state.retrieval_method == "direct":
444
- st.info("πŸ” **Direct Retrieval**: Shows raw document passages without processing. Fast and transparent.")
445
- else:
446
- st.info("πŸ’‘ **Enhanced Answers**: Processes content through AI for comprehensive answers. Better quality.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
- # Voice Input Section
449
- with animated_section("voice_input"):
450
- st.markdown("### Ask with Voice")
451
- voice_transcript = voice_input_component()
452
-
453
- # Update session state with voice transcript if not empty
454
- if voice_transcript and voice_transcript.strip():
455
- st.session_state.voice_transcript = voice_transcript.strip()
456
- st.experimental_rerun()
457
 
458
- # Text Input Section
459
- with animated_section("text_input"):
460
- st.markdown("### Or Type a Question")
461
- # Chat input - show the voice transcript in the text input
462
- user_input = st.text_input(
463
- "Ask a question about your documents",
464
- value=st.session_state.voice_transcript,
465
- key="text_question"
466
- )
467
 
468
- # Process user input (from text or voice)
469
- if user_input or st.session_state.voice_transcript:
470
- # Prioritize text input over voice input
471
- if user_input:
472
- query = user_input
473
- else:
474
- query = st.session_state.voice_transcript
475
- # Clear voice transcript after using it
476
- st.session_state.voice_transcript = ""
477
-
478
- # Add user message to chat history
479
- st.session_state.messages.append({"role": "user", "content": query})
480
-
481
- # Check if system is initialized
482
- if not st.session_state.rag:
483
- st.error("Please initialize the system and process PDFs first.")
484
- st.session_state.messages.append({
485
- "role": "assistant",
486
- "content": "Please initialize the system and process PDFs first."
487
- })
488
 
489
- # Get response if vector store is ready
490
- elif st.session_state.rag.vector_store:
491
- with st.spinner("Generating answer..."):
492
- animated_loader("Searching documents and generating answer...")
493
-
494
- # Get retrieval method
495
- method = st.session_state.retrieval_method
496
-
497
- # Get response using specified method
498
- response = st.session_state.rag.ask(query, method=method)
499
- st.session_state.messages.append({"role": "assistant", "content": response})
500
-
501
- # Store current answer for display
502
- st.session_state.current_answer = response
503
-
504
- # Rerun to update the UI
505
- st.experimental_rerun()
506
- else:
507
- st.error("Please upload and process PDF files first.")
508
- st.session_state.messages.append({
509
- "role": "assistant",
510
- "content": "Please upload and process PDF files first."
511
- })
512
 
513
- # Answer Display Section
514
  if st.session_state.current_answer and isinstance(st.session_state.current_answer, dict):
515
- with animated_section("answer_display"):
516
- answer = st.session_state.current_answer
517
-
518
- st.markdown("""
519
- <div class="answer-container animate-section">
520
- <div class="answer-header">
521
- <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="margin-right: 8px;"><circle cx="12" cy="12" r="10"></circle><path d="M9.09 9a3 3 0 0 1 5.83 1c0 2-3 3-3 3"></path><line x1="12" y1="17" x2="12.01" y2="17"></line></svg>
522
- Answer
523
- </div>
524
- <div class="answer-content">
525
- {answer_text}
526
- </div>
527
  </div>
528
- """.format(answer_text=answer["answer"]), unsafe_allow_html=True)
529
-
530
- # Display metadata
531
- meta_cols = st.columns(3)
532
- with meta_cols[0]:
533
- method_name = "Direct Retrieval" if answer["method"] == "direct" else "Enhanced Answer"
534
- st.caption(f"Method: {method_name}")
535
- with meta_cols[1]:
536
- st.caption(f"Time: {answer['query_time']:.2f} seconds")
537
- with meta_cols[2]:
538
- if "blockchain_log" in answer and answer["blockchain_log"]:
539
- blockchain_log = answer["blockchain_log"]
540
- st.caption(f"πŸ“ Logged on blockchain: {blockchain_log['tx_hash'][:8]}...")
 
541
 
542
- # Right column - Sources section
543
- with main_col2:
 
 
544
  if st.session_state.current_answer and isinstance(st.session_state.current_answer, dict):
545
- with animated_section("sources_display"):
546
- answer = st.session_state.current_answer
547
-
548
- st.markdown("""
549
- <div class="sources-container animate-section">
550
- <div class="sources-header">
551
- <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="margin-right: 8px;"><path d="M2 3h6a4 4 0 0 1 4 4v14a3 3 0 0 0-3-3H2z"></path><path d="M22 3h-6a4 4 0 0 0-4 4v14a3 3 0 0 1 3-3h7z"></path></svg>
552
- Sources
553
- </div>
554
- """, unsafe_allow_html=True)
555
-
556
- # Display sources
557
- if "sources" in answer and answer["sources"]:
558
- for i, source in enumerate(answer["sources"]):
559
- verified_badge = ""
560
- if source.get("blockchain"):
561
- verified_badge = f"""
562
- <div class="verified-badge">
563
- <svg xmlns="http://www.w3.org/2000/svg" width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M22 11.08V12a10 10 0 1 1-5.93-9.14"></path><polyline points="22 4 12 14.01 9 11.01"></polyline></svg>
564
- Verified
565
- </div>
566
- """
567
-
568
- st.markdown(f"""
569
- <div class="source-item">
570
- <div class="source-header">
571
- <div>Source {i+1}: {source['source']}</div>
572
- {verified_badge}
573
- </div>
574
- <div class="source-content">
575
- {source['content']}
576
- </div>
577
  </div>
578
- """, unsafe_allow_html=True)
579
-
580
- st.markdown("</div>", unsafe_allow_html=True)
 
 
 
 
581
  else:
582
- # Placeholder when no sources to display
583
- st.markdown("""
584
- <div style="height: 300px; display: flex; justify-content: center; align-items: center; background-color: white; border-radius: 12px; margin-top: 30px;">
585
- <div style="text-align: center; color: #94a3b8;">
586
- <svg xmlns="http://www.w3.org/2000/svg" width="40" height="40" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="margin: 0 auto 15px;"><circle cx="12" cy="12" r="10"></circle><path d="M9.09 9a3 3 0 0 1 5.83 1c0 2-3 3-3 3"></path><line x1="12" y1="17" x2="12.01" y2="17"></line></svg>
587
- <p>Ask a question to see document sources here</p>
588
- </div>
589
- </div>
590
- """, unsafe_allow_html=True)
591
 
592
  # Main entry point
593
  if __name__ == "__main__":
 
1
+ # app.py - Optimized for Hugging Face Spaces
2
  import os
3
+ import tempfile
4
  import shutil
5
+ import PyPDF2
6
  import streamlit as st
7
  import torch
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain.chains import RetrievalQA
11
+ from langchain.docstore.document import Document
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain_huggingface import HuggingFaceEmbeddings
14
+ from langchain_community.llms import HuggingFaceHub
15
  import time
16
+ import psutil
17
+ import uuid
18
+ import atexit
19
+ import json
20
+ import hashlib
21
+ from web3 import Web3
22
 
23
+ # Set page configuration
24
+ st.set_page_config(
25
+ page_title="RAG System",
26
+ layout="wide",
27
+ initial_sidebar_state="expanded"
28
+ )
29
+
30
+ # Custom CSS for better UI
31
+ def load_css():
32
  st.markdown("""
33
  <style>
34
+ /* Main layout styling */
35
  .main {
36
  background-color: #f9fafb;
37
  }
38
 
39
  /* Card styling */
40
+ .card {
41
+ border-radius: 10px;
42
+ background-color: white;
43
+ padding: 20px;
44
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
45
+ margin-bottom: 20px;
 
 
46
  }
47
 
48
+ /* Two-column layout */
49
+ .answer-section {
50
+ background-color: white;
51
+ border-radius: 10px;
52
+ padding: 20px;
53
+ margin-bottom: 15px;
54
+ border-left: 4px solid #4CAF50;
55
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.05);
 
 
56
  }
57
+
58
+ .sources-section {
59
+ background-color: white;
60
+ border-radius: 10px;
61
+ padding: 15px;
62
+ margin-bottom: 15px;
63
+ border-left: 4px solid #2196F3;
64
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.05);
65
  }
66
 
 
67
  .source-item {
68
+ padding: 10px;
69
+ border-radius: 5px;
70
+ background-color: #f8f9fa;
 
71
  margin-bottom: 10px;
72
+ border: 1px solid #eee;
 
 
 
 
73
  }
74
+
75
  .source-header {
76
+ font-weight: bold;
77
+ margin-bottom: 5px;
78
  display: flex;
79
  justify-content: space-between;
 
 
 
 
 
 
 
 
80
  }
81
+
82
  .verified-badge {
83
+ background-color: #4CAF50;
84
  color: white;
85
  padding: 2px 8px;
86
+ border-radius: 10px;
87
+ font-size: 0.8em;
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  }
89
 
90
+ /* Method selection styling */
91
+ .method-container {
92
+ display: flex;
93
+ gap: 10px;
94
+ margin-bottom: 15px;
95
  }
96
 
 
97
  .method-button {
98
+ flex: 1;
99
+ text-align: center;
100
+ padding: 10px;
101
+ border-radius: 5px;
102
  cursor: pointer;
103
+ transition: all 0.3s;
 
 
 
104
  }
105
+
106
+ .direct-method {
107
+ background-color: #e3f2fd;
108
+ border: 1px solid #bbdefb;
109
+ color: #1976D2;
110
  }
111
+
112
+ .direct-method:hover {
113
+ background-color: #bbdefb;
114
  }
115
+
116
+ .enhanced-method {
117
+ background-color: #e8f5e9;
118
+ border: 1px solid #c8e6c9;
119
+ color: #388E3C;
120
  }
121
+
122
+ .enhanced-method:hover {
123
+ background-color: #c8e6c9;
124
  }
125
+
126
  .method-active {
127
+ box-shadow: 0 0 0 2px #3f51b5;
128
  }
129
 
130
+ /* Voice button styling */
131
+ .voice-button {
132
+ width: 50px;
133
+ height: 50px;
134
+ border-radius: 50%;
135
+ background-color: #f44336;
136
+ color: white;
137
+ display: flex;
138
+ align-items: center;
139
+ justify-content: center;
140
+ cursor: pointer;
141
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.2);
142
+ transition: all 0.3s;
143
+ margin: 0 auto;
 
 
 
 
144
  }
145
+
146
+ .voice-button:hover {
147
+ transform: scale(1.05);
148
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.3);
 
149
  }
150
+
151
+ /* Header styling */
152
+ h1, h2, h3 {
153
+ color: #333;
 
154
  }
155
 
156
+ /* Button styling */
157
+ .stButton>button {
158
+ border-radius: 5px;
159
+ font-weight: 500;
160
  }
161
  </style>
162
  """, unsafe_allow_html=True)
163
 
164
+ # Simple blockchain utility
165
+ class BlockchainVerifier:
166
+ def __init__(self, contract_address=None):
167
+ self.contract_address = contract_address
168
+ self.is_connected = False
169
+ self.user_address = None
170
+
171
+ def connect_wallet(self, wallet_address):
172
+ """Simulate connecting to a wallet"""
173
+ self.is_connected = True
174
+ self.user_address = wallet_address
175
+ return True
176
+
177
+ def compute_file_hash(self, file_path):
178
+ """Compute SHA-256 hash of file"""
179
+ sha256_hash = hashlib.sha256()
180
+ with open(file_path, "rb") as f:
181
+ for byte_block in iter(lambda: f.read(4096), b""):
182
+ sha256_hash.update(byte_block)
183
+ return sha256_hash.hexdigest()
184
+
185
+ def verify_document(self, document_id, file_path):
186
+ """Simulate document verification on blockchain"""
187
+ if not self.is_connected:
188
+ return {"status": False, "error": "Wallet not connected"}
189
+
190
+ # Calculate hash
191
+ document_hash = self.compute_file_hash(file_path)
192
+
193
+ # Simulate transaction
194
+ tx_hash = "0x" + "".join([format(i, "02x") for i in os.urandom(32)])
195
+
196
+ return {
197
+ "status": True,
198
+ "tx_hash": tx_hash,
199
+ "document_id": document_id,
200
+ "document_hash": document_hash,
201
+ "block_number": 12345678
202
+ }
203
+
204
+ def log_query(self, query_text, answer_text):
205
+ """Simulate logging a query on blockchain"""
206
+ if not self.is_connected:
207
+ return {"status": False, "error": "Wallet not connected"}
208
+
209
+ # Create query data and hash
210
+ query_id = f"query_{int(time.time())}"
211
+ query_data = {
212
+ "query": query_text,
213
+ "answer": answer_text,
214
+ "timestamp": int(time.time())
215
+ }
216
+ query_hash = hashlib.sha256(json.dumps(query_data).encode()).hexdigest()
217
+
218
+ # Simulate transaction
219
+ tx_hash = "0x" + "".join([format(i, "02x") for i in os.urandom(32)])
220
+
221
+ return {
222
+ "status": True,
223
+ "tx_hash": tx_hash,
224
+ "query_id": query_id,
225
+ "query_hash": query_hash,
226
+ "block_number": 12345678
227
+ }
228
+
229
+ # RAG System Class
230
+ class OptimizedRAG:
231
+ def __init__(self,
232
+ llm_model_name="google/flan-t5-base",
233
+ embedding_model_name="sentence-transformers/all-MiniLM-L6-v2",
234
+ chunk_size=1000,
235
+ chunk_overlap=200,
236
+ use_gpu=True,
237
+ use_blockchain=False,
238
+ contract_address=None):
239
+ """
240
+ Initialize the RAG system optimized for Hugging Face Spaces
241
+ """
242
+ self.llm_model_name = llm_model_name
243
+ self.embedding_model_name = embedding_model_name
244
+ self.use_gpu = use_gpu and torch.cuda.is_available()
245
+ self.use_blockchain = use_blockchain
246
+
247
+ # Device selection for embeddings
248
+ self.device = "cuda" if self.use_gpu else "cpu"
249
+
250
+ # Initialize text splitter
251
+ self.text_splitter = RecursiveCharacterTextSplitter(
252
+ chunk_size=chunk_size,
253
+ chunk_overlap=chunk_overlap,
254
+ length_function=len,
255
+ )
256
+
257
+ # Initialize embeddings model
258
+ self.embeddings = HuggingFaceEmbeddings(
259
+ model_name=embedding_model_name,
260
+ model_kwargs={"device": self.device}
261
+ )
262
+
263
+ # Initialize LLM using HuggingFaceHub
264
+ try:
265
+ # Use HF_TOKEN from environment variables
266
+ hf_token = os.environ.get("HF_TOKEN")
267
+ if not hf_token:
268
+ st.warning("No HuggingFace token found. Using model without authentication.")
269
+
270
+ self.llm = HuggingFaceHub(
271
+ repo_id=llm_model_name,
272
+ huggingfacehub_api_token=hf_token,
273
+ model_kwargs={"temperature": 0.7, "max_length": 512}
274
+ )
275
+ except Exception as e:
276
+ st.error(f"Error initializing LLM: {str(e)}")
277
+ st.info("Trying to initialize with default model...")
278
+ # Fallback to a smaller model
279
+ self.llm = HuggingFaceHub(
280
+ repo_id="google/flan-t5-small",
281
+ model_kwargs={"temperature": 0.7, "max_length": 256}
282
+ )
283
+
284
+ # Initialize vector store and stats
285
+ self.vector_store = None
286
+ self.documents_processed = 0
287
+ self.processing_times = {}
288
+
289
+ # Initialize blockchain verifier
290
+ self.blockchain = None
291
+ if use_blockchain:
292
+ self.blockchain = BlockchainVerifier(contract_address=contract_address)
293
+
294
+ def connect_wallet(self, wallet_address):
295
+ """Connect wallet for blockchain verification"""
296
+ if self.blockchain:
297
+ return self.blockchain.connect_wallet(wallet_address)
298
+ return False
299
+
300
+ def process_pdfs(self, pdf_files):
301
+ """Process PDF files and create vector store"""
302
+ all_docs = []
303
+
304
+ with st.status("Processing PDF files...") as status:
305
+ # Create temporary directory
306
+ temp_dir = tempfile.mkdtemp()
307
+ st.session_state['temp_dir'] = temp_dir
308
+
309
+ # Track processing stats
310
+ start_time = time.time()
311
+ mem_before = psutil.virtual_memory().used / (1024 * 1024 * 1024) # GB
312
+
313
+ # Process each PDF
314
+ for i, pdf_file in enumerate(pdf_files):
315
+ try:
316
+ # Save uploaded file
317
+ pdf_path = os.path.join(temp_dir, pdf_file.name)
318
+ with open(pdf_path, "wb") as f:
319
+ f.write(pdf_file.getbuffer())
320
+
321
+ status.update(label=f"Processing {pdf_file.name} ({i+1}/{len(pdf_files)})...")
322
+
323
+ # Extract text from PDF
324
+ text = ""
325
+ with open(pdf_path, "rb") as f:
326
+ pdf = PyPDF2.PdfReader(f)
327
+ for page_num in range(len(pdf.pages)):
328
+ page = pdf.pages[page_num]
329
+ page_text = page.extract_text()
330
+ if page_text:
331
+ text += page_text + "\n\n"
332
+
333
+ # Create and split documents
334
+ docs = [Document(page_content=text, metadata={"source": pdf_file.name})]
335
+ split_docs = self.text_splitter.split_documents(docs)
336
+ all_docs.extend(split_docs)
337
+
338
+ # Verify on blockchain if enabled
339
+ if self.use_blockchain and self.blockchain and self.blockchain.is_connected:
340
+ document_id = f"{pdf_file.name}_{uuid.uuid4().hex[:8]}"
341
+ verification = self.blockchain.verify_document(document_id, pdf_path)
342
+
343
+ if verification.get('status'):
344
+ st.sidebar.success(f"βœ… {pdf_file.name} verified on blockchain")
345
+
346
+ # Add blockchain metadata
347
+ for doc in split_docs:
348
+ doc.metadata["blockchain"] = {
349
+ "verified": True,
350
+ "document_id": document_id,
351
+ "document_hash": verification.get("document_hash", ""),
352
+ "tx_hash": verification.get("tx_hash", ""),
353
+ "block_number": verification.get("block_number", 0)
354
+ }
355
+
356
+ except Exception as e:
357
+ st.sidebar.error(f"Error processing {pdf_file.name}: {str(e)}")
358
+
359
+ # Create vector store
360
+ if all_docs:
361
+ status.update(label="Building vector index...")
362
+ try:
363
+ index_start_time = time.time()
364
+ self.vector_store = FAISS.from_documents(all_docs, self.embeddings)
365
+ index_time = time.time() - index_start_time
366
+
367
+ # Track memory usage
368
+ mem_after = psutil.virtual_memory().used / (1024 * 1024 * 1024)
369
+ mem_used = mem_after - mem_before
370
+
371
+ # Save performance metrics
372
+ total_time = time.time() - start_time
373
+ self.processing_times["index_building"] = index_time
374
+ self.processing_times["total_time"] = total_time
375
+ self.processing_times["memory_used_gb"] = mem_used
376
+ self.documents_processed = len(all_docs)
377
+
378
+ status.update(label=f"Completed processing {len(all_docs)} chunks", state="complete")
379
+ return True
380
+ except Exception as e:
381
+ st.error(f"Error creating vector store: {str(e)}")
382
+ return False
383
+ else:
384
+ status.update(label="No content extracted from PDFs", state="error")
385
+ return False
386
+
387
+ def direct_retrieval(self, query):
388
+ """Direct retrieval method - returns raw document chunks"""
389
+ if not self.vector_store:
390
+ return "Please upload and process PDF files first."
391
+
392
+ try:
393
+ # Start timing
394
+ query_start_time = time.time()
395
+
396
+ # Retrieve relevant documents
397
+ retriever = self.vector_store.as_retriever(search_kwargs={"k": 5})
398
+ docs = retriever.get_relevant_documents(query)
399
+
400
+ # Format sources and answer
401
+ sources = []
402
+ answer = "Here are the most relevant passages:\n\n"
403
+
404
+ for i, doc in enumerate(docs):
405
+ # Get blockchain info if available
406
+ blockchain_info = None
407
+ if "blockchain" in doc.metadata:
408
+ blockchain_info = {
409
+ "verified": doc.metadata["blockchain"]["verified"],
410
+ "document_id": doc.metadata["blockchain"]["document_id"],
411
+ "tx_hash": doc.metadata["blockchain"]["tx_hash"]
412
+ }
413
+
414
+ # Add to answer and sources
415
+ answer += f"Passage {i+1} (from {doc.metadata.get('source', 'Unknown')}):\n{doc.page_content}\n\n"
416
+ sources.append({
417
+ "content": doc.page_content,
418
+ "source": doc.metadata.get("source", "Unknown"),
419
+ "blockchain": blockchain_info
420
+ })
421
+
422
+ # Calculate query time
423
+ query_time = time.time() - query_start_time
424
+
425
+ # Log query to blockchain if enabled
426
+ blockchain_log = None
427
+ if self.use_blockchain and self.blockchain and self.blockchain.is_connected:
428
+ log_result = self.blockchain.log_query(query, answer)
429
+ if log_result.get("status"):
430
+ blockchain_log = {
431
+ "logged": True,
432
+ "query_id": log_result.get("query_id", ""),
433
+ "tx_hash": log_result.get("tx_hash", "")
434
+ }
435
+
436
+ return {
437
+ "answer": answer,
438
+ "sources": sources,
439
+ "query_time": query_time,
440
+ "blockchain_log": blockchain_log,
441
+ "method": "direct"
442
+ }
443
+
444
+ except Exception as e:
445
+ st.error(f"Error in direct retrieval: {str(e)}")
446
+ return f"Error: {str(e)}"
447
+
448
+ def enhanced_retrieval(self, query):
449
+ """Enhanced retrieval - processes through LLM for better answers"""
450
+ if not self.vector_store:
451
+ return "Please upload and process PDF files first."
452
+
453
+ try:
454
+ # Create prompt template
455
+ prompt_template = """
456
+ Answer the question based on the context below.
457
+
458
+ Context:
459
+ {context}
460
+
461
+ Question: {question}
462
+
463
+ Answer:
464
+ """
465
+ PROMPT = PromptTemplate(
466
+ template=prompt_template,
467
+ input_variables=["context", "question"]
468
+ )
469
+
470
+ # Start timing
471
+ query_start_time = time.time()
472
+
473
+ # Create QA chain
474
+ qa = RetrievalQA.from_chain_type(
475
+ llm=self.llm,
476
+ chain_type="stuff",
477
+ retriever=self.vector_store.as_retriever(search_kwargs={"k": 4}),
478
+ chain_type_kwargs={"prompt": PROMPT},
479
+ return_source_documents=True
480
+ )
481
+
482
+ # Get answer
483
+ response = qa({"query": query})
484
+ answer = response["result"]
485
+ source_docs = response["source_documents"]
486
+
487
+ # Calculate query time
488
+ query_time = time.time() - query_start_time
489
+
490
+ # Format sources
491
+ sources = []
492
+ for i, doc in enumerate(source_docs):
493
+ # Get blockchain info if available
494
+ blockchain_info = None
495
+ if "blockchain" in doc.metadata:
496
+ blockchain_info = {
497
+ "verified": doc.metadata["blockchain"]["verified"],
498
+ "document_id": doc.metadata["blockchain"]["document_id"],
499
+ "tx_hash": doc.metadata["blockchain"]["tx_hash"]
500
+ }
501
+
502
+ sources.append({
503
+ "content": doc.page_content,
504
+ "source": doc.metadata.get("source", "Unknown"),
505
+ "blockchain": blockchain_info
506
+ })
507
+
508
+ # Log query to blockchain if enabled
509
+ blockchain_log = None
510
+ if self.use_blockchain and self.blockchain and self.blockchain.is_connected:
511
+ log_result = self.blockchain.log_query(query, answer)
512
+ if log_result.get("status"):
513
+ blockchain_log = {
514
+ "logged": True,
515
+ "query_id": log_result.get("query_id", ""),
516
+ "tx_hash": log_result.get("tx_hash", "")
517
+ }
518
+
519
+ return {
520
+ "answer": answer,
521
+ "sources": sources,
522
+ "query_time": query_time,
523
+ "blockchain_log": blockchain_log,
524
+ "method": "enhanced"
525
+ }
526
+
527
+ except Exception as e:
528
+ st.error(f"Error in enhanced retrieval: {str(e)}")
529
+ return f"Error: {str(e)}"
530
+
531
+ def ask(self, query, method="enhanced"):
532
+ """Ask a question using the specified method"""
533
+ if method == "direct":
534
+ return self.direct_retrieval(query)
535
+ else:
536
+ return self.enhanced_retrieval(query)
537
+
538
  # Helper function to initialize session state
539
  def initialize_session_state():
540
+ """Initialize Streamlit session state variables"""
541
  if "rag" not in st.session_state:
542
  st.session_state.rag = None
543
  if "messages" not in st.session_state:
544
  st.session_state.messages = []
545
  if "temp_dir" not in st.session_state:
546
  st.session_state.temp_dir = None
547
+ if "wallet_connected" not in st.session_state:
548
+ st.session_state.wallet_connected = False
549
+ if "wallet_address" not in st.session_state:
550
+ st.session_state.wallet_address = None
551
  if "retrieval_method" not in st.session_state:
552
  st.session_state.retrieval_method = "enhanced"
 
 
553
  if "current_answer" not in st.session_state:
554
  st.session_state.current_answer = None
555
 
556
  # Helper function to clean up temporary files
557
  def cleanup_temp_files():
558
+ """Clean up temporary files when application exits"""
559
  if st.session_state.get('temp_dir') and os.path.exists(st.session_state.temp_dir):
560
  try:
561
  shutil.rmtree(st.session_state.temp_dir)
 
562
  except Exception as e:
563
  print(f"Error cleaning up temporary directory: {e}")
564
 
565
+ # Create a simple wallet connector UI
566
+ def wallet_connector():
567
+ st.sidebar.subheader("πŸ”— Blockchain Connection")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
 
569
+ if st.session_state.wallet_connected:
570
+ st.sidebar.success(f"βœ… Connected: {st.session_state.wallet_address[:10]}...")
571
+ if st.sidebar.button("Disconnect Wallet"):
572
+ st.session_state.wallet_connected = False
573
+ st.session_state.wallet_address = None
574
+ st.rerun()
575
+ else:
576
+ st.sidebar.info("Connect wallet to verify documents on blockchain")
577
+ if st.sidebar.button("Connect Wallet"):
578
+ # Generate a mock wallet address
579
+ wallet_address = "0x" + "".join([format(i, "02x") for i in os.urandom(20)])
580
+ st.session_state.wallet_address = wallet_address
581
+ st.session_state.wallet_connected = True
582
+
583
+ # Connect to RAG system if initialized
584
+ if st.session_state.rag:
585
+ st.session_state.rag.connect_wallet(wallet_address)
586
+
587
+ st.rerun()
588
 
589
+ # Main application UI
590
  def main():
591
+ # Load CSS
592
+ load_css()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
 
594
  # Initialize session state
595
  initialize_session_state()
596
 
597
+ # Page header
598
+ st.title("πŸ“š Advanced RAG System")
599
+ st.markdown("""
600
+ <div style="display: flex; gap: 10px; margin-bottom: 20px;">
601
+ <div style="background-color: #e3f2fd; padding: 5px 10px; border-radius: 15px; font-size: 0.9em;">
602
+ πŸ“„ Document Analysis
603
+ </div>
604
+ <div style="background-color: #e8f5e9; padding: 5px 10px; border-radius: 15px; font-size: 0.9em;">
605
+ πŸ”— Blockchain Verification
606
+ </div>
607
+ <div style="background-color: #fff3e0; padding: 5px 10px; border-radius: 15px; font-size: 0.9em;">
608
+ 🎀 Voice Input
609
+ </div>
610
+ </div>
611
+ """, unsafe_allow_html=True)
612
+
613
+ # Sidebar for configuration
614
  with st.sidebar:
615
+ # Wallet connector
616
+ wallet_connector()
 
 
 
 
 
617
 
618
+ # System configuration
619
+ st.sidebar.subheader("βš™οΈ System Configuration")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
+ # GPU Detection
622
+ gpu_available = torch.cuda.is_available()
623
+ if gpu_available:
624
+ st.sidebar.success(f"GPU detected and available")
625
+ else:
626
+ st.sidebar.warning("No GPU detected. Running in CPU mode.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627
 
628
+ # Model selection with faster models
629
+ llm_model = st.sidebar.selectbox(
630
+ "LLM Model",
631
+ options=[
632
+ "google/flan-t5-base",
633
+ "google/flan-t5-small",
634
+ "distilbert/distilgpt2",
635
+ "google/flan-ul2"
636
+ ],
637
+ index=0
638
+ )
639
+
640
+ embedding_model = st.sidebar.selectbox(
641
+ "Embedding Model",
642
+ options=[
643
+ "sentence-transformers/all-MiniLM-L6-v2",
644
+ "sentence-transformers/paraphrase-MiniLM-L3-v2",
645
+ "sentence-transformers/all-mpnet-base-v2"
646
+ ],
647
+ index=0
648
+ )
649
+
650
+ use_gpu = st.sidebar.checkbox("Use GPU Acceleration", value=gpu_available)
651
+ use_blockchain = st.sidebar.checkbox("Enable Blockchain", value=True)
652
+
653
+ # Contract address - hardcoded for simplicity
654
+ contract_address = "0x123abc..." # Your pre-deployed contract
655
+
656
+ # Initialize button
657
+ if st.sidebar.button("Initialize System"):
658
+ with st.spinner("Setting up RAG system..."):
659
+ st.session_state.rag = OptimizedRAG(
660
+ llm_model_name=llm_model,
661
+ embedding_model_name=embedding_model,
662
+ chunk_size=1000,
663
+ chunk_overlap=200,
664
+ use_gpu=use_gpu and gpu_available,
665
+ use_blockchain=use_blockchain,
666
+ contract_address=contract_address if use_blockchain else None
667
+ )
668
+
669
+ # Connect wallet if already connected
670
+ if st.session_state.wallet_connected:
671
+ st.session_state.rag.connect_wallet(st.session_state.wallet_address)
672
 
673
+ st.sidebar.success(f"βœ… System initialized!")
674
+
675
+ # Document upload
676
+ st.sidebar.subheader("πŸ“„ Document Upload")
677
+ uploaded_files = st.sidebar.file_uploader("Select PDFs", type="pdf", accept_multiple_files=True)
678
+
679
+ if uploaded_files and st.sidebar.button("Process Documents"):
680
+ if not st.session_state.rag:
681
+ with st.spinner("Initializing system first..."):
682
+ st.session_state.rag = OptimizedRAG(
683
+ llm_model_name=llm_model,
684
+ embedding_model_name=embedding_model,
685
+ chunk_size=1000,
686
+ chunk_overlap=200,
687
+ use_gpu=use_gpu and gpu_available,
688
+ use_blockchain=use_blockchain,
689
+ contract_address=contract_address if use_blockchain else None
690
+ )
691
 
692
+ # Connect wallet if already connected
693
+ if st.session_state.wallet_connected:
694
+ st.session_state.rag.connect_wallet(st.session_state.wallet_address)
695
+
696
+ success = st.session_state.rag.process_pdfs(uploaded_files)
697
+ if success:
698
+ st.sidebar.success("πŸ“„ Documents processed successfully!")
 
 
699
 
700
+ # Method Selection
701
+ st.markdown("### Retrieval Method")
702
+ col1, col2 = st.columns(2)
703
 
704
+ with col1:
705
+ direct_class = "method-button direct-method"
706
+ if st.session_state.retrieval_method == "direct":
707
+ direct_class += " method-active"
708
+
709
+ if st.markdown(f"""
710
+ <div class="{direct_class}" onclick="this.classList.add('method-active')">
711
+ πŸ” Direct Retrieval
712
+ </div>
713
+ """, unsafe_allow_html=True):
714
+ st.session_state.retrieval_method = "direct"
715
+ st.rerun()
716
+
717
+ with col2:
718
+ enhanced_class = "method-button enhanced-method"
719
+ if st.session_state.retrieval_method == "enhanced":
720
+ enhanced_class += " method-active"
 
 
 
 
 
 
 
 
 
 
 
721
 
722
+ if st.markdown(f"""
723
+ <div class="{enhanced_class}" onclick="this.classList.add('method-active')">
724
+ πŸ’‘ Enhanced Answers
725
+ </div>
726
+ """, unsafe_allow_html=True):
727
+ st.session_state.retrieval_method = "enhanced"
728
+ st.rerun()
729
+
730
+ # Method description
731
+ if st.session_state.retrieval_method == "direct":
732
+ st.info("πŸ” **Direct Retrieval**: Shows raw document passages. Fast and transparent.")
733
+ else:
734
+ st.info("πŸ’‘ **Enhanced Answers**: Processes content through AI for better quality answers.")
735
+
736
+ # Main Two-Column Layout
737
+ answer_col, sources_col = st.columns([2, 1])
738
+
739
+ # Answer column
740
+ with answer_col:
741
+ st.markdown("### Ask a Question")
742
 
743
+ # Text input
744
+ user_input = st.text_input("Enter your question about the documents")
 
 
 
 
 
 
 
745
 
746
+ # Simple voice input simulation
747
+ voice_toggle = st.checkbox("Enable voice input")
748
+ if voice_toggle:
749
+ st.markdown("""
750
+ <div style="display: flex; flex-direction: column; align-items: center; margin: 15px 0;">
751
+ <div class="voice-button">🎀</div>
752
+ <div style="margin-top: 10px; color: #666;">Click to speak</div>
753
+ </div>
754
+ """, unsafe_allow_html=True)
755
 
756
+ if st.button("Simulate Voice Input"):
757
+ user_input = "What are the main topics covered in the documents?"
758
+ st.info(f"Voice input received: {user_input}")
759
+ st.rerun()
760
+
761
+ # Process query
762
+ if user_input:
763
+ # Add user message to history
764
+ st.session_state.messages.append({"role": "user", "content": user_input})
765
+
766
+ # Check if system is initialized
767
+ if not st.session_state.rag:
768
+ st.error("Please initialize the system and process PDFs first.")
 
 
 
 
 
 
 
769
 
770
+ # Get response if vector store is ready
771
+ elif st.session_state.rag.vector_store:
772
+ with st.spinner("Generating answer..."):
773
+ # Get retrieval method
774
+ method = st.session_state.retrieval_method
775
+
776
+ # Get answer
777
+ response = st.session_state.rag.ask(user_input, method=method)
778
+ st.session_state.messages.append({"role": "assistant", "content": response})
779
+
780
+ # Store current answer
781
+ st.session_state.current_answer = response
782
+
783
+ # Rerun to update UI
784
+ st.rerun()
785
+ else:
786
+ st.error("Please upload and process PDF files first.")
 
 
 
 
 
 
787
 
788
+ # Display current answer
789
  if st.session_state.current_answer and isinstance(st.session_state.current_answer, dict):
790
+ answer = st.session_state.current_answer
791
+
792
+ st.markdown("""
793
+ <div class="answer-section">
794
+ <h3>Answer</h3>
795
+ <div style="white-space: pre-line;">
796
+ {answer_text}
 
 
 
 
 
797
  </div>
798
+ <div style="margin-top: 10px; font-size: 0.8em; color: #666;">
799
+ Method: {method_name} | Time: {query_time:.2f}s
800
+ </div>
801
+ </div>
802
+ """.format(
803
+ answer_text=answer["answer"],
804
+ method_name="Direct Retrieval" if answer["method"] == "direct" else "Enhanced Answer",
805
+ query_time=answer["query_time"]
806
+ ), unsafe_allow_html=True)
807
+
808
+ # Blockchain verification display
809
+ if "blockchain_log" in answer and answer["blockchain_log"]:
810
+ blockchain_log = answer["blockchain_log"]
811
+ st.success(f"βœ… Query logged on blockchain | Transaction: {blockchain_log['tx_hash'][:10]}...")
812
 
813
+ # Sources column
814
+ with sources_col:
815
+ st.markdown("### Sources")
816
+
817
  if st.session_state.current_answer and isinstance(st.session_state.current_answer, dict):
818
+ answer = st.session_state.current_answer
819
+
820
+ # Display sources
821
+ if "sources" in answer and answer["sources"]:
822
+ for i, source in enumerate(answer["sources"]):
823
+ verified_badge = ""
824
+ if source.get("blockchain"):
825
+ verified_badge = '<span class="verified-badge">βœ“ Verified</span>'
826
+
827
+ st.markdown(f"""
828
+ <div class="source-item">
829
+ <div class="source-header">
830
+ Source {i+1}: {source['source']}
831
+ {verified_badge}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
832
  </div>
833
+ <div style="font-size: 0.9em;">
834
+ {source['content'][:200]}...
835
+ </div>
836
+ </div>
837
+ """, unsafe_allow_html=True)
838
+ else:
839
+ st.info("No sources available for this query.")
840
  else:
841
+ st.info("Ask a question to see sources here.")
 
 
 
 
 
 
 
 
842
 
843
  # Main entry point
844
  if __name__ == "__main__":