MOLx-Powered_by_MedRAX / interface.py
vaibhavm29's picture
bug fixes
3c58ccd
import re, uuid
import base64
import bcrypt
import gradio as gr
from gradio_pdf import PDF
from pathlib import Path
import time
import shutil
from typing import AsyncGenerator, List, Optional, Tuple
from gradio import ChatMessage
from fpdf import FPDF
REPORT_DIR = Path("reports")
REPORT_DIR.mkdir(exist_ok=True)
SALT = b'$2b$12$MC7djiqmIR7154Syul5Wme'
USERS = {
'test_user': b'$2b$12$MC7djiqmIR7154Syul5WmeQwebwsNOK5svMX08zMYhvpF9P9IVXe6',
'pna': b'$2b$12$MC7djiqmIR7154Syul5WmeWTzYft1UnOV4uGVn54FGfmbH3dRNq1C',
'dr_rajat': b'$2b$12$MC7djiqmIR7154Syul5WmeKZX8DXEs48GWbFpO3nRtFLbB5W/2suW'
}
class ChatInterface:
"""
A chat interface for interacting with a medical AI agent through Gradio.
Handles file uploads, message processing, and chat history management.
Supports both regular image files and DICOM medical imaging files.
"""
def __init__(self, agent, tools_dict):
"""
Initialize the chat interface.
Args:
agent: The medical AI agent to handle requests
tools_dict (dict): Dictionary of available tools for image processing
"""
self.agent = agent
self.tools_dict = tools_dict
self.upload_dir = Path("temp")
self.upload_dir.mkdir(exist_ok=True)
self.current_thread_id = None
# Separate storage for original and display paths
self.original_file_path = None # For LLM (.dcm or other)
self.display_file_path = None # For UI (always viewable format)
def handle_upload(self, file_path: str) -> str:
"""
Handle new file upload and set appropriate paths.
Args:
file_path (str): Path to the uploaded file
Returns:
str: Display path for UI, or None if no file uploaded
"""
if not file_path:
return None
source = Path(file_path)
timestamp = int(time.time())
# Save original file with proper suffix
suffix = source.suffix.lower()
saved_path = self.upload_dir / f"upload_{timestamp}{suffix}"
shutil.copy2(file_path, saved_path) # Use file_path directly instead of source
self.original_file_path = str(saved_path)
# Handle DICOM conversion for display only
if suffix == ".dcm":
output, _ = self.tools_dict["DicomProcessorTool"]._run(str(saved_path))
self.display_file_path = output["image_path"]
else:
self.display_file_path = str(saved_path)
return self.display_file_path, gr.update(interactive=True), gr.update(interactive=True)
def add_message(
self, message: str, display_image: str, history: List[dict]
) -> Tuple[List[dict], gr.Textbox]:
"""
Add a new message to the chat history.
Args:
message (str): Text message to add
display_image (str): Path to image being displayed
history (List[dict]): Current chat history
Returns:
Tuple[List[dict], gr.Textbox]: Updated history and textbox component
"""
image_path = self.original_file_path or display_image
if image_path is not None:
history.append({"role": "user", "content": {"path": image_path}})
if message is not None:
history.append({"role": "user", "content": message})
return history, gr.Textbox(value=message, interactive=False)
async def process_message(
self, message: str, display_image: Optional[str], chat_history: List[ChatMessage]
) -> AsyncGenerator[Tuple[List[ChatMessage], Optional[str], str], None]:
"""
Process a message and generate responses.
Args:
message (str): User message to process
display_image (Optional[str]): Path to currently displayed image
chat_history (List[ChatMessage]): Current chat history
Yields:
Tuple[List[ChatMessage], Optional[str], str]: Updated chat history, display path, and empty string
"""
chat_history = chat_history or []
# Initialize thread if needed
if not self.current_thread_id:
self.current_thread_id = str(time.time())
messages = []
image_path = self.original_file_path or display_image
if image_path is not None:
# Send path for tools
messages.append({"role": "user", "content": f"image_path: {image_path}"})
# Load and encode image for multimodal
with open(image_path, "rb") as img_file:
img_base64 = base64.b64encode(img_file.read()).decode("utf-8")
messages.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
}
],
}
)
if message is not None:
messages.append({"role": "user", "content": [{"type": "text", "text": message}]})
try:
for event in self.agent.workflow.stream(
{"messages": messages}, {"configurable": {"thread_id": self.current_thread_id}}
):
if isinstance(event, dict):
if "process" in event:
content = event["process"]["messages"][-1].content
if content:
content = re.sub(r"temp/[^\s]*", "", content)
chat_history.append(ChatMessage(role="assistant", content=content))
yield chat_history, self.display_file_path, ""
elif "execute" in event:
for message in event["execute"]["messages"]:
tool_name = message.name
tool_result = eval(message.content)[0]
if tool_result:
metadata = {"title": f"🖼️ Image from tool: {tool_name}"}
formatted_result = " ".join(
line.strip() for line in str(tool_result).splitlines()
).strip()
metadata["description"] = formatted_result
chat_history.append(
ChatMessage(
role="assistant",
content=formatted_result,
metadata=metadata,
)
)
# For image_visualizer, use display path
if tool_name == "image_visualizer":
self.display_file_path = tool_result["image_path"]
chat_history.append(
ChatMessage(
role="assistant",
# content=gr.Image(value=self.display_file_path),
content={"path": self.display_file_path},
)
)
yield chat_history, self.display_file_path, ""
except Exception as e:
chat_history.append(
ChatMessage(
role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"}
)
)
yield chat_history, self.display_file_path
def create_demo(agent, tools_dict):
"""
Create a Gradio demo interface for the medical AI agent.
Args:
agent: The medical AI agent to handle requests
tools_dict (dict): Dictionary of available tools for image processing
Returns:
gr.Blocks: Gradio Blocks interface
"""
interface = ChatInterface(agent, tools_dict)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
auth_state = gr.State(False)
with gr.Column(visible=True) as login_page:
gr.Markdown("## 🔐 Login")
username = gr.Textbox(label="Username")
password = gr.Textbox(label="Password", type="password")
login_button = gr.Button("Login")
login_error = gr.Markdown(visible=False)
with gr.Column(visible=False) as main_page:
gr.Markdown(
"""
# 🏥 MOLx - Powered by MedRAX
"""
)
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
[],
height=800,
container=True,
show_label=True,
elem_classes="chat-box",
type="messages",
label="Agent",
avatar_images=(
None,
"assets/medrax_logo.jpg",
),
)
with gr.Row():
with gr.Column(scale=3):
txt = gr.Textbox(
show_label=False,
placeholder="Ask about the X-ray...",
container=False,
)
with gr.Column(scale=3):
with gr.Tabs():
with gr.Tab(label="Image section"):
image_display = gr.Image(
label="Image", type="filepath", height=685, container=True
)
# with gr.Row():
# upload_button = gr.UploadButton(
# "📎 Upload X-Ray",
# file_types=["image"],
# )
# dicom_upload = gr.UploadButton(
# "📄 Upload DICOM",
# file_types=["file"],
# )
with gr.Row():
analyze_btn = gr.Button("Analyze 1")
analyze2_btn = gr.Button("Analyze 2")
segment_btn = gr.Button("Segment")
with gr.Row():
clear_btn = gr.Button("Clear Chat")
new_thread_btn = gr.Button("New Patient")
with gr.Tab(label="Report section"):
generate_report_btn = gr.Button("Generate Report")
# diseases_df = gr.Dataframe(
# headers=["Disease", "Info"],
# datatype=["str", "str"],
# interactive=False, visible=False, max_height=220)
conclusion_tb = gr.Textbox(label="Conclusion", interactive=False)
with gr.Row():
approve_btn = gr.Button("Approve", visible=False)
# reject_btn = gr.Button("Reject", visible=False)
download_pdf_btn = gr.DownloadButton(label="📥 Download PDF", visible=False)
# pdf_preview = gr.HTML(visible=False)
# pdf_preview = gr.File(visible=False)
pdf_preview = PDF(visible=False)
# rejection_text = gr.Textbox(
# show_label=False,
# visible=False,
# placeholder="Tell us what is wrong with the report",
# container=False,
# interactive=True
# )
# with gr.Row():
# submit_reject_btn = gr.Button("Submit", visible=False)
# cancel_reject_btn = gr.Button("Cancel", visible=False)
# Event handlers
def authenticate(username, password):
hashed = USERS.get(username)
if hashed and bcrypt.checkpw(password.encode(), hashed):
return (
gr.update(visible=False), # hide login
gr.update(visible=True), # show main
gr.update(visible=False), # hide error
True # set state
)
return None, None, gr.update(value="❌ Incorrect username or password", visible=True), False
def clear_chat():
interface.original_file_path = None
interface.display_file_path = None
return [], None
def new_thread():
interface.current_thread_id = str(time.time())
return (
[],
interface.display_file_path,
gr.update(value=None, interactive=False),
gr.update(visible=False),
# gr.update(visible=False),
gr.update(value=None, visible=False),
gr.update(value=None, visible=False)
)
def handle_file_upload(file):
return interface.handle_upload(file.name)
def generate_report():
result = interface.agent.summarize_message(interface.current_thread_id)
return (
gr.update(value=result["Conclusion"], lines=4, interactive=True),
gr.update(visible=True),
# gr.update(visible=True),
)
# def records_to_pdf(table, conclusion) -> Path:
def records_to_pdf(conclusion) -> Path:
"""
Writes a PDF report under ./reports/ and returns the Path.
"""
pdf = FPDF()
pdf.set_auto_page_break(auto=True, margin=15)
pdf.add_page()
pdf.set_font(family="Helvetica", size=12)
pdf.cell(0, 10, "Chest-X-ray Report", ln=1, align="C")
pdf.ln(4)
# pdf.set_font(family="Helvetica", style="B")
# pdf.cell(60, 8, "Disease")
# pdf.cell(0, 8, "Information", ln=1)
# pdf.set_font(family="Helvetica", style="")
# for idx, row in table.iterrows():
# pdf.multi_cell(0, 8, f"{row['Disease']}: {row['Info']}")
# pdf.ln(4)
# pdf.set_font(family="Helvetica", style="B")
# pdf.cell(0, 8, "Conclusion", ln=1)
pdf.set_font(family="Helvetica", style="")
pdf.multi_cell(0, 8, conclusion)
pdf_path = REPORT_DIR / f"report_{uuid.uuid4().hex}.pdf"
pdf.output(str(pdf_path))
return pdf_path
# def build_pdf_and_preview(table, conclusion):
def build_pdf_and_preview(conclusion):
# pdf_path = records_to_pdf(table, conclusion)
pdf_path = records_to_pdf(conclusion)
iframe_html = (
f'<iframe src="file={pdf_path}" '
'style="width:100%;height:650px;border:none;"></iframe>'
)
return (
gr.update(value=pdf_path, visible=True), # for DownloadButton
gr.update(value=str(pdf_path), visible=True) # for HTML preview
)
def show_reject_ui():
return gr.update(visible=True, value=""), gr.update(visible=True), gr.update(visible=True)
def hide_reject_ui():
return gr.update(visible=False, value=""), gr.update(visible=False), gr.update(visible=False)
login_button.click(authenticate, [username, password], [login_page, main_page, login_error, auth_state])
chat_msg = txt.submit(
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
)
bot_msg = chat_msg.then(
interface.process_message,
inputs=[txt, image_display, chatbot],
outputs=[chatbot, image_display, txt],
)
bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt])
analyze_btn.click(
lambda: gr.update(value="Analyze this xray and give me a detailed response. Use the medgemma_xray_expert tool"), None, txt
).then(
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
).then(
interface.process_message,
inputs=[txt, image_display, chatbot],
outputs=[chatbot, image_display, txt],
).then(lambda: gr.Textbox(interactive=True), None, [txt])
analyze2_btn.click(
lambda: gr.update(value="Analyze this xray and give me a detailed response. Use the chest_xray_expert tool"), None, txt
).then(
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
).then(
interface.process_message,
inputs=[txt, image_display, chatbot],
outputs=[chatbot, image_display, txt],
).then(lambda: gr.Textbox(interactive=True), None, [txt])
segment_btn.click(
lambda: gr.update(value="Segment the major affected lung"), None, txt
).then(
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
).then(
interface.process_message,
inputs=[txt, image_display, chatbot],
outputs=[chatbot, image_display, txt],
).then(lambda: gr.Textbox(interactive=True), None, [txt])
# upload_button.upload(handle_file_upload, inputs=upload_button, outputs=[image_display])
# dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=[image_display])
clear_btn.click(clear_chat, outputs=[chatbot, image_display])
new_thread_btn.click(new_thread, outputs=[chatbot, image_display, conclusion_tb, approve_btn, download_pdf_btn, pdf_preview])
# generate_report_btn.click(generate_report, outputs=[diseases_df, conclusion_tb, approve_btn, reject_btn])
generate_report_btn.click(generate_report, outputs=[conclusion_tb, approve_btn])
approve_btn.click(
build_pdf_and_preview,
# inputs=[diseases_df, conclusion_tb],
inputs=[conclusion_tb],
outputs=[download_pdf_btn, pdf_preview],
)
# reject_btn.click(show_reject_ui, outputs=[rejection_text, submit_reject_btn, cancel_reject_btn])
# cancel_reject_btn.click(hide_reject_ui, outputs=[rejection_text, submit_reject_btn, cancel_reject_btn])
return demo