DocAgent / app.py
ragunath-ravi's picture
Update app.py
2da7fa0 verified
raw
history blame
22.5 kB
import gradio as gr
import os
import json
import uuid
import asyncio
from datetime import datetime
from typing import List, Dict, Any, Optional, Generator
import logging
# Import required libraries
from huggingface_hub import InferenceClient
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.docstore.document import Document
from langchain.chains import LLMChain, RetrievalQA, ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain_community.llms import HuggingFaceHub
# Import document parsers
import PyPDF2
from pptx import Presentation
import pandas as pd
from docx import Document as DocxDocument
import io
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Get HuggingFace token from environment
HF_TOKEN = os.getenv("hf_token")
if not HF_TOKEN:
raise ValueError("HuggingFace token not found in environment variables")
# Initialize HuggingFace LLM
llm = HuggingFaceHub(
repo_id="meta-llama/Llama-3.1-8B-Instruct",
huggingfacehub_api_token=HF_TOKEN,
model_kwargs={"temperature": 0.7, "max_length": 512}
)
# Initialize embeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
class MCPMessage:
"""Model Context Protocol Message Structure"""
def __init__(self, sender: str, receiver: str, msg_type: str,
trace_id: str = None, payload: Dict = None):
self.sender = sender
self.receiver = receiver
self.type = msg_type
self.trace_id = trace_id or str(uuid.uuid4())
self.payload = payload or {}
self.timestamp = datetime.now().isoformat()
def to_dict(self):
return {
"sender": self.sender,
"receiver": self.receiver,
"type": self.type,
"trace_id": self.trace_id,
"payload": self.payload,
"timestamp": self.timestamp
}
class MessageBus:
"""In-memory message bus for MCP communication"""
def __init__(self):
self.messages = []
self.subscribers = {}
def publish(self, message: MCPMessage):
"""Publish message to the bus"""
self.messages.append(message)
logger.info(f"Message published: {message.sender} -> {message.receiver} [{message.type}]")
# Notify subscribers
if message.receiver in self.subscribers:
for callback in self.subscribers[message.receiver]:
callback(message)
def subscribe(self, agent_name: str, callback):
"""Subscribe agent to receive messages"""
if agent_name not in self.subscribers:
self.subscribers[agent_name] = []
self.subscribers[agent_name].append(callback)
# Global message bus
message_bus = MessageBus()
class IngestionAgent:
"""Agent responsible for document parsing and preprocessing"""
def __init__(self, message_bus: MessageBus):
self.name = "IngestionAgent"
self.message_bus = message_bus
self.message_bus.subscribe(self.name, self.handle_message)
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
def handle_message(self, message: MCPMessage):
"""Handle incoming MCP messages"""
if message.type == "INGESTION_REQUEST":
self.process_documents(message)
def parse_pdf(self, file_path: str) -> str:
"""Parse PDF document"""
try:
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
return text
except Exception as e:
logger.error(f"Error parsing PDF: {e}")
return ""
def parse_pptx(self, file_path: str) -> str:
"""Parse PPTX document"""
try:
prs = Presentation(file_path)
text = ""
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text += shape.text + "\n"
return text
except Exception as e:
logger.error(f"Error parsing PPTX: {e}")
return ""
def parse_csv(self, file_path: str) -> str:
"""Parse CSV document"""
try:
df = pd.read_csv(file_path)
return df.to_string()
except Exception as e:
logger.error(f"Error parsing CSV: {e}")
return ""
def parse_docx(self, file_path: str) -> str:
"""Parse DOCX document"""
try:
doc = DocxDocument(file_path)
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
return text
except Exception as e:
logger.error(f"Error parsing DOCX: {e}")
return ""
def parse_txt(self, file_path: str) -> str:
"""Parse TXT/Markdown document"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
except Exception as e:
logger.error(f"Error parsing TXT: {e}")
return ""
def process_documents(self, message: MCPMessage):
"""Process uploaded documents"""
files = message.payload.get("files", [])
processed_docs = []
for file_path in files:
file_ext = os.path.splitext(file_path)[1].lower()
# Parse document based on file type
if file_ext == '.pdf':
text = self.parse_pdf(file_path)
elif file_ext == '.pptx':
text = self.parse_pptx(file_path)
elif file_ext == '.csv':
text = self.parse_csv(file_path)
elif file_ext == '.docx':
text = self.parse_docx(file_path)
elif file_ext in ['.txt', '.md']:
text = self.parse_txt(file_path)
else:
logger.warning(f"Unsupported file type: {file_ext}")
continue
if text:
# Split text into chunks
chunks = self.text_splitter.split_text(text)
docs = [Document(page_content=chunk, metadata={"source": file_path})
for chunk in chunks]
processed_docs.extend(docs)
# Send processed documents to RetrievalAgent
response = MCPMessage(
sender=self.name,
receiver="RetrievalAgent",
msg_type="INGESTION_COMPLETE",
trace_id=message.trace_id,
payload={"documents": processed_docs}
)
self.message_bus.publish(response)
class RetrievalAgent:
"""Agent responsible for embedding and semantic retrieval using LangChain"""
def __init__(self, message_bus: MessageBus):
self.name = "RetrievalAgent"
self.message_bus = message_bus
self.message_bus.subscribe(self.name, self.handle_message)
self.vector_store = None
self.retriever = None
self.qa_chain = None
self.conversation_chain = None
self.memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
output_key="answer"
)
def handle_message(self, message: MCPMessage):
"""Handle incoming MCP messages"""
if message.type == "INGESTION_COMPLETE":
self.create_vector_store(message)
elif message.type == "RETRIEVAL_REQUEST":
self.process_query(message)
def create_vector_store(self, message: MCPMessage):
"""Create vector store and chains from processed documents"""
documents = message.payload.get("documents", [])
if documents:
try:
self.vector_store = FAISS.from_documents(documents, embeddings)
self.retriever = self.vector_store.as_retriever(
search_type="similarity",
search_kwargs={"k": 3}
)
# Create QA chain
self.qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=self.retriever,
return_source_documents=True,
verbose=True
)
# Create conversational chain
self.conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=self.retriever,
memory=self.memory,
return_source_documents=True,
verbose=True
)
logger.info(f"Vector store and chains created with {len(documents)} documents")
# Notify completion
response = MCPMessage(
sender=self.name,
receiver="CoordinatorAgent",
msg_type="VECTORSTORE_READY",
trace_id=message.trace_id,
payload={"status": "ready"}
)
self.message_bus.publish(response)
except Exception as e:
logger.error(f"Error creating vector store: {e}")
def process_query(self, message: MCPMessage):
"""Process query using conversational retrieval chain"""
query = message.payload.get("query", "")
use_conversation = message.payload.get("use_conversation", True)
if not self.qa_chain or not query:
return
try:
if use_conversation and self.conversation_chain:
# Use conversational chain for context-aware responses
result = self.conversation_chain({"question": query})
answer = result["answer"]
source_docs = result.get("source_documents", [])
else:
# Use simple QA chain
result = self.qa_chain({"query": query})
answer = result["result"]
source_docs = result.get("source_documents", [])
# Format sources
sources = []
for doc in source_docs:
sources.append({
"content": doc.page_content[:200] + "...",
"source": doc.metadata.get("source", "Unknown")
})
response = MCPMessage(
sender=self.name,
receiver="CoordinatorAgent",
msg_type="CHAIN_RESPONSE",
trace_id=message.trace_id,
payload={
"query": query,
"answer": answer,
"sources": sources
}
)
self.message_bus.publish(response)
except Exception as e:
logger.error(f"Error processing query: {e}")
# Send error response
response = MCPMessage(
sender=self.name,
receiver="CoordinatorAgent",
msg_type="CHAIN_RESPONSE",
trace_id=message.trace_id,
payload={
"query": query,
"answer": f"Error processing query: {str(e)}",
"sources": []
}
)
self.message_bus.publish(response)
class CoordinatorAgent:
"""Coordinator agent that orchestrates the entire workflow"""
def __init__(self, message_bus: MessageBus):
self.name = "CoordinatorAgent"
self.message_bus = message_bus
self.message_bus.subscribe(self.name, self.handle_message)
self.vector_store_ready = False
self.current_response = None
def handle_message(self, message: MCPMessage):
"""Handle incoming MCP messages"""
if message.type == "VECTORSTORE_READY":
self.vector_store_ready = True
elif message.type == "CHAIN_RESPONSE":
self.current_response = message.payload
def process_files(self, files):
"""Process uploaded files"""
if not files:
return "No files uploaded."
file_paths = [file.name for file in files]
# Send ingestion request
message = MCPMessage(
sender=self.name,
receiver="IngestionAgent",
msg_type="INGESTION_REQUEST",
payload={"files": file_paths}
)
self.message_bus.publish(message)
return f"Processing {len(files)} files: {', '.join([os.path.basename(fp) for fp in file_paths])}"
def handle_query(self, query: str):
"""Handle user query using LangChain chains"""
if not self.vector_store_ready:
return "Please upload and process documents first."
# Send retrieval request
message = MCPMessage(
sender=self.name,
receiver="RetrievalAgent",
msg_type="RETRIEVAL_REQUEST",
payload={"query": query, "use_conversation": True}
)
self.message_bus.publish(message)
# Wait for response
import time
timeout = 30 # seconds
start_time = time.time()
while not self.current_response and (time.time() - start_time) < timeout:
time.sleep(0.1)
if self.current_response:
response = self.current_response
self.current_response = None # Reset for next query
# Format response with sources
answer = response.get("answer", "No answer generated.")
sources = response.get("sources", [])
if sources:
source_text = "\n\n**Sources:**\n"
for i, source in enumerate(sources, 1):
source_text += f"{i}. {source['source']}: {source['content']}\n"
answer += source_text
return answer
else:
return "Timeout: No response received from the system."
# Initialize agents
ingestion_agent = IngestionAgent(message_bus)
retrieval_agent = RetrievalAgent(message_bus)
coordinator_agent = CoordinatorAgent(message_bus)
def create_interface():
"""Create ChatGPT-style Gradio interface"""
with gr.Blocks(
theme=gr.themes.Base(),
css="""
/* Dark theme styling */
.gradio-container {
background-color: #1a1a1a !important;
color: #ffffff !important;
height: 100vh !important;
max-width: none !important;
padding: 0 !important;
}
/* Main container */
.main-container {
display: flex;
flex-direction: column;
height: 100vh;
background: linear-gradient(135deg, #1a1a1a 0%, #2d2d2d 100%);
}
/* Header */
.header {
background: rgba(255, 193, 7, 0.1);
border-bottom: 1px solid rgba(255, 193, 7, 0.2);
padding: 1rem 2rem;
backdrop-filter: blur(10px);
}
.header h1 {
color: #ffc107;
margin: 0;
font-size: 1.5rem;
font-weight: 600;
}
.header p {
color: #cccccc;
margin: 0.25rem 0 0 0;
font-size: 0.9rem;
}
/* Chat area */
.chat-container {
flex: 1;
display: flex;
flex-direction: column;
max-width: 800px;
margin: 0 auto;
width: 100%;
padding: 1rem;
}
/* Chatbot styling */
.gradio-chatbot {
flex: 1 !important;
background: transparent !important;
border: none !important;
margin-bottom: 1rem;
}
/* Input area */
.input-area {
background: rgba(45, 45, 45, 0.6);
border-radius: 16px;
padding: 1rem;
border: 1px solid rgba(255, 193, 7, 0.2);
backdrop-filter: blur(10px);
}
/* File upload */
.upload-area {
background: rgba(255, 193, 7, 0.05);
border: 2px dashed rgba(255, 193, 7, 0.3);
border-radius: 12px;
padding: 1rem;
margin-bottom: 1rem;
transition: all 0.3s ease;
}
/* Buttons */
.primary-btn {
background: linear-gradient(135deg, #ffc107 0%, #ff8f00 100%) !important;
color: #000000 !important;
border: none !important;
border-radius: 8px !important;
font-weight: 600 !important;
}
/* Text inputs */
.gradio-textbox input, .gradio-textbox textarea {
background: rgba(45, 45, 45, 0.8) !important;
color: #ffffff !important;
border: 1px solid rgba(255, 193, 7, 0.2) !important;
border-radius: 8px !important;
}
/* Processing indicator */
.processing-indicator {
background: rgba(255, 193, 7, 0.1);
border: 1px solid rgba(255, 193, 7, 0.3);
border-radius: 8px;
padding: 0.75rem;
margin: 0.5rem 0;
color: #ffc107;
text-align: center;
}
""",
title="Agentic RAG Assistant"
) as iface:
# Header
with gr.Row():
with gr.Column():
gr.HTML("""
<div class="header">
<h1>πŸ€– Agentic RAG Assistant</h1>
<p>Upload documents and ask questions - powered by LangChain Multi-Agent Architecture</p>
</div>
""")
# Main chat container
with gr.Row():
with gr.Column():
# Chatbot
chatbot = gr.Chatbot(
value=[],
height=500,
show_copy_button=True
)
# Input area
with gr.Column():
# File upload
file_upload = gr.File(
file_count="multiple",
file_types=[".pdf", ".pptx", ".csv", ".docx", ".txt", ".md"],
label="Upload Documents"
)
# Processing status
processing_status = gr.HTML(visible=False)
# Message input row
with gr.Row():
msg_input = gr.Textbox(
placeholder="Upload documents above, then ask your questions here...",
label="Message",
scale=4
)
send_btn = gr.Button("Send", scale=1, variant="primary")
# State to track document processing
doc_processed = gr.State(False)
# Event handlers
def handle_file_upload(files):
if not files:
return gr.update(visible=False), False
# Show processing indicator
processing_html = f"""
<div class="processing-indicator">
πŸ“„ Processing {len(files)} documents... Please wait.
</div>
"""
# Process files
try:
result = coordinator_agent.process_files(files)
# Wait a moment for processing to complete
import time
time.sleep(3)
success_html = """
<div style="background: rgba(76, 175, 80, 0.1); border: 1px solid rgba(76, 175, 80, 0.3);
border-radius: 8px; padding: 0.75rem; color: #4caf50; text-align: center;">
βœ… Documents processed successfully! You can now ask questions.
</div>
"""
return gr.update(value=success_html, visible=True), True
except Exception as e:
error_html = f"""
<div style="background: rgba(244, 67, 54, 0.1); border: 1px solid rgba(244, 67, 54, 0.3);
border-radius: 8px; padding: 0.75rem; color: #f44336; text-align: center;">
❌ Error processing documents: {str(e)}
</div>
"""
return gr.update(value=error_html, visible=True), False
def respond(message, history, doc_ready):
if not doc_ready:
return history + [["Please upload and process documents first.", None]], ""
if not message.strip():
return history, message
# Get response from coordinator
response = coordinator_agent.handle_query(message)
# Add to chat history
history.append([message, response])
return history, ""
# File upload triggers processing
file_upload.change(
handle_file_upload,
inputs=[file_upload],
outputs=[processing_status, doc_processed]
)
# Send message
send_btn.click(
respond,
inputs=[msg_input, chatbot, doc_processed],
outputs=[chatbot, msg_input]
)
msg_input.submit(
respond,
inputs=[msg_input, chatbot, doc_processed],
outputs=[chatbot, msg_input]
)
return iface
# Launch the application
if __name__ == "__main__":
demo = create_interface()
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860
)