Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import time | |
import torch | |
import os | |
from models import MedicalConsultationPipeline | |
from utils import ( | |
highlight_text_with_entities, | |
format_duration, | |
create_risk_gauge, | |
create_risk_probability_chart, | |
save_consultation, | |
load_consultation_history, | |
init_session_state, | |
RISK_COLORS | |
) | |
# Page configuration | |
st.set_page_config( | |
page_title="AI Medical Consultation", | |
page_icon="🩺", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Custom CSS | |
def load_css(): | |
with open("style.css", "r") as f: | |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
# 检查本地是否有fine-tuned的T5模型 | |
def find_fine_tuned_model(): | |
possible_local_paths = [ | |
"./finetuned_t5-small", # 添加用户提供的微调模型路径 | |
"./t5-small-medical-recommendation", | |
"./models/t5-small-medical-recommendation", | |
"./fine_tuned_models/t5-small", | |
"./output", | |
"./fine_tuning_output" | |
] | |
for path in possible_local_paths: | |
if os.path.exists(path): | |
return path | |
return "t5-small" # 如果没有找到,返回基础模型 | |
# Initialize session state | |
init_session_state() | |
# Apply custom CSS | |
load_css() | |
# Sidebar for settings and history | |
with st.sidebar: | |
st.image("https://img.icons8.com/fluency/96/000000/hospital-3.png", width=80) | |
st.title("AI Medical Assistant") | |
st.markdown("---") | |
with st.expander("⚙️ Settings", expanded=False): | |
# Model settings | |
st.subheader("Model Settings") | |
symptom_model = st.selectbox( | |
"Symptom Extraction Model", | |
["dmis-lab/biobert-v1.1"], | |
index=0, | |
disabled=st.session_state.loaded_models # Disable after models are loaded | |
) | |
risk_model = st.selectbox( | |
"Risk Classification Model", | |
["microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"], | |
index=0, | |
disabled=st.session_state.loaded_models # Disable after models are loaded | |
) | |
# 查找可用的t5模型 | |
available_t5_model = find_fine_tuned_model() | |
recommendation_model_options = [] | |
# 总是添加基础模型 | |
recommendation_model_options.append("t5-small (base model)") | |
# 如果找到了fine-tuned模型,添加到选项中 | |
if available_t5_model != "t5-small": | |
recommendation_model_options.insert(0, f"{available_t5_model} (fine-tuned)") | |
recommendation_model_label = st.selectbox( | |
"Recommendation Model", | |
recommendation_model_options, | |
index=0, | |
disabled=st.session_state.loaded_models # Disable after models are loaded | |
) | |
# 提取实际的模型路径 | |
if "(fine-tuned)" in recommendation_model_label: | |
recommendation_model = available_t5_model | |
else: | |
recommendation_model = "t5-small" | |
# Device selection | |
device = st.radio( | |
"Compute Device", | |
["CPU", "GPU (if available)"], | |
index=1 if torch.cuda.is_available() else 0, | |
disabled=st.session_state.loaded_models # Disable after models are loaded | |
) | |
device = "cuda" if device == "GPU (if available)" and torch.cuda.is_available() else "cpu" | |
if st.session_state.loaded_models: | |
st.info("注意:设置已锁定,因为模型已加载。要更改设置,请刷新页面。") | |
# Consultation history section | |
st.markdown("---") | |
st.subheader("📋 Consultation History") | |
# Load consultation history | |
if st.button("Refresh History"): | |
st.session_state.consultation_history = load_consultation_history() | |
st.success("History refreshed!") | |
# If history is not already loaded, load it | |
if not st.session_state.consultation_history: | |
st.session_state.consultation_history = load_consultation_history() | |
# Display history items | |
if not st.session_state.consultation_history: | |
st.info("No previous consultations found.") | |
else: | |
for i, consultation in enumerate(st.session_state.consultation_history[:10]): # Show only the 10 most recent | |
timestamp = pd.to_datetime(consultation.get("timestamp", "")).strftime("%Y-%m-%d %H:%M") | |
risk_level = consultation.get("risk", {}).get("risk_level", "Unknown") | |
risk_color = RISK_COLORS.get(risk_level, "#6c757d") | |
# Create a clickable history item | |
history_item = f""" | |
<div class='history-item' onclick=''> | |
<strong>Patient Input:</strong> {consultation.get('input_text', '')[:50]}...<br> | |
<strong>Time:</strong> {timestamp}<br> | |
<strong>Risk Level:</strong> <span style='color:{risk_color};'>{risk_level}</span> | |
</div> | |
""" | |
clicked = st.markdown(history_item, unsafe_allow_html=True) | |
# If clicked, set this consultation as the current result | |
if clicked: | |
st.session_state.current_result = consultation | |
# Main app layout | |
st.markdown("<h1 class='main-header'>AI-Powered Medical Consultation</h1>", unsafe_allow_html=True) | |
# Introduction row | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
st.markdown(""" | |
<div class="card"> | |
<h2 class="card-header">How it Works</h2> | |
<p>This AI-powered medical consultation system helps you understand your symptoms and provides guidance on next steps.</p> | |
<p><strong>Simply describe your symptoms</strong> in natural language and the system will:</p> | |
<ol> | |
<li>Extract key symptoms and duration information</li> | |
<li>Assess your risk level</li> | |
<li>Generate personalized medical recommendations</li> | |
</ol> | |
<p><em>Note: This system is for informational purposes only and does not replace professional medical advice.</em></p> | |
</div> | |
""", unsafe_allow_html=True) | |
with col2: | |
st.markdown(""" | |
<div class="card"> | |
<h2 class="card-header">Example Inputs</h2> | |
<ul> | |
<li>"I've been experiencing severe headaches and dizziness for about 2 weeks. Sometimes I also feel nauseous."</li> | |
<li>"My child has had a high fever of 39°C since yesterday and is coughing a lot."</li> | |
<li>"I've noticed a persistent rash on my arm for the past 3 days, it's itchy and slightly swollen."</li> | |
</ul> | |
</div> | |
""", unsafe_allow_html=True) | |
# 显示当前使用的模型信息 | |
model_info = f""" | |
<div class="card"> | |
<h2 class="card-header">当前模型配置</h2> | |
<ul> | |
<li><strong>症状抽取模型:</strong> {symptom_model}</li> | |
<li><strong>风险分类模型:</strong> {risk_model}</li> | |
<li><strong>推荐生成模型:</strong> {recommendation_model} {"(微调模型)" if recommendation_model != "t5-small" else "(基础模型)"}</li> | |
<li><strong>计算设备:</strong> {device.upper()}</li> | |
</ul> | |
</div> | |
""" | |
st.markdown(model_info, unsafe_allow_html=True) | |
# Load models on first run or when settings change | |
def load_pipeline(_symptom_model, _risk_model, _recommendation_model, _device): | |
return MedicalConsultationPipeline( | |
symptom_model=_symptom_model, | |
risk_model=_risk_model, | |
recommendation_model=_recommendation_model, | |
device=_device | |
) | |
# Only load models if they haven't been loaded yet | |
if not st.session_state.loaded_models: | |
try: | |
with st.spinner("Loading AI models... This may take a minute..."): | |
pipeline = load_pipeline(symptom_model, risk_model, recommendation_model, device) | |
st.session_state.pipeline = pipeline | |
st.session_state.loaded_models = True | |
st.success("✅ Models loaded successfully!") | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
else: | |
pipeline = st.session_state.pipeline | |
# Input section | |
st.markdown("<h2 class='subheader'>Describe Your Symptoms</h2>", unsafe_allow_html=True) | |
# Text input for patient description | |
patient_input = st.text_area( | |
"Please describe your symptoms, including when they started and any other relevant information:", | |
height=150, | |
placeholder="Example: I've been experiencing severe headaches and dizziness for about 2 weeks. Sometimes I also feel nauseous." | |
) | |
# Process button | |
col1, col2, col3 = st.columns([1, 1, 1]) | |
with col2: | |
process_button = st.button("Analyze Symptoms", type="primary", use_container_width=True) | |
# Handle processing | |
if process_button and patient_input and not st.session_state.is_processing: | |
st.session_state.is_processing = True | |
# Process the input | |
with st.spinner("Analyzing your symptoms..."): | |
try: | |
# Process through pipeline | |
start_time = time.time() | |
result = pipeline.process(patient_input) | |
elapsed_time = time.time() - start_time | |
# Save result to session state | |
st.session_state.current_result = result | |
# Save consultation to history | |
save_consultation(result) | |
# Success message | |
st.success(f"Analysis completed in {elapsed_time:.2f} seconds!") | |
except Exception as e: | |
st.error(f"Error processing your input: {str(e)}") | |
st.session_state.is_processing = False | |
# Results section - show if there's a current result | |
if st.session_state.current_result: | |
result = st.session_state.current_result | |
st.markdown("<h2 class='subheader'>Consultation Results</h2>", unsafe_allow_html=True) | |
# Create tabs for different sections of the results | |
tabs = st.tabs(["Overview", "Symptoms Analysis", "Risk Assessment", "Recommendations"]) | |
# Overview tab - summary of all results | |
with tabs[0]: | |
col1, col2 = st.columns([3, 2]) | |
with col1: | |
st.markdown(""" | |
<div class="card"> | |
<h3 class="card-header">Patient Description</h3> | |
""", unsafe_allow_html=True) | |
# Highlight symptoms and duration in the text | |
highlighted_text = highlight_text_with_entities( | |
result.get("input_text", ""), | |
result.get("extraction", {}).get("symptoms", []) | |
) | |
st.markdown(f"<p>{highlighted_text}</p>", unsafe_allow_html=True) | |
st.markdown("</div>", unsafe_allow_html=True) | |
# Recommendations card | |
st.markdown(""" | |
<div class="card"> | |
<h3 class="card-header">Medical Recommendations</h3> | |
<div class="recommendation-container"> | |
""", unsafe_allow_html=True) | |
recommendation = result.get("recommendation", "No recommendations available.") | |
st.markdown(f"<p>{recommendation}</p>", unsafe_allow_html=True) | |
st.markdown(""" | |
</div> | |
<p><em>Note: This is AI-generated guidance and should not replace professional medical advice.</em></p> | |
</div> | |
""", unsafe_allow_html=True) | |
with col2: | |
# Risk level card | |
risk_level = result.get("risk", {}).get("risk_level", "Unknown") | |
confidence = result.get("risk", {}).get("confidence", 0.0) | |
st.markdown(f""" | |
<div class="card"> | |
<h3 class="card-header">Risk Assessment</h3> | |
<div style="text-align: center;"> | |
<span class="risk-{risk_level.lower()}" style="font-size: 1.8rem;">{risk_level}</span> | |
<p>Confidence: {confidence:.1%}</p> | |
</div> | |
""", unsafe_allow_html=True) | |
# Add risk gauge | |
risk_gauge = create_risk_gauge(risk_level, confidence) | |
st.plotly_chart(risk_gauge, use_container_width=True, key="overview_risk_gauge") | |
st.markdown("</div>", unsafe_allow_html=True) | |
# Extracted symptoms summary | |
st.markdown(""" | |
<div class="card"> | |
<h3 class="card-header">Key Findings</h3> | |
""", unsafe_allow_html=True) | |
symptoms = result.get("extraction", {}).get("symptoms", []) | |
duration = result.get("extraction", {}).get("duration", []) | |
if symptoms: | |
st.markdown("<strong>Identified Symptoms:</strong>", unsafe_allow_html=True) | |
for symptom in symptoms: | |
st.markdown(f"• {symptom['text']} ({symptom['score']:.1%} confidence)", unsafe_allow_html=True) | |
else: | |
st.info("No specific symptoms identified") | |
st.markdown("<br><strong>Duration Information:</strong>", unsafe_allow_html=True) | |
st.markdown(f"<p>{format_duration(duration)}</p>", unsafe_allow_html=True) | |
st.markdown("</div>", unsafe_allow_html=True) | |
# Symptoms Analysis tab | |
with tabs[1]: | |
st.markdown(""" | |
<div class="card"> | |
<h3 class="card-header">Detailed Symptom Analysis</h3> | |
""", unsafe_allow_html=True) | |
symptoms = result.get("extraction", {}).get("symptoms", []) | |
if symptoms: | |
# Create a DataFrame for symptoms | |
symptom_df = pd.DataFrame([ | |
{ | |
"Symptom": s["text"], | |
"Confidence": s["score"], | |
"Start Position": s["start"], | |
"End Position": s["end"] | |
} for s in symptoms | |
]) | |
# Sort by confidence | |
symptom_df = symptom_df.sort_values("Confidence", ascending=False) | |
# Display DataFrame | |
st.dataframe(symptom_df, use_container_width=True) | |
# Bar chart of symptoms by confidence | |
if len(symptoms) > 1: | |
st.markdown("<h4>Symptom Confidence Scores</h4>", unsafe_allow_html=True) | |
chart_data = symptom_df[["Symptom", "Confidence"]].set_index("Symptom") | |
st.bar_chart(chart_data, use_container_width=True) | |
else: | |
st.info("No specific symptoms were detected in the input text.") | |
st.markdown("</div>", unsafe_allow_html=True) | |
# Duration information card | |
st.markdown(""" | |
<div class="card"> | |
<h3 class="card-header">Duration Analysis</h3> | |
""", unsafe_allow_html=True) | |
duration = result.get("extraction", {}).get("duration", []) | |
if duration: | |
# Create a DataFrame for duration information | |
duration_df = pd.DataFrame([ | |
{ | |
"Duration": d["text"], | |
"Start Position": d["start"], | |
"End Position": d["end"] | |
} for d in duration | |
]) | |
# Display DataFrame | |
st.dataframe(duration_df, use_container_width=True) | |
# Highlight duration in text | |
st.markdown("<h4>Original Text with Duration Highlighted</h4>", unsafe_allow_html=True) | |
# Highlight duration in a different color | |
duration_text = result.get("input_text", "") | |
sorted_duration = sorted(duration, key=lambda x: x['start'], reverse=True) | |
for d in sorted_duration: | |
start = d['start'] | |
end = d['end'] | |
highlight = f"<span class='duration-highlight'>{duration_text[start:end]}</span>" | |
duration_text = duration_text[:start] + highlight + duration_text[end:] | |
st.markdown(f"<p>{duration_text}</p>", unsafe_allow_html=True) | |
else: | |
st.info("No specific duration information was detected in the input text.") | |
st.markdown("</div>", unsafe_allow_html=True) | |
# Risk Assessment tab | |
with tabs[2]: | |
st.markdown(""" | |
<div class="card"> | |
<h3 class="card-header">Risk Level Assessment</h3> | |
""", unsafe_allow_html=True) | |
risk_data = result.get("risk", {}) | |
risk_level = risk_data.get("risk_level", "Unknown") | |
confidence = risk_data.get("confidence", 0.0) | |
probabilities = risk_data.get("all_probabilities", {}) | |
col1, col2 = st.columns(2) | |
with col1: | |
# Display risk gauge | |
risk_gauge = create_risk_gauge(risk_level, confidence) | |
st.plotly_chart(risk_gauge, use_container_width=True, key="risk_assessment_gauge") | |
with col2: | |
# Display probability distribution | |
prob_chart = create_risk_probability_chart(probabilities) | |
st.plotly_chart(prob_chart, use_container_width=True, key="risk_probability_chart") | |
# Risk level descriptions | |
st.markdown("<h4>Risk Levels Explained</h4>", unsafe_allow_html=True) | |
risk_descriptions = { | |
"Low": """ | |
<div style="border-left: 3px solid #7FD8BE; padding-left: 10px; margin: 10px 0;"> | |
<strong style="color: #7FD8BE;">Low Risk</strong>: Your symptoms suggest a condition that is likely non-urgent. | |
While it's good to stay vigilant, these types of conditions typically don't require immediate medical attention | |
and can often be managed with self-care or a routine appointment within the next few days or weeks. | |
</div> | |
""", | |
"Medium": """ | |
<div style="border-left: 3px solid #FFC857; padding-left: 10px; margin: 10px 0;"> | |
<strong style="color: #FFC857;">Medium Risk</strong>: Your symptoms indicate a condition that may need medical attention | |
soon, but may not be an emergency. Consider scheduling an appointment with your primary care provider within 24-48 hours, | |
or visit an urgent care facility if your symptoms worsen or if you cannot schedule a timely appointment. | |
</div> | |
""", | |
"High": """ | |
<div style="border-left: 3px solid #E84855; padding-left: 10px; margin: 10px 0;"> | |
<strong style="color: #E84855;">High Risk</strong>: Your symptoms suggest a potentially serious condition that requires | |
prompt medical attention. Consider seeking emergency care or calling emergency services if symptoms are severe or rapidly | |
worsening, especially if they include difficulty breathing, severe pain, or altered consciousness. | |
</div> | |
""" | |
} | |
# Display the description for the current risk level first | |
if risk_level in risk_descriptions: | |
st.markdown(risk_descriptions[risk_level], unsafe_allow_html=True) | |
# Then display the others | |
for level, desc in risk_descriptions.items(): | |
if level != risk_level: | |
st.markdown(desc, unsafe_allow_html=True) | |
st.markdown("</div>", unsafe_allow_html=True) | |
# Disclaimer | |
st.warning(""" | |
**Important Disclaimer**: This risk assessment is based on AI analysis and should be used as a guidance only. | |
It is not a definitive medical diagnosis. Always consult with a healthcare professional for proper evaluation, | |
especially if you experience severe symptoms, symptoms that persist or worsen, or if you're unsure about your condition. | |
""") | |
# Recommendations tab | |
with tabs[3]: | |
st.markdown(""" | |
<div class="card"> | |
<h3 class="card-header">Detailed Recommendations</h3> | |
""", unsafe_allow_html=True) | |
recommendation = result.get("recommendation", "No recommendations available.") | |
# Split recommendation into paragraphs for better readability | |
recommendation_parts = recommendation.split('. ') | |
formatted_recommendation = "" | |
current_paragraph = [] | |
for part in recommendation_parts: | |
current_paragraph.append(part) | |
# Start a new paragraph every 2-3 sentences | |
if len(current_paragraph) >= 2 and part.endswith('.'): | |
formatted_recommendation += '. '.join(current_paragraph) + ".<br><br>" | |
current_paragraph = [] | |
# Add any remaining parts | |
if current_paragraph: | |
formatted_recommendation += '. '.join(current_paragraph) | |
st.markdown(f"<p>{formatted_recommendation}</p>", unsafe_allow_html=True) | |
st.markdown("</div>", unsafe_allow_html=True) | |
# Department suggestion based on symptoms | |
st.markdown(""" | |
<div class="card"> | |
<h3 class="card-header">Suggested Medical Departments</h3> | |
""", unsafe_allow_html=True) | |
# 使用模型生成的科室建议而不是规则基础的建议 | |
departments = result.get("structured_recommendation", {}).get("departments", []) | |
if not departments: | |
departments = ["General Medicine / Primary Care"] | |
# Display departments | |
for dept in departments: | |
st.markdown(f"• **{dept}**", unsafe_allow_html=True) | |
st.markdown("<br><em>Note: Department suggestions are based on your symptoms and risk level. Consult with a healthcare provider for proper referral.</em>", unsafe_allow_html=True) | |
st.markdown("</div>", unsafe_allow_html=True) | |
# Self-care suggestions | |
st.markdown(""" | |
<div class="card"> | |
<h3 class="card-header">Self-Care Suggestions</h3> | |
""", unsafe_allow_html=True) | |
# 使用模型生成的自我护理建议 | |
self_care_tips = result.get("structured_recommendation", {}).get("self_care", []) | |
if self_care_tips: | |
st.markdown("<ul>", unsafe_allow_html=True) | |
for tip in self_care_tips: | |
st.markdown(f"<li>{tip}</li>", unsafe_allow_html=True) | |
st.markdown("</ul>", unsafe_allow_html=True) | |
else: | |
# 如果没有获取到模型生成的自我护理建议,则显示默认信息 | |
risk_level = result.get("risk", {}).get("risk_level", "Medium") | |
if risk_level == "Low": | |
st.markdown(""" | |
<p>While waiting for your symptoms to improve:</p> | |
<ul> | |
<li>Ensure you're getting adequate rest</li> | |
<li>Stay hydrated by drinking plenty of water</li> | |
<li>Monitor your symptoms and note any changes</li> | |
<li>Consider over-the-counter medications appropriate for your symptoms</li> | |
<li>Maintain a balanced diet to support your immune system</li> | |
</ul> | |
""", unsafe_allow_html=True) | |
elif risk_level == "Medium": | |
st.markdown(""" | |
<p>While arranging medical care:</p> | |
<ul> | |
<li>Rest and avoid strenuous activities</li> | |
<li>Stay hydrated and maintain proper nutrition</li> | |
<li>Take your temperature and other vital signs if possible</li> | |
<li>Write down any changes in symptoms and when they occur</li> | |
<li>Have someone stay with you if your symptoms are concerning</li> | |
</ul> | |
""", unsafe_allow_html=True) | |
else: # High risk | |
st.markdown(""" | |
<p>While seeking emergency care:</p> | |
<ul> | |
<li>Don't wait - seek medical attention immediately</li> | |
<li>Have someone drive you to the emergency room if safe to do so</li> | |
<li>Call emergency services if symptoms are severe</li> | |
<li>Bring a list of your current medications if possible</li> | |
<li>Follow any first aid protocols appropriate for your symptoms</li> | |
</ul> | |
""", unsafe_allow_html=True) | |
st.markdown("</div>", unsafe_allow_html=True) | |
# Footer | |
st.markdown(""" | |
<div class="footer"> | |
<p>AI Medical Consultation System | Created with Streamlit | Not a substitute for professional medical advice</p> | |
<p>Powered by: dmis-lab/biobert-v1.1, microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract, and fine-tuned T5-small</p> | |
</div> | |
""", unsafe_allow_html=True) |