import re, uuid import base64 import bcrypt import gradio as gr from gradio_pdf import PDF from pathlib import Path import time import shutil from typing import AsyncGenerator, List, Optional, Tuple from gradio import ChatMessage from fpdf import FPDF REPORT_DIR = Path("reports") REPORT_DIR.mkdir(exist_ok=True) SALT = b'$2b$12$MC7djiqmIR7154Syul5Wme' USERS = { 'test_user': b'$2b$12$MC7djiqmIR7154Syul5WmeQwebwsNOK5svMX08zMYhvpF9P9IVXe6', 'pna': b'$2b$12$MC7djiqmIR7154Syul5WmeWTzYft1UnOV4uGVn54FGfmbH3dRNq1C', 'dr_rajat': b'$2b$12$MC7djiqmIR7154Syul5WmeKZX8DXEs48GWbFpO3nRtFLbB5W/2suW' } class ChatInterface: """ A chat interface for interacting with a medical AI agent through Gradio. Handles file uploads, message processing, and chat history management. Supports both regular image files and DICOM medical imaging files. """ def __init__(self, agent, tools_dict): """ Initialize the chat interface. Args: agent: The medical AI agent to handle requests tools_dict (dict): Dictionary of available tools for image processing """ self.agent = agent self.tools_dict = tools_dict self.upload_dir = Path("temp") self.upload_dir.mkdir(exist_ok=True) self.current_thread_id = None # Separate storage for original and display paths self.original_file_path = None # For LLM (.dcm or other) self.display_file_path = None # For UI (always viewable format) def handle_upload(self, file_path: str) -> str: """ Handle new file upload and set appropriate paths. Args: file_path (str): Path to the uploaded file Returns: str: Display path for UI, or None if no file uploaded """ if not file_path: return None source = Path(file_path) timestamp = int(time.time()) # Save original file with proper suffix suffix = source.suffix.lower() saved_path = self.upload_dir / f"upload_{timestamp}{suffix}" shutil.copy2(file_path, saved_path) # Use file_path directly instead of source self.original_file_path = str(saved_path) # Handle DICOM conversion for display only if suffix == ".dcm": output, _ = self.tools_dict["DicomProcessorTool"]._run(str(saved_path)) self.display_file_path = output["image_path"] else: self.display_file_path = str(saved_path) return self.display_file_path, gr.update(interactive=True), gr.update(interactive=True) def add_message( self, message: str, display_image: str, history: List[dict] ) -> Tuple[List[dict], gr.Textbox]: """ Add a new message to the chat history. Args: message (str): Text message to add display_image (str): Path to image being displayed history (List[dict]): Current chat history Returns: Tuple[List[dict], gr.Textbox]: Updated history and textbox component """ image_path = self.original_file_path or display_image if image_path is not None: history.append({"role": "user", "content": {"path": image_path}}) if message is not None: history.append({"role": "user", "content": message}) return history, gr.Textbox(value=message, interactive=False) async def process_message( self, message: str, display_image: Optional[str], chat_history: List[ChatMessage] ) -> AsyncGenerator[Tuple[List[ChatMessage], Optional[str], str], None]: """ Process a message and generate responses. Args: message (str): User message to process display_image (Optional[str]): Path to currently displayed image chat_history (List[ChatMessage]): Current chat history Yields: Tuple[List[ChatMessage], Optional[str], str]: Updated chat history, display path, and empty string """ chat_history = chat_history or [] # Initialize thread if needed if not self.current_thread_id: self.current_thread_id = str(time.time()) messages = [] image_path = self.original_file_path or display_image if image_path is not None: # Send path for tools messages.append({"role": "user", "content": f"image_path: {image_path}"}) # Load and encode image for multimodal with open(image_path, "rb") as img_file: img_base64 = base64.b64encode(img_file.read()).decode("utf-8") messages.append( { "role": "user", "content": [ { "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}, } ], } ) if message is not None: messages.append({"role": "user", "content": [{"type": "text", "text": message}]}) try: for event in self.agent.workflow.stream( {"messages": messages}, {"configurable": {"thread_id": self.current_thread_id}} ): if isinstance(event, dict): if "process" in event: content = event["process"]["messages"][-1].content if content: content = re.sub(r"temp/[^\s]*", "", content) chat_history.append(ChatMessage(role="assistant", content=content)) yield chat_history, self.display_file_path, "" elif "execute" in event: for message in event["execute"]["messages"]: tool_name = message.name tool_result = eval(message.content)[0] if tool_result: metadata = {"title": f"🖼️ Image from tool: {tool_name}"} formatted_result = " ".join( line.strip() for line in str(tool_result).splitlines() ).strip() metadata["description"] = formatted_result chat_history.append( ChatMessage( role="assistant", content=formatted_result, metadata=metadata, ) ) # For image_visualizer, use display path if tool_name == "image_visualizer": self.display_file_path = tool_result["image_path"] chat_history.append( ChatMessage( role="assistant", # content=gr.Image(value=self.display_file_path), content={"path": self.display_file_path}, ) ) yield chat_history, self.display_file_path, "" except Exception as e: chat_history.append( ChatMessage( role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"} ) ) yield chat_history, self.display_file_path def create_demo(agent, tools_dict): """ Create a Gradio demo interface for the medical AI agent. Args: agent: The medical AI agent to handle requests tools_dict (dict): Dictionary of available tools for image processing Returns: gr.Blocks: Gradio Blocks interface """ interface = ChatInterface(agent, tools_dict) with gr.Blocks(theme=gr.themes.Soft()) as demo: auth_state = gr.State(False) with gr.Column(visible=True) as login_page: gr.Markdown("## 🔐 Login") username = gr.Textbox(label="Username") password = gr.Textbox(label="Password", type="password") login_button = gr.Button("Login") login_error = gr.Markdown(visible=False) with gr.Column(visible=False) as main_page: gr.Markdown( """ # 🏥 MOLx - Powered by MedRAX """ ) with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( [], height=800, container=True, show_label=True, elem_classes="chat-box", type="messages", label="Agent", avatar_images=( None, "assets/medrax_logo.jpg", ), ) with gr.Row(): with gr.Column(scale=3): txt = gr.Textbox( show_label=False, placeholder="Ask about the X-ray...", container=False, ) with gr.Column(scale=3): with gr.Tabs(): with gr.Tab(label="Image section"): image_display = gr.Image( label="Image", type="filepath", height=685, container=True ) # with gr.Row(): # upload_button = gr.UploadButton( # "📎 Upload X-Ray", # file_types=["image"], # ) # dicom_upload = gr.UploadButton( # "📄 Upload DICOM", # file_types=["file"], # ) with gr.Row(): analyze_btn = gr.Button("Analyze 1") analyze2_btn = gr.Button("Analyze 2") segment_btn = gr.Button("Segment") with gr.Row(): clear_btn = gr.Button("Clear Chat") new_thread_btn = gr.Button("New Patient") with gr.Tab(label="Report section"): generate_report_btn = gr.Button("Generate Report") # diseases_df = gr.Dataframe( # headers=["Disease", "Info"], # datatype=["str", "str"], # interactive=False, visible=False, max_height=220) conclusion_tb = gr.Textbox(label="Conclusion", interactive=False) with gr.Row(): approve_btn = gr.Button("Approve", visible=False) # reject_btn = gr.Button("Reject", visible=False) download_pdf_btn = gr.DownloadButton(label="📥 Download PDF", visible=False) # pdf_preview = gr.HTML(visible=False) # pdf_preview = gr.File(visible=False) pdf_preview = PDF(visible=False) # rejection_text = gr.Textbox( # show_label=False, # visible=False, # placeholder="Tell us what is wrong with the report", # container=False, # interactive=True # ) # with gr.Row(): # submit_reject_btn = gr.Button("Submit", visible=False) # cancel_reject_btn = gr.Button("Cancel", visible=False) # Event handlers def authenticate(username, password): hashed = USERS.get(username) if hashed and bcrypt.checkpw(password.encode(), hashed): return ( gr.update(visible=False), # hide login gr.update(visible=True), # show main gr.update(visible=False), # hide error True # set state ) return None, None, gr.update(value="❌ Incorrect username or password", visible=True), False def clear_chat(): interface.original_file_path = None interface.display_file_path = None return [], None def new_thread(): interface.current_thread_id = str(time.time()) return ( [], interface.display_file_path, gr.update(value=None, interactive=False), gr.update(visible=False), # gr.update(visible=False), gr.update(value=None, visible=False), gr.update(value=None, visible=False) ) def handle_file_upload(file): return interface.handle_upload(file.name) def generate_report(): result = interface.agent.summarize_message(interface.current_thread_id) return ( gr.update(value=result["Conclusion"], lines=4, interactive=True), gr.update(visible=True), # gr.update(visible=True), ) # def records_to_pdf(table, conclusion) -> Path: def records_to_pdf(conclusion) -> Path: """ Writes a PDF report under ./reports/ and returns the Path. """ pdf = FPDF() pdf.set_auto_page_break(auto=True, margin=15) pdf.add_page() pdf.set_font(family="Helvetica", size=12) pdf.cell(0, 10, "Chest-X-ray Report", ln=1, align="C") pdf.ln(4) # pdf.set_font(family="Helvetica", style="B") # pdf.cell(60, 8, "Disease") # pdf.cell(0, 8, "Information", ln=1) # pdf.set_font(family="Helvetica", style="") # for idx, row in table.iterrows(): # pdf.multi_cell(0, 8, f"{row['Disease']}: {row['Info']}") # pdf.ln(4) # pdf.set_font(family="Helvetica", style="B") # pdf.cell(0, 8, "Conclusion", ln=1) pdf.set_font(family="Helvetica", style="") pdf.multi_cell(0, 8, conclusion) pdf_path = REPORT_DIR / f"report_{uuid.uuid4().hex}.pdf" pdf.output(str(pdf_path)) return pdf_path # def build_pdf_and_preview(table, conclusion): def build_pdf_and_preview(conclusion): # pdf_path = records_to_pdf(table, conclusion) pdf_path = records_to_pdf(conclusion) iframe_html = ( f'' ) return ( gr.update(value=pdf_path, visible=True), # for DownloadButton gr.update(value=str(pdf_path), visible=True) # for HTML preview ) def show_reject_ui(): return gr.update(visible=True, value=""), gr.update(visible=True), gr.update(visible=True) def hide_reject_ui(): return gr.update(visible=False, value=""), gr.update(visible=False), gr.update(visible=False) login_button.click(authenticate, [username, password], [login_page, main_page, login_error, auth_state]) chat_msg = txt.submit( interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] ) bot_msg = chat_msg.then( interface.process_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, image_display, txt], ) bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt]) analyze_btn.click( lambda: gr.update(value="Analyze this xray and give me a detailed response. Use the medgemma_xray_expert tool"), None, txt ).then( interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] ).then( interface.process_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, image_display, txt], ).then(lambda: gr.Textbox(interactive=True), None, [txt]) analyze2_btn.click( lambda: gr.update(value="Analyze this xray and give me a detailed response. Use the chest_xray_expert tool"), None, txt ).then( interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] ).then( interface.process_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, image_display, txt], ).then(lambda: gr.Textbox(interactive=True), None, [txt]) segment_btn.click( lambda: gr.update(value="Segment the major affected lung"), None, txt ).then( interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] ).then( interface.process_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, image_display, txt], ).then(lambda: gr.Textbox(interactive=True), None, [txt]) # upload_button.upload(handle_file_upload, inputs=upload_button, outputs=[image_display]) # dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=[image_display]) clear_btn.click(clear_chat, outputs=[chatbot, image_display]) new_thread_btn.click(new_thread, outputs=[chatbot, image_display, conclusion_tb, approve_btn, download_pdf_btn, pdf_preview]) # generate_report_btn.click(generate_report, outputs=[diseases_df, conclusion_tb, approve_btn, reject_btn]) generate_report_btn.click(generate_report, outputs=[conclusion_tb, approve_btn]) approve_btn.click( build_pdf_and_preview, # inputs=[diseases_df, conclusion_tb], inputs=[conclusion_tb], outputs=[download_pdf_btn, pdf_preview], ) # reject_btn.click(show_reject_ui, outputs=[rejection_text, submit_reject_btn, cancel_reject_btn]) # cancel_reject_btn.click(hide_reject_ui, outputs=[rejection_text, submit_reject_btn, cancel_reject_btn]) return demo