Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForTokenClassification, | |
AutoModelForSequenceClassification, | |
AutoModelForSeq2SeqLM, | |
pipeline | |
) | |
import re | |
import os | |
import json | |
from typing import Dict, List, Tuple, Any | |
class SymptomExtractor: | |
"""Model for extracting symptoms from patient descriptions using BioBERT.""" | |
def __init__(self, model_name="dmis-lab/biobert-v1.1", device=None): | |
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Loading Symptom Extractor model on {self.device}...") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForTokenClassification.from_pretrained(model_name).to(self.device) | |
self.nlp = pipeline("ner", model=self.model, tokenizer=self.tokenizer, device=0 if self.device == "cuda" else -1) | |
print("Symptom Extractor model loaded successfully.") | |
def extract_symptoms(self, text: str) -> Dict[str, Any]: | |
"""Extract symptoms from the input text.""" | |
results = self.nlp(text) | |
# Process the NER results to group related tokens | |
symptoms = [] | |
current_symptom = None | |
for entity in results: | |
if entity["entity"].startswith("B-"): # Beginning of a symptom | |
if current_symptom: | |
symptoms.append(current_symptom) | |
current_symptom = { | |
"text": entity["word"], | |
"start": entity["start"], | |
"end": entity["end"], | |
"score": entity["score"] | |
} | |
elif entity["entity"].startswith("I-") and current_symptom: # Inside a symptom | |
current_symptom["text"] += " " + entity["word"].replace("##", "") | |
current_symptom["end"] = entity["end"] | |
current_symptom["score"] = (current_symptom["score"] + entity["score"]) / 2 | |
if current_symptom: | |
symptoms.append(current_symptom) | |
# Extract duration information | |
duration_patterns = [ | |
r"(\d+)\s*(day|days|week|weeks|month|months|year|years)", | |
r"since\s+(\w+)", | |
r"for\s+(\w+)" | |
] | |
duration_info = [] | |
for pattern in duration_patterns: | |
matches = re.finditer(pattern, text, re.IGNORECASE) | |
for match in matches: | |
duration_info.append({ | |
"text": match.group(0), | |
"start": match.start(), | |
"end": match.end() | |
}) | |
return { | |
"symptoms": symptoms, | |
"duration": duration_info | |
} | |
class RiskClassifier: | |
"""Model for classifying patient risk level using PubMedBERT.""" | |
def __init__(self, model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", device=None): | |
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Loading Risk Classifier model on {self.device}...") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForSequenceClassification.from_pretrained( | |
model_name, | |
num_labels=3 # Low, Medium, High | |
).to(self.device) | |
self.id2label = {0: "Low", 1: "Medium", 2: "High"} | |
print("Risk Classifier model loaded successfully.") | |
def classify_risk(self, text: str) -> Dict[str, Any]: | |
"""Classify the risk level based on the input text.""" | |
inputs = self.tokenizer( | |
text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.softmax(logits, dim=1)[0].cpu().numpy() | |
model_prediction = torch.argmax(logits, dim=1).item() | |
# 由于模型没有经过微调,我们添加基于规则的后处理来调整风险级别 | |
# 检查文本中是否存在高风险关键词 | |
high_risk_keywords = [ | |
"severe", "extreme", "intense", "unbearable", "emergency", | |
"chest pain", "difficulty breathing", "can't breathe", | |
"losing consciousness", "fainted", "seizure", "stroke", "heart attack", | |
"allergic reaction", "bleeding heavily", "blood", "poisoning", | |
"overdose", "suicide", "self-harm", "hallucinations" | |
] | |
medium_risk_keywords = [ | |
"worsening", "spreading", "persistent", "chronic", "recurring", | |
"infection", "fever", "swelling", "rash", "pain", "vomiting", | |
"diarrhea", "dizzy", "headache", "concerning", "worried", | |
"weeks", "days", "increasing", "progressing" | |
] | |
low_risk_keywords = [ | |
"mild", "slight", "minor", "occasional", "intermittent", | |
"improving", "better", "sometimes", "rarely", "manageable" | |
] | |
text_lower = text.lower() | |
# 计算匹配的关键词数量 | |
high_risk_matches = sum(keyword in text_lower for keyword in high_risk_keywords) | |
medium_risk_matches = sum(keyword in text_lower for keyword in medium_risk_keywords) | |
low_risk_matches = sum(keyword in text_lower for keyword in low_risk_keywords) | |
# 根据关键词匹配调整风险预测 | |
adjusted_prediction = model_prediction | |
if high_risk_matches >= 2: | |
adjusted_prediction = 2 # High risk | |
elif high_risk_matches == 1 and medium_risk_matches >= 2: | |
adjusted_prediction = 2 # High risk | |
elif medium_risk_matches >= 3: | |
adjusted_prediction = 1 # Medium risk | |
elif medium_risk_matches >= 1 and low_risk_matches <= 1: | |
adjusted_prediction = 1 # Medium risk | |
elif low_risk_matches >= 2 and high_risk_matches == 0: | |
adjusted_prediction = 0 # Low risk | |
# 如果文本很长(详细描述),可能表明情况更复杂,风险更高 | |
if len(text.split()) > 40 and adjusted_prediction == 0: | |
adjusted_prediction = 1 # 升级到Medium风险 | |
# 对调整后的概率进行修正 | |
adjusted_probabilities = probabilities.copy() | |
# 增强对应风险级别的概率 | |
adjusted_probabilities[adjusted_prediction] = max(0.6, adjusted_probabilities[adjusted_prediction]) | |
# 规范化概率使其总和为1 | |
adjusted_probabilities = adjusted_probabilities / adjusted_probabilities.sum() | |
return { | |
"risk_level": self.id2label[adjusted_prediction], | |
"confidence": float(adjusted_probabilities[adjusted_prediction]), | |
"all_probabilities": { | |
self.id2label[i]: float(prob) | |
for i, prob in enumerate(adjusted_probabilities) | |
}, | |
"original_prediction": self.id2label[model_prediction] | |
} | |
class RecommendationGenerator: | |
"""Model for generating medical recommendations using fine-tuned t5-small.""" | |
def __init__(self, model_path="t5-small", device=None): | |
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Loading Recommendation Generator model on {self.device}...") | |
# 检查常见的微调模型路径 | |
possible_local_paths = [ | |
"./finetuned_t5-small", # 添加用户指定的微调模型路径 | |
"./t5-small-medical-recommendation", | |
"./models/t5-small-medical-recommendation", | |
"./fine_tuned_models/t5-small", | |
"./output", | |
"./fine_tuning_output" | |
] | |
# 检查是否为路径或模型标识符 | |
model_exists = False | |
for path in possible_local_paths: | |
if os.path.exists(path): | |
model_path = path | |
model_exists = True | |
print(f"Found fine-tuned model at: {model_path}") | |
break | |
if not model_exists and model_path == "t5-small-medical-recommendation": | |
print("Fine-tuned model not found locally. Falling back to base t5-small...") | |
model_path = "t5-small" | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(self.device) | |
print(f"Recommendation Generator model '{model_path}' loaded successfully.") | |
except Exception as e: | |
print(f"Error loading model from {model_path}: {str(e)}") | |
print("Falling back to base t5-small model...") | |
self.tokenizer = AutoTokenizer.from_pretrained("t5-small") | |
self.model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(self.device) | |
print("Base t5-small model loaded successfully as fallback.") | |
# 科室映射 - 症状关键词到科室的映射 | |
self.symptom_to_department = { | |
"headache": "Neurology", | |
"dizziness": "Neurology", | |
"confusion": "Neurology", | |
"memory": "Neurology", | |
"numbness": "Neurology", | |
"tingling": "Neurology", | |
"seizure": "Neurology", | |
"nerve": "Neurology", | |
"chest pain": "Cardiology", | |
"heart": "Cardiology", | |
"palpitation": "Cardiology", | |
"arrhythmia": "Cardiology", | |
"high blood pressure": "Cardiology", | |
"hypertension": "Cardiology", | |
"heart attack": "Cardiology", | |
"cardiovascular": "Cardiology", | |
"cough": "Pulmonology", | |
"breathing": "Pulmonology", | |
"shortness": "Pulmonology", | |
"lung": "Pulmonology", | |
"respiratory": "Pulmonology", | |
"asthma": "Pulmonology", | |
"pneumonia": "Pulmonology", | |
"copd": "Pulmonology", | |
"stomach": "Gastroenterology", | |
"abdomen": "Gastroenterology", | |
"nausea": "Gastroenterology", | |
"vomit": "Gastroenterology", | |
"diarrhea": "Gastroenterology", | |
"constipation": "Gastroenterology", | |
"heartburn": "Gastroenterology", | |
"liver": "Gastroenterology", | |
"digestive": "Gastroenterology", | |
"joint": "Orthopedics", | |
"bone": "Orthopedics", | |
"muscle": "Orthopedics", | |
"pain": "Orthopedics", | |
"back": "Orthopedics", | |
"arthritis": "Orthopedics", | |
"fracture": "Orthopedics", | |
"sprain": "Orthopedics", | |
"rash": "Dermatology", | |
"skin": "Dermatology", | |
"itching": "Dermatology", | |
"itch": "Dermatology", | |
"acne": "Dermatology", | |
"eczema": "Dermatology", | |
"psoriasis": "Dermatology", | |
"fever": "General Medicine / Primary Care", | |
"infection": "General Medicine / Primary Care", | |
"sore throat": "General Medicine / Primary Care", | |
"flu": "General Medicine / Primary Care", | |
"cold": "General Medicine / Primary Care", | |
"fatigue": "General Medicine / Primary Care", | |
"pregnancy": "Obstetrics / Gynecology", | |
"menstruation": "Obstetrics / Gynecology", | |
"period": "Obstetrics / Gynecology", | |
"vaginal": "Obstetrics / Gynecology", | |
"menopause": "Obstetrics / Gynecology", | |
"depression": "Psychiatry", | |
"anxiety": "Psychiatry", | |
"mood": "Psychiatry", | |
"stress": "Psychiatry", | |
"sleep": "Psychiatry", | |
"insomnia": "Psychiatry", | |
"mental": "Psychiatry", | |
"ear": "Otolaryngology (ENT)", | |
"nose": "Otolaryngology (ENT)", | |
"throat": "Otolaryngology (ENT)", | |
"hearing": "Otolaryngology (ENT)", | |
"sinus": "Otolaryngology (ENT)", | |
"eye": "Ophthalmology", | |
"vision": "Ophthalmology", | |
"blindness": "Ophthalmology", | |
"blurry": "Ophthalmology", | |
"urination": "Urology", | |
"kidney": "Urology", | |
"bladder": "Urology", | |
"urine": "Urology", | |
"prostate": "Urology" | |
} | |
# 自我护理建议 | |
self.self_care_by_risk = { | |
"Low": [ | |
"Ensure you're getting adequate rest", | |
"Stay hydrated by drinking plenty of water", | |
"Monitor your symptoms and note any changes", | |
"Consider over-the-counter medications appropriate for your symptoms", | |
"Maintain a balanced diet to support your immune system", | |
"Try gentle exercises if appropriate for your condition", | |
"Avoid activities that worsen your symptoms", | |
"Keep track of any patterns in your symptoms" | |
], | |
"Medium": [ | |
"Rest and avoid strenuous activities", | |
"Stay hydrated and maintain proper nutrition", | |
"Take your temperature and other vital signs if possible", | |
"Write down any changes in symptoms and when they occur", | |
"Have someone stay with you if your symptoms are concerning", | |
"Prepare a list of your symptoms and medications for your doctor", | |
"Avoid self-medicating beyond basic over-the-counter remedies", | |
"Consider arranging transportation to your medical appointment" | |
], | |
"High": [ | |
"Don't wait - seek medical attention immediately", | |
"Have someone drive you to the emergency room if safe to do so", | |
"Call emergency services if symptoms are severe", | |
"Bring a list of your current medications if possible", | |
"Follow any first aid protocols appropriate for your symptoms", | |
"Don't eat or drink anything if you might need surgery", | |
"Take prescribed emergency medications if applicable (like an inhaler for asthma)", | |
"Try to remain calm and focused on getting help" | |
] | |
} | |
def _extract_departments_from_symptoms(self, symptoms_text: str) -> List[str]: | |
""" | |
从症状描述中提取可能的相关科室 | |
Args: | |
symptoms_text: 症状描述文本 | |
Returns: | |
科室名称列表 | |
""" | |
departments = set() | |
symptoms_lower = symptoms_text.lower() | |
# 通过关键词匹配寻找相关科室 | |
for keyword, department in self.symptom_to_department.items(): | |
if keyword in symptoms_lower: | |
departments.add(department) | |
# 如果没有找到匹配的科室,返回常规医疗科室 | |
if not departments: | |
departments.add("General Medicine / Primary Care") | |
return list(departments) | |
def _get_self_care_suggestions(self, risk_level: str) -> List[str]: | |
""" | |
根据风险级别获取自我护理建议 | |
Args: | |
risk_level: 风险级别 (Low, Medium, High) | |
Returns: | |
自我护理建议列表 | |
""" | |
# 确保风险级别有效 | |
if risk_level not in self.self_care_by_risk: | |
risk_level = "Medium" # 默认返回中等风险的建议 | |
# 返回为该风险级别准备的建议 | |
suggestions = self.self_care_by_risk[risk_level] | |
# 随机选择5项建议,避免每次返回完全相同的内容 | |
import random | |
if len(suggestions) > 5: | |
selected = random.sample(suggestions, 5) | |
else: | |
selected = suggestions | |
return selected | |
def _format_structured_recommendation(self, medical_advice: str, departments: List[str], self_care: List[str], risk_level: str) -> str: | |
""" | |
格式化结构化建议为文本格式 | |
Args: | |
medical_advice: 主要医疗建议 | |
departments: 建议科室列表 | |
self_care: 自我护理建议列表 | |
risk_level: 风险级别 | |
Returns: | |
格式化后的完整建议文本 | |
""" | |
# 初始化建议文本 | |
recommendation = "" | |
# 添加主要医疗建议 | |
recommendation += medical_advice.strip() + "\n\n" | |
# 添加建议科室部分 | |
recommendation += f"RECOMMENDED DEPARTMENTS: Based on your symptoms, consider consulting the following departments: {', '.join(departments)}.\n\n" | |
# 添加自我护理部分 | |
recommendation += f"SELF-CARE SUGGESTIONS: While {risk_level.lower()} risk level requires {'immediate attention' if risk_level == 'High' else 'medical care soon' if risk_level == 'Medium' else 'monitoring'}, you can also:\n" | |
for suggestion in self_care: | |
recommendation += f"- {suggestion}\n" | |
return recommendation | |
def generate_recommendation(self, | |
symptoms: str, | |
risk_level: str, | |
max_length: int = 150) -> Dict[str, Any]: | |
""" | |
Generate a comprehensive medical recommendation based on symptoms and risk level. | |
Args: | |
symptoms: Symptom description text | |
risk_level: Risk level (Low, Medium, High) | |
max_length: Maximum length for generated text | |
Returns: | |
Dictionary containing structured recommendation including medical advice, | |
department suggestions, and self-care tips | |
""" | |
# 创建输入提示 | |
input_text = f"Symptoms: {symptoms} Risk: {risk_level}" | |
# 通过模型生成主要医疗建议 | |
inputs = self.tokenizer( | |
input_text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
).to(self.device) | |
with torch.no_grad(): | |
output_ids = self.model.generate( | |
**inputs, | |
max_length=max_length, | |
num_beams=4, | |
early_stopping=True | |
) | |
# 解码生成的医疗建议 | |
medical_advice = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
# 从症状提取建议科室 | |
departments = self._extract_departments_from_symptoms(symptoms) | |
# 如果是高风险,添加急诊科 | |
if risk_level == "High" and "Emergency Medicine" not in departments: | |
departments.insert(0, "Emergency Medicine") | |
# 获取自我护理建议 | |
self_care_suggestions = self._get_self_care_suggestions(risk_level) | |
# 创建完整的结构化建议 | |
structured_recommendation = { | |
"medical_advice": medical_advice, | |
"departments": departments, | |
"self_care": self_care_suggestions | |
} | |
# 格式化为文本格式的完整建议 | |
formatted_text = self._format_structured_recommendation( | |
medical_advice, | |
departments, | |
self_care_suggestions, | |
risk_level | |
) | |
return { | |
"text": formatted_text, | |
"structured": structured_recommendation | |
} | |
class MedicalConsultationPipeline: | |
"""Complete pipeline for medical consultation.""" | |
def __init__(self, | |
symptom_model="dmis-lab/biobert-v1.1", | |
risk_model="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", | |
recommendation_model="t5-small", | |
device=None): | |
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Initializing Medical Consultation Pipeline on {self.device}...") | |
self.symptom_extractor = SymptomExtractor(model_name=symptom_model, device=self.device) | |
self.risk_classifier = RiskClassifier(model_name=risk_model, device=self.device) | |
self.recommendation_generator = RecommendationGenerator(model_path=recommendation_model, device=self.device) | |
print("Medical Consultation Pipeline initialized successfully.") | |
def process(self, text: str) -> Dict[str, Any]: | |
"""Process the patient description through the complete pipeline.""" | |
# Step 1: Extract symptoms | |
extraction_results = self.symptom_extractor.extract_symptoms(text) | |
# Step 2: Classify risk | |
risk_results = self.risk_classifier.classify_risk(text) | |
# Create a summary of the symptoms for the recommendation model | |
symptoms_summary = ", ".join([symptom["text"] for symptom in extraction_results["symptoms"]]) | |
if not symptoms_summary: | |
symptoms_summary = text # Use original text if no symptoms found | |
# Step 3: Generate recommendation | |
recommendation_result = self.recommendation_generator.generate_recommendation( | |
symptoms=symptoms_summary, | |
risk_level=risk_results["risk_level"] | |
) | |
return { | |
"extraction": extraction_results, | |
"risk": risk_results, | |
"recommendation": recommendation_result["text"], | |
"structured_recommendation": recommendation_result["structured"], | |
"input_text": text | |
} | |
# Example usage | |
if __name__ == "__main__": | |
# This is just a test code that won't run in the Streamlit app | |
pipeline = MedicalConsultationPipeline() | |
sample_text = "I've been experiencing severe headaches and dizziness for about 2 weeks. Sometimes I also feel nauseous." | |
result = pipeline.process(sample_text) | |
print("Extracted symptoms:", [s["text"] for s in result["extraction"]["symptoms"]]) | |
print("Duration info:", [d["text"] for d in result["extraction"]["duration"]]) | |
print("Risk level:", result["risk"]["risk_level"], f"(Confidence: {result['risk']['confidence']:.2f})") | |
print("Recommendation:", result["recommendation"]) |