|
import re, uuid |
|
import base64 |
|
|
|
import bcrypt |
|
import gradio as gr |
|
from gradio_pdf import PDF |
|
from pathlib import Path |
|
import time |
|
import shutil |
|
from typing import AsyncGenerator, List, Optional, Tuple |
|
from gradio import ChatMessage |
|
from fpdf import FPDF |
|
|
|
REPORT_DIR = Path("reports") |
|
REPORT_DIR.mkdir(exist_ok=True) |
|
SALT = b'$2b$12$MC7djiqmIR7154Syul5Wme' |
|
|
|
USERS = { |
|
'test_user': b'$2b$12$MC7djiqmIR7154Syul5WmeQwebwsNOK5svMX08zMYhvpF9P9IVXe6', |
|
'pna': b'$2b$12$MC7djiqmIR7154Syul5WmeWTzYft1UnOV4uGVn54FGfmbH3dRNq1C', |
|
'dr_rajat': b'$2b$12$MC7djiqmIR7154Syul5WmeKZX8DXEs48GWbFpO3nRtFLbB5W/2suW' |
|
} |
|
|
|
class ChatInterface: |
|
""" |
|
A chat interface for interacting with a medical AI agent through Gradio. |
|
|
|
Handles file uploads, message processing, and chat history management. |
|
Supports both regular image files and DICOM medical imaging files. |
|
""" |
|
|
|
def __init__(self, agent, tools_dict): |
|
""" |
|
Initialize the chat interface. |
|
|
|
Args: |
|
agent: The medical AI agent to handle requests |
|
tools_dict (dict): Dictionary of available tools for image processing |
|
""" |
|
self.agent = agent |
|
self.tools_dict = tools_dict |
|
self.upload_dir = Path("temp") |
|
self.upload_dir.mkdir(exist_ok=True) |
|
self.current_thread_id = None |
|
|
|
self.original_file_path = None |
|
self.display_file_path = None |
|
|
|
def handle_upload(self, file_path: str) -> str: |
|
""" |
|
Handle new file upload and set appropriate paths. |
|
|
|
Args: |
|
file_path (str): Path to the uploaded file |
|
|
|
Returns: |
|
str: Display path for UI, or None if no file uploaded |
|
""" |
|
if not file_path: |
|
return None |
|
|
|
source = Path(file_path) |
|
timestamp = int(time.time()) |
|
|
|
|
|
suffix = source.suffix.lower() |
|
saved_path = self.upload_dir / f"upload_{timestamp}{suffix}" |
|
shutil.copy2(file_path, saved_path) |
|
self.original_file_path = str(saved_path) |
|
|
|
|
|
if suffix == ".dcm": |
|
output, _ = self.tools_dict["DicomProcessorTool"]._run(str(saved_path)) |
|
self.display_file_path = output["image_path"] |
|
else: |
|
self.display_file_path = str(saved_path) |
|
|
|
return self.display_file_path, gr.update(interactive=True), gr.update(interactive=True) |
|
|
|
def add_message( |
|
self, message: str, display_image: str, history: List[dict] |
|
) -> Tuple[List[dict], gr.Textbox]: |
|
""" |
|
Add a new message to the chat history. |
|
|
|
Args: |
|
message (str): Text message to add |
|
display_image (str): Path to image being displayed |
|
history (List[dict]): Current chat history |
|
|
|
Returns: |
|
Tuple[List[dict], gr.Textbox]: Updated history and textbox component |
|
""" |
|
image_path = self.original_file_path or display_image |
|
if image_path is not None: |
|
history.append({"role": "user", "content": {"path": image_path}}) |
|
if message is not None: |
|
history.append({"role": "user", "content": message}) |
|
|
|
return history, gr.Textbox(value=message, interactive=False) |
|
|
|
async def process_message( |
|
self, message: str, display_image: Optional[str], chat_history: List[ChatMessage] |
|
) -> AsyncGenerator[Tuple[List[ChatMessage], Optional[str], str], None]: |
|
""" |
|
Process a message and generate responses. |
|
|
|
Args: |
|
message (str): User message to process |
|
display_image (Optional[str]): Path to currently displayed image |
|
chat_history (List[ChatMessage]): Current chat history |
|
|
|
Yields: |
|
Tuple[List[ChatMessage], Optional[str], str]: Updated chat history, display path, and empty string |
|
""" |
|
chat_history = chat_history or [] |
|
|
|
|
|
if not self.current_thread_id: |
|
self.current_thread_id = str(time.time()) |
|
|
|
messages = [] |
|
image_path = self.original_file_path or display_image |
|
|
|
if image_path is not None: |
|
|
|
messages.append({"role": "user", "content": f"image_path: {image_path}"}) |
|
|
|
|
|
with open(image_path, "rb") as img_file: |
|
img_base64 = base64.b64encode(img_file.read()).decode("utf-8") |
|
|
|
messages.append( |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "image_url", |
|
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}, |
|
} |
|
], |
|
} |
|
) |
|
|
|
if message is not None: |
|
messages.append({"role": "user", "content": [{"type": "text", "text": message}]}) |
|
|
|
try: |
|
for event in self.agent.workflow.stream( |
|
{"messages": messages}, {"configurable": {"thread_id": self.current_thread_id}} |
|
): |
|
if isinstance(event, dict): |
|
if "process" in event: |
|
content = event["process"]["messages"][-1].content |
|
if content: |
|
content = re.sub(r"temp/[^\s]*", "", content) |
|
chat_history.append(ChatMessage(role="assistant", content=content)) |
|
yield chat_history, self.display_file_path, "" |
|
|
|
elif "execute" in event: |
|
for message in event["execute"]["messages"]: |
|
tool_name = message.name |
|
tool_result = eval(message.content)[0] |
|
|
|
if tool_result: |
|
metadata = {"title": f"🖼️ Image from tool: {tool_name}"} |
|
formatted_result = " ".join( |
|
line.strip() for line in str(tool_result).splitlines() |
|
).strip() |
|
metadata["description"] = formatted_result |
|
chat_history.append( |
|
ChatMessage( |
|
role="assistant", |
|
content=formatted_result, |
|
metadata=metadata, |
|
) |
|
) |
|
|
|
|
|
if tool_name == "image_visualizer": |
|
self.display_file_path = tool_result["image_path"] |
|
chat_history.append( |
|
ChatMessage( |
|
role="assistant", |
|
|
|
content={"path": self.display_file_path}, |
|
) |
|
) |
|
|
|
yield chat_history, self.display_file_path, "" |
|
|
|
except Exception as e: |
|
chat_history.append( |
|
ChatMessage( |
|
role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"} |
|
) |
|
) |
|
yield chat_history, self.display_file_path |
|
|
|
|
|
def create_demo(agent, tools_dict): |
|
""" |
|
Create a Gradio demo interface for the medical AI agent. |
|
|
|
Args: |
|
agent: The medical AI agent to handle requests |
|
tools_dict (dict): Dictionary of available tools for image processing |
|
|
|
Returns: |
|
gr.Blocks: Gradio Blocks interface |
|
""" |
|
interface = ChatInterface(agent, tools_dict) |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
auth_state = gr.State(False) |
|
|
|
with gr.Column(visible=True) as login_page: |
|
gr.Markdown("## 🔐 Login") |
|
username = gr.Textbox(label="Username") |
|
password = gr.Textbox(label="Password", type="password") |
|
login_button = gr.Button("Login") |
|
login_error = gr.Markdown(visible=False) |
|
|
|
with gr.Column(visible=False) as main_page: |
|
gr.Markdown( |
|
""" |
|
# 🏥 MOLx - Powered by MedRAX |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
chatbot = gr.Chatbot( |
|
[], |
|
height=800, |
|
container=True, |
|
show_label=True, |
|
elem_classes="chat-box", |
|
type="messages", |
|
label="Agent", |
|
avatar_images=( |
|
None, |
|
"assets/medrax_logo.jpg", |
|
), |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
txt = gr.Textbox( |
|
show_label=False, |
|
placeholder="Ask about the X-ray...", |
|
container=False, |
|
) |
|
|
|
with gr.Column(scale=3): |
|
with gr.Tabs(): |
|
with gr.Tab(label="Image section"): |
|
image_display = gr.Image( |
|
label="Image", type="filepath", height=685, container=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
analyze_btn = gr.Button("Analyze 1") |
|
analyze2_btn = gr.Button("Analyze 2") |
|
segment_btn = gr.Button("Segment") |
|
with gr.Row(): |
|
clear_btn = gr.Button("Clear Chat") |
|
new_thread_btn = gr.Button("New Patient") |
|
|
|
with gr.Tab(label="Report section"): |
|
generate_report_btn = gr.Button("Generate Report") |
|
|
|
|
|
|
|
|
|
conclusion_tb = gr.Textbox(label="Conclusion", interactive=False) |
|
with gr.Row(): |
|
approve_btn = gr.Button("Approve", visible=False) |
|
|
|
download_pdf_btn = gr.DownloadButton(label="📥 Download PDF", visible=False) |
|
|
|
|
|
pdf_preview = PDF(visible=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def authenticate(username, password): |
|
hashed = USERS.get(username) |
|
if hashed and bcrypt.checkpw(password.encode(), hashed): |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
True |
|
) |
|
return None, None, gr.update(value="❌ Incorrect username or password", visible=True), False |
|
|
|
def clear_chat(): |
|
interface.original_file_path = None |
|
interface.display_file_path = None |
|
return [], None |
|
|
|
def new_thread(): |
|
interface.current_thread_id = str(time.time()) |
|
return ( |
|
[], |
|
interface.display_file_path, |
|
gr.update(value=None, interactive=False), |
|
gr.update(visible=False), |
|
|
|
gr.update(value=None, visible=False), |
|
gr.update(value=None, visible=False) |
|
) |
|
|
|
def handle_file_upload(file): |
|
return interface.handle_upload(file.name) |
|
|
|
def generate_report(): |
|
result = interface.agent.summarize_message(interface.current_thread_id) |
|
return ( |
|
gr.update(value=result["Conclusion"], lines=4, interactive=True), |
|
gr.update(visible=True), |
|
|
|
) |
|
|
|
|
|
def records_to_pdf(conclusion) -> Path: |
|
""" |
|
Writes a PDF report under ./reports/ and returns the Path. |
|
""" |
|
pdf = FPDF() |
|
pdf.set_auto_page_break(auto=True, margin=15) |
|
pdf.add_page() |
|
pdf.set_font(family="Helvetica", size=12) |
|
|
|
pdf.cell(0, 10, "Chest-X-ray Report", ln=1, align="C") |
|
pdf.ln(4) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pdf.set_font(family="Helvetica", style="") |
|
pdf.multi_cell(0, 8, conclusion) |
|
|
|
pdf_path = REPORT_DIR / f"report_{uuid.uuid4().hex}.pdf" |
|
pdf.output(str(pdf_path)) |
|
return pdf_path |
|
|
|
|
|
def build_pdf_and_preview(conclusion): |
|
|
|
pdf_path = records_to_pdf(conclusion) |
|
|
|
iframe_html = ( |
|
f'<iframe src="file={pdf_path}" ' |
|
'style="width:100%;height:650px;border:none;"></iframe>' |
|
) |
|
|
|
return ( |
|
gr.update(value=pdf_path, visible=True), |
|
gr.update(value=str(pdf_path), visible=True) |
|
) |
|
|
|
def show_reject_ui(): |
|
return gr.update(visible=True, value=""), gr.update(visible=True), gr.update(visible=True) |
|
|
|
def hide_reject_ui(): |
|
return gr.update(visible=False, value=""), gr.update(visible=False), gr.update(visible=False) |
|
|
|
login_button.click(authenticate, [username, password], [login_page, main_page, login_error, auth_state]) |
|
|
|
chat_msg = txt.submit( |
|
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] |
|
) |
|
bot_msg = chat_msg.then( |
|
interface.process_message, |
|
inputs=[txt, image_display, chatbot], |
|
outputs=[chatbot, image_display, txt], |
|
) |
|
bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt]) |
|
|
|
analyze_btn.click( |
|
lambda: gr.update(value="Analyze this xray and give me a detailed response. Use the medgemma_xray_expert tool"), None, txt |
|
).then( |
|
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] |
|
).then( |
|
interface.process_message, |
|
inputs=[txt, image_display, chatbot], |
|
outputs=[chatbot, image_display, txt], |
|
).then(lambda: gr.Textbox(interactive=True), None, [txt]) |
|
|
|
analyze2_btn.click( |
|
lambda: gr.update(value="Analyze this xray and give me a detailed response. Use the chest_xray_expert tool"), None, txt |
|
).then( |
|
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] |
|
).then( |
|
interface.process_message, |
|
inputs=[txt, image_display, chatbot], |
|
outputs=[chatbot, image_display, txt], |
|
).then(lambda: gr.Textbox(interactive=True), None, [txt]) |
|
|
|
segment_btn.click( |
|
lambda: gr.update(value="Segment the major affected lung"), None, txt |
|
).then( |
|
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] |
|
).then( |
|
interface.process_message, |
|
inputs=[txt, image_display, chatbot], |
|
outputs=[chatbot, image_display, txt], |
|
).then(lambda: gr.Textbox(interactive=True), None, [txt]) |
|
|
|
|
|
|
|
|
|
|
|
clear_btn.click(clear_chat, outputs=[chatbot, image_display]) |
|
new_thread_btn.click(new_thread, outputs=[chatbot, image_display, conclusion_tb, approve_btn, download_pdf_btn, pdf_preview]) |
|
|
|
generate_report_btn.click(generate_report, outputs=[conclusion_tb, approve_btn]) |
|
approve_btn.click( |
|
build_pdf_and_preview, |
|
|
|
inputs=[conclusion_tb], |
|
outputs=[download_pdf_btn, pdf_preview], |
|
) |
|
|
|
|
|
|
|
return demo |
|
|