Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import json | |
import uuid | |
import asyncio | |
from datetime import datetime | |
from typing import List, Dict, Any, Optional, Generator | |
import logging | |
# Import required libraries | |
from huggingface_hub import InferenceClient | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.docstore.document import Document | |
from langchain.chains import LLMChain, RetrievalQA, ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain.prompts import PromptTemplate | |
from langchain_community.llms import HuggingFaceHub | |
# Import document parsers | |
import PyPDF2 | |
from pptx import Presentation | |
import pandas as pd | |
from docx import Document as DocxDocument | |
import io | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Get HuggingFace token from environment | |
HF_TOKEN = os.getenv("hf_token") | |
if not HF_TOKEN: | |
raise ValueError("HuggingFace token not found in environment variables") | |
# Initialize HuggingFace LLM | |
llm = HuggingFaceHub( | |
repo_id="meta-llama/Llama-3.1-8B-Instruct", | |
huggingfacehub_api_token=HF_TOKEN, | |
model_kwargs={"temperature": 0.7, "max_length": 512} | |
) | |
# Initialize embeddings | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
class MCPMessage: | |
"""Model Context Protocol Message Structure""" | |
def __init__(self, sender: str, receiver: str, msg_type: str, | |
trace_id: str = None, payload: Dict = None): | |
self.sender = sender | |
self.receiver = receiver | |
self.type = msg_type | |
self.trace_id = trace_id or str(uuid.uuid4()) | |
self.payload = payload or {} | |
self.timestamp = datetime.now().isoformat() | |
def to_dict(self): | |
return { | |
"sender": self.sender, | |
"receiver": self.receiver, | |
"type": self.type, | |
"trace_id": self.trace_id, | |
"payload": self.payload, | |
"timestamp": self.timestamp | |
} | |
class MessageBus: | |
"""In-memory message bus for MCP communication""" | |
def __init__(self): | |
self.messages = [] | |
self.subscribers = {} | |
def publish(self, message: MCPMessage): | |
"""Publish message to the bus""" | |
self.messages.append(message) | |
logger.info(f"Message published: {message.sender} -> {message.receiver} [{message.type}]") | |
# Notify subscribers | |
if message.receiver in self.subscribers: | |
for callback in self.subscribers[message.receiver]: | |
callback(message) | |
def subscribe(self, agent_name: str, callback): | |
"""Subscribe agent to receive messages""" | |
if agent_name not in self.subscribers: | |
self.subscribers[agent_name] = [] | |
self.subscribers[agent_name].append(callback) | |
# Global message bus | |
message_bus = MessageBus() | |
class IngestionAgent: | |
"""Agent responsible for document parsing and preprocessing""" | |
def __init__(self, message_bus: MessageBus): | |
self.name = "IngestionAgent" | |
self.message_bus = message_bus | |
self.message_bus.subscribe(self.name, self.handle_message) | |
self.text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=200 | |
) | |
def handle_message(self, message: MCPMessage): | |
"""Handle incoming MCP messages""" | |
if message.type == "INGESTION_REQUEST": | |
self.process_documents(message) | |
def parse_pdf(self, file_path: str) -> str: | |
"""Parse PDF document""" | |
try: | |
with open(file_path, 'rb') as file: | |
pdf_reader = PyPDF2.PdfReader(file) | |
text = "" | |
for page in pdf_reader.pages: | |
text += page.extract_text() | |
return text | |
except Exception as e: | |
logger.error(f"Error parsing PDF: {e}") | |
return "" | |
def parse_pptx(self, file_path: str) -> str: | |
"""Parse PPTX document""" | |
try: | |
prs = Presentation(file_path) | |
text = "" | |
for slide in prs.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text"): | |
text += shape.text + "\n" | |
return text | |
except Exception as e: | |
logger.error(f"Error parsing PPTX: {e}") | |
return "" | |
def parse_csv(self, file_path: str) -> str: | |
"""Parse CSV document""" | |
try: | |
df = pd.read_csv(file_path) | |
return df.to_string() | |
except Exception as e: | |
logger.error(f"Error parsing CSV: {e}") | |
return "" | |
def parse_docx(self, file_path: str) -> str: | |
"""Parse DOCX document""" | |
try: | |
doc = DocxDocument(file_path) | |
text = "" | |
for paragraph in doc.paragraphs: | |
text += paragraph.text + "\n" | |
return text | |
except Exception as e: | |
logger.error(f"Error parsing DOCX: {e}") | |
return "" | |
def parse_txt(self, file_path: str) -> str: | |
"""Parse TXT/Markdown document""" | |
try: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
return file.read() | |
except Exception as e: | |
logger.error(f"Error parsing TXT: {e}") | |
return "" | |
def process_documents(self, message: MCPMessage): | |
"""Process uploaded documents""" | |
files = message.payload.get("files", []) | |
processed_docs = [] | |
for file_path in files: | |
file_ext = os.path.splitext(file_path)[1].lower() | |
# Parse document based on file type | |
if file_ext == '.pdf': | |
text = self.parse_pdf(file_path) | |
elif file_ext == '.pptx': | |
text = self.parse_pptx(file_path) | |
elif file_ext == '.csv': | |
text = self.parse_csv(file_path) | |
elif file_ext == '.docx': | |
text = self.parse_docx(file_path) | |
elif file_ext in ['.txt', '.md']: | |
text = self.parse_txt(file_path) | |
else: | |
logger.warning(f"Unsupported file type: {file_ext}") | |
continue | |
if text: | |
# Split text into chunks | |
chunks = self.text_splitter.split_text(text) | |
docs = [Document(page_content=chunk, metadata={"source": file_path}) | |
for chunk in chunks] | |
processed_docs.extend(docs) | |
# Send processed documents to RetrievalAgent | |
response = MCPMessage( | |
sender=self.name, | |
receiver="RetrievalAgent", | |
msg_type="INGESTION_COMPLETE", | |
trace_id=message.trace_id, | |
payload={"documents": processed_docs} | |
) | |
self.message_bus.publish(response) | |
class RetrievalAgent: | |
"""Agent responsible for embedding and semantic retrieval using LangChain""" | |
def __init__(self, message_bus: MessageBus): | |
self.name = "RetrievalAgent" | |
self.message_bus = message_bus | |
self.message_bus.subscribe(self.name, self.handle_message) | |
self.vector_store = None | |
self.retriever = None | |
self.qa_chain = None | |
self.conversation_chain = None | |
self.memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
return_messages=True, | |
output_key="answer" | |
) | |
def handle_message(self, message: MCPMessage): | |
"""Handle incoming MCP messages""" | |
if message.type == "INGESTION_COMPLETE": | |
self.create_vector_store(message) | |
elif message.type == "RETRIEVAL_REQUEST": | |
self.process_query(message) | |
def create_vector_store(self, message: MCPMessage): | |
"""Create vector store and chains from processed documents""" | |
documents = message.payload.get("documents", []) | |
if documents: | |
try: | |
self.vector_store = FAISS.from_documents(documents, embeddings) | |
self.retriever = self.vector_store.as_retriever( | |
search_type="similarity", | |
search_kwargs={"k": 3} | |
) | |
# Create QA chain | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=self.retriever, | |
return_source_documents=True, | |
verbose=True | |
) | |
# Create conversational chain | |
self.conversation_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=self.retriever, | |
memory=self.memory, | |
return_source_documents=True, | |
verbose=True | |
) | |
logger.info(f"Vector store and chains created with {len(documents)} documents") | |
# Notify completion | |
response = MCPMessage( | |
sender=self.name, | |
receiver="CoordinatorAgent", | |
msg_type="VECTORSTORE_READY", | |
trace_id=message.trace_id, | |
payload={"status": "ready"} | |
) | |
self.message_bus.publish(response) | |
except Exception as e: | |
logger.error(f"Error creating vector store: {e}") | |
def process_query(self, message: MCPMessage): | |
"""Process query using conversational retrieval chain""" | |
query = message.payload.get("query", "") | |
use_conversation = message.payload.get("use_conversation", True) | |
if not self.qa_chain or not query: | |
return | |
try: | |
if use_conversation and self.conversation_chain: | |
# Use conversational chain for context-aware responses | |
result = self.conversation_chain({"question": query}) | |
answer = result["answer"] | |
source_docs = result.get("source_documents", []) | |
else: | |
# Use simple QA chain | |
result = self.qa_chain({"query": query}) | |
answer = result["result"] | |
source_docs = result.get("source_documents", []) | |
# Format sources | |
sources = [] | |
for doc in source_docs: | |
sources.append({ | |
"content": doc.page_content[:200] + "...", | |
"source": doc.metadata.get("source", "Unknown") | |
}) | |
response = MCPMessage( | |
sender=self.name, | |
receiver="CoordinatorAgent", | |
msg_type="CHAIN_RESPONSE", | |
trace_id=message.trace_id, | |
payload={ | |
"query": query, | |
"answer": answer, | |
"sources": sources | |
} | |
) | |
self.message_bus.publish(response) | |
except Exception as e: | |
logger.error(f"Error processing query: {e}") | |
# Send error response | |
response = MCPMessage( | |
sender=self.name, | |
receiver="CoordinatorAgent", | |
msg_type="CHAIN_RESPONSE", | |
trace_id=message.trace_id, | |
payload={ | |
"query": query, | |
"answer": f"Error processing query: {str(e)}", | |
"sources": [] | |
} | |
) | |
self.message_bus.publish(response) | |
class CoordinatorAgent: | |
"""Coordinator agent that orchestrates the entire workflow""" | |
def __init__(self, message_bus: MessageBus): | |
self.name = "CoordinatorAgent" | |
self.message_bus = message_bus | |
self.message_bus.subscribe(self.name, self.handle_message) | |
self.vector_store_ready = False | |
self.current_response = None | |
def handle_message(self, message: MCPMessage): | |
"""Handle incoming MCP messages""" | |
if message.type == "VECTORSTORE_READY": | |
self.vector_store_ready = True | |
elif message.type == "CHAIN_RESPONSE": | |
self.current_response = message.payload | |
def process_files(self, files): | |
"""Process uploaded files""" | |
if not files: | |
return "No files uploaded." | |
file_paths = [file.name for file in files] | |
# Send ingestion request | |
message = MCPMessage( | |
sender=self.name, | |
receiver="IngestionAgent", | |
msg_type="INGESTION_REQUEST", | |
payload={"files": file_paths} | |
) | |
self.message_bus.publish(message) | |
return f"Processing {len(files)} files: {', '.join([os.path.basename(fp) for fp in file_paths])}" | |
def handle_query(self, query: str): | |
"""Handle user query using LangChain chains""" | |
if not self.vector_store_ready: | |
return "Please upload and process documents first." | |
# Send retrieval request | |
message = MCPMessage( | |
sender=self.name, | |
receiver="RetrievalAgent", | |
msg_type="RETRIEVAL_REQUEST", | |
payload={"query": query, "use_conversation": True} | |
) | |
self.message_bus.publish(message) | |
# Wait for response | |
import time | |
timeout = 30 # seconds | |
start_time = time.time() | |
while not self.current_response and (time.time() - start_time) < timeout: | |
time.sleep(0.1) | |
if self.current_response: | |
response = self.current_response | |
self.current_response = None # Reset for next query | |
# Format response with sources | |
answer = response.get("answer", "No answer generated.") | |
sources = response.get("sources", []) | |
if sources: | |
source_text = "\n\n**Sources:**\n" | |
for i, source in enumerate(sources, 1): | |
source_text += f"{i}. {source['source']}: {source['content']}\n" | |
answer += source_text | |
return answer | |
else: | |
return "Timeout: No response received from the system." | |
# Initialize agents | |
ingestion_agent = IngestionAgent(message_bus) | |
retrieval_agent = RetrievalAgent(message_bus) | |
coordinator_agent = CoordinatorAgent(message_bus) | |
def create_interface(): | |
"""Create ChatGPT-style Gradio interface""" | |
with gr.Blocks( | |
theme=gr.themes.Base(), | |
css=""" | |
/* Dark theme styling */ | |
.gradio-container { | |
background-color: #1a1a1a !important; | |
color: #ffffff !important; | |
height: 100vh !important; | |
max-width: none !important; | |
padding: 0 !important; | |
} | |
/* Main container */ | |
.main-container { | |
display: flex; | |
flex-direction: column; | |
height: 100vh; | |
background: linear-gradient(135deg, #1a1a1a 0%, #2d2d2d 100%); | |
} | |
/* Header */ | |
.header { | |
background: rgba(255, 193, 7, 0.1); | |
border-bottom: 1px solid rgba(255, 193, 7, 0.2); | |
padding: 1rem 2rem; | |
backdrop-filter: blur(10px); | |
} | |
.header h1 { | |
color: #ffc107; | |
margin: 0; | |
font-size: 1.5rem; | |
font-weight: 600; | |
} | |
.header p { | |
color: #cccccc; | |
margin: 0.25rem 0 0 0; | |
font-size: 0.9rem; | |
} | |
/* Chat area */ | |
.chat-container { | |
flex: 1; | |
display: flex; | |
flex-direction: column; | |
max-width: 800px; | |
margin: 0 auto; | |
width: 100%; | |
padding: 1rem; | |
} | |
/* Chatbot styling */ | |
.gradio-chatbot { | |
flex: 1 !important; | |
background: transparent !important; | |
border: none !important; | |
margin-bottom: 1rem; | |
} | |
/* Input area */ | |
.input-area { | |
background: rgba(45, 45, 45, 0.6); | |
border-radius: 16px; | |
padding: 1rem; | |
border: 1px solid rgba(255, 193, 7, 0.2); | |
backdrop-filter: blur(10px); | |
} | |
/* File upload */ | |
.upload-area { | |
background: rgba(255, 193, 7, 0.05); | |
border: 2px dashed rgba(255, 193, 7, 0.3); | |
border-radius: 12px; | |
padding: 1rem; | |
margin-bottom: 1rem; | |
transition: all 0.3s ease; | |
} | |
/* Buttons */ | |
.primary-btn { | |
background: linear-gradient(135deg, #ffc107 0%, #ff8f00 100%) !important; | |
color: #000000 !important; | |
border: none !important; | |
border-radius: 8px !important; | |
font-weight: 600 !important; | |
} | |
/* Text inputs */ | |
.gradio-textbox input, .gradio-textbox textarea { | |
background: rgba(45, 45, 45, 0.8) !important; | |
color: #ffffff !important; | |
border: 1px solid rgba(255, 193, 7, 0.2) !important; | |
border-radius: 8px !important; | |
} | |
/* Processing indicator */ | |
.processing-indicator { | |
background: rgba(255, 193, 7, 0.1); | |
border: 1px solid rgba(255, 193, 7, 0.3); | |
border-radius: 8px; | |
padding: 0.75rem; | |
margin: 0.5rem 0; | |
color: #ffc107; | |
text-align: center; | |
} | |
""", | |
title="Agentic RAG Assistant" | |
) as iface: | |
# Header | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML(""" | |
<div class="header"> | |
<h1>π€ Agentic RAG Assistant</h1> | |
<p>Upload documents and ask questions - powered by LangChain Multi-Agent Architecture</p> | |
</div> | |
""") | |
# Main chat container | |
with gr.Row(): | |
with gr.Column(): | |
# Chatbot | |
chatbot = gr.Chatbot( | |
value=[], | |
height=500, | |
show_copy_button=True | |
) | |
# Input area | |
with gr.Column(): | |
# File upload | |
file_upload = gr.File( | |
file_count="multiple", | |
file_types=[".pdf", ".pptx", ".csv", ".docx", ".txt", ".md"], | |
label="Upload Documents" | |
) | |
# Processing status | |
processing_status = gr.HTML(visible=False) | |
# Message input row | |
with gr.Row(): | |
msg_input = gr.Textbox( | |
placeholder="Upload documents above, then ask your questions here...", | |
label="Message", | |
scale=4 | |
) | |
send_btn = gr.Button("Send", scale=1, variant="primary") | |
# State to track document processing | |
doc_processed = gr.State(False) | |
# Event handlers | |
def handle_file_upload(files): | |
if not files: | |
return gr.update(visible=False), False | |
# Show processing indicator | |
processing_html = f""" | |
<div class="processing-indicator"> | |
π Processing {len(files)} documents... Please wait. | |
</div> | |
""" | |
# Process files | |
try: | |
result = coordinator_agent.process_files(files) | |
# Wait a moment for processing to complete | |
import time | |
time.sleep(3) | |
success_html = """ | |
<div style="background: rgba(76, 175, 80, 0.1); border: 1px solid rgba(76, 175, 80, 0.3); | |
border-radius: 8px; padding: 0.75rem; color: #4caf50; text-align: center;"> | |
β Documents processed successfully! You can now ask questions. | |
</div> | |
""" | |
return gr.update(value=success_html, visible=True), True | |
except Exception as e: | |
error_html = f""" | |
<div style="background: rgba(244, 67, 54, 0.1); border: 1px solid rgba(244, 67, 54, 0.3); | |
border-radius: 8px; padding: 0.75rem; color: #f44336; text-align: center;"> | |
β Error processing documents: {str(e)} | |
</div> | |
""" | |
return gr.update(value=error_html, visible=True), False | |
def respond(message, history, doc_ready): | |
if not doc_ready: | |
return history + [["Please upload and process documents first.", None]], "" | |
if not message.strip(): | |
return history, message | |
# Get response from coordinator | |
response = coordinator_agent.handle_query(message) | |
# Add to chat history | |
history.append([message, response]) | |
return history, "" | |
# File upload triggers processing | |
file_upload.change( | |
handle_file_upload, | |
inputs=[file_upload], | |
outputs=[processing_status, doc_processed] | |
) | |
# Send message | |
send_btn.click( | |
respond, | |
inputs=[msg_input, chatbot, doc_processed], | |
outputs=[chatbot, msg_input] | |
) | |
msg_input.submit( | |
respond, | |
inputs=[msg_input, chatbot, doc_processed], | |
outputs=[chatbot, msg_input] | |
) | |
return iface | |
# Launch the application | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |