Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoImageProcessor, AutoModelForImageClassification | |
from PIL import Image | |
import json | |
import re | |
import pandas as pd | |
from datetime import datetime | |
import plotly.express as px | |
from io import StringIO | |
# Load text model | |
text_model_name = "microsoft/BiomedVLP-CXR-BERT-specialized" | |
text_tokenizer = AutoTokenizer.from_pretrained(text_model_name) | |
text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name) | |
# Load image model | |
image_model_name = "aehrc/cxrmate-tf" # Replace with skin disease model if needed | |
image_processor = AutoImageProcessor.from_pretrained(image_model_name) | |
image_model = AutoModelForImageClassification.from_pretrained(image_model_name) | |
# Define labels | |
text_labels = ["Positive", "Negative", "Neutral", "Informative"] # For text analysis | |
image_labels = ["Normal", "Abnormal"] # For X-ray or skin images | |
# Store conversation state | |
conversation_state = { | |
"history": [], | |
"texts": [], | |
"image_uploaded": False, | |
"last_analysis": None, | |
"analysis_log": [] | |
} | |
# Extract key terms | |
def extract_key_terms(text): | |
terms = re.findall(r'\b(fever|cough|fatigue|headache|sore throat|chest pain|shortness of breath|rash|lesion|study|treatment|trial|astronaut|microgravity)\b', text, re.IGNORECASE) | |
return terms | |
# Generate context-aware follow-up questions | |
def generate_follow_up(terms, history): | |
if not terms: | |
return "Please provide medical text (e.g., symptoms, abstract) or upload an image." | |
if "astronaut" in [t.lower() for t in terms] or "microgravity" in [t.lower() for t in terms]: | |
return "Are you researching space medicine? Please describe physiological data or symptoms in microgravity." | |
if len(terms) < 3: | |
return "Can you provide more details (e.g., duration of symptoms or study context)?" | |
if not conversation_state["image_uploaded"]: | |
return "Would you like to upload an image (e.g., X-ray or skin photo) for analysis?" | |
return "Would you like to analyze another text or image, or export the analysis log?" | |
# Main analysis function | |
def analyze_medical_input(user_input, image=None, chat_history=None, export_format="None"): | |
global conversation_state | |
if not chat_history: | |
chat_history = [] | |
# Process text input | |
text_response = "" | |
text_chart = "" | |
if user_input.strip(): | |
terms = extract_key_terms(user_input) | |
conversation_state["texts"].extend(terms) | |
inputs = text_tokenizer(user_input, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = text_model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
confidence = torch.softmax(logits, dim=-1)[0][predicted_class_idx].item() | |
scores = torch.softmax(logits, dim=-1)[0].tolist() | |
conversation_state["last_analysis"] = { | |
"type": "text", | |
"label": text_labels[predicted_class_idx], | |
"confidence": confidence, | |
"scores": scores, | |
"input": user_input, | |
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
} | |
conversation_state["analysis_log"].append(conversation_state["last_analysis"]) | |
text_response = f"Text Analysis: {text_labels[predicted_class_idx]} (Confidence: {confidence:.2%})" | |
# Text visualization (Chart.js) | |
chart_data = { | |
"type": "bar", | |
"data": { | |
"labels": text_labels, | |
"datasets": [{ | |
"label": "Confidence Scores", | |
"data": scores, | |
"backgroundColor": ["#4CAF50", "#F44336", "#2196F3", "#FF9800"], | |
"borderColor": ["#388E3C", "#D32F2F", "#1976D2", "#F57C00"], | |
"borderWidth": 1 | |
}] | |
}, | |
"options": { | |
"scales": { | |
"y": {"beginAtZero": True, "max": 1, "title": {"display": True, "text": "Confidence"}}, | |
"x": {"title": {"display": True, "text": "Text Categories"}} | |
}, | |
"plugins": {"title": {"display": True, "text": "Text Analysis Confidence"}} | |
} | |
} | |
text_chart = f""" | |
<canvas id='textChart' width='400' height='200'></canvas> | |
<script src='https://cdn.jsdelivr.net/npm/chart.js'></script> | |
<script> | |
new Chart(document.getElementById('textChart'), {json.dumps(chart_data)}); | |
</script> | |
""" | |
# Process image input | |
image_response = "" | |
image_chart = "" | |
if image is not None: | |
conversation_state["image_uploaded"] = True | |
inputs = image_processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = image_model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
confidence = torch.softmax(logits, dim=-1)[0][predicted_class_idx].item() | |
scores = torch.softmax(logits, dim=-1)[0].tolist() | |
conversation_state["last_analysis"] = { | |
"type": "image", | |
"label": image_labels[predicted_class_idx], | |
"confidence": confidence, | |
"scores": scores, | |
"input": "image", | |
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
} | |
conversation_state["analysis_log"].append(conversation_state["last_analysis"]) | |
image_response = f"Image Analysis: {image_labels[predicted_class_idx]} (Confidence: {confidence:.2%})" | |
# Image visualization (Chart.js) | |
chart_data = { | |
"type": "bar", | |
"data": { | |
"labels": image_labels, | |
"datasets": [{ | |
"label": "Confidence Scores", | |
"data": scores, | |
"backgroundColor": ["#4CAF50", "#F44336"], | |
"borderColor": ["#388E3C", "#D32F2F"], | |
"borderWidth": 1 | |
}] | |
}, | |
"options": { | |
"scales": { | |
"y": {"beginAtZero": True, "max": 1, "title": {"display": True, "text": "Confidence"}}, | |
"x": {"title": {"display": True, "text": "Image Categories"}} | |
}, | |
"plugins": {"title": {"display": True, "text": "Image Analysis Confidence"}} | |
} | |
} | |
image_chart = f""" | |
<canvas id='imageChart' width='400' height='200'></canvas> | |
<script src='https://cdn.jsdelivr.net/npm/chart.js'></script> | |
<script> | |
new Chart(document.getElementById('imageChart'), {json.dumps(chart_data)}); | |
</script> | |
""" | |
# Generate trend visualization (Plotly) | |
trend_html = "" | |
if len(conversation_state["analysis_log"]) > 1: | |
df = pd.DataFrame(conversation_state["analysis_log"]) | |
fig = px.line( | |
df, x="timestamp", y="confidence", color="type", | |
title="Analysis Confidence Over Time", | |
labels={"confidence": "Confidence Score", "timestamp": "Time"} | |
) | |
trend_html = fig.to_html(full_html=False) | |
# Combine responses | |
response = "\n".join([r for r in [text_response, image_response] if r]) | |
if not response: | |
response = "No analysis yet. Please provide text or upload an image." | |
response += f"\n\nFollow-Up: {generate_follow_up(conversation_state['texts'], conversation_state['history'])}" | |
response += f"\n\n{text_chart}\n{image_chart}\n{trend_html}" | |
# Handle export | |
if export_format != "None": | |
df = pd.DataFrame(conversation_state["analysis_log"]) | |
if export_format == "JSON": | |
export_data = df.to_json(orient="records") | |
return response, gr.File(value=StringIO(export_data), file_name="analysis_log.json") | |
elif export_format == "CSV": | |
export_data = df.to_csv(index=False) | |
return response, gr.File(value=StringIO(export_data), file_name="analysis_log.csv") | |
# Add disclaimer | |
disclaimer = "β οΈ This tool is for research purposes only and does not provide medical diagnoses. Consult a healthcare professional for medical advice." | |
response += f"\n\n{disclaimer}" | |
conversation_state["history"].append((user_input, response)) | |
return response | |
# Custom CSS for professional UI | |
css = """ | |
body { background-color: #f0f2f5; font-family: 'Segoe UI', Arial, sans-serif; } | |
.gradio-container { max-width: 900px; margin: auto; padding: 30px; background: white; border-radius: 10px; box-shadow: 0 4px 12px rgba(0,0,0,0.1); } | |
h1 { color: #1a3c5e; text-align: center; font-size: 2em; } | |
input, textarea { border-radius: 8px; border: 1px solid #ccc; padding: 10px; } | |
button { background: linear-gradient(90deg, #3498db, #2980b9); color: white; border-radius: 8px; padding: 12px; font-weight: bold; } | |
button:hover { background: linear-gradient(90deg, #2980b9, #1a6ea6); } | |
#export_dropdown { width: 150px; margin-top: 10px; } | |
""" | |
# Create Gradio interface | |
with gr.Blocks(css=css) as iface: | |
gr.Markdown("# Ultra-Advanced Medical Research Chatbot") | |
gr.Markdown("Analyze medical texts or images for research purposes. Supports symptom analysis, literature review, or space medicine research. Not for medical diagnosis.") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
text_input = gr.Textbox(lines=5, placeholder="Enter symptoms, medical abstract, or space medicine data...") | |
image_input = gr.Image(type="pil", label="Upload X-ray or Skin Image") | |
export_dropdown = gr.Dropdown(choices=["None", "JSON", "CSV"], label="Export Log", value="None") | |
submit_button = gr.Button("Analyze") | |
with gr.Column(scale=3): | |
output = gr.HTML() | |
submit_button.click( | |
fn=analyze_medical_input, | |
inputs=[text_input, image_input, gr.State(), export_dropdown], | |
outputs=[output, gr.File()] | |
) | |
# Launch the interface | |
iface.launch() |