finalProject / models.py
BillyZ1129's picture
Upload 7 files
79bcb1b verified
import torch
import numpy as np
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
AutoModelForSequenceClassification,
AutoModelForSeq2SeqLM,
pipeline
)
import re
import os
import json
from typing import Dict, List, Tuple, Any
class SymptomExtractor:
"""Model for extracting symptoms from patient descriptions using BioBERT."""
def __init__(self, model_name="dmis-lab/biobert-v1.1", device=None):
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading Symptom Extractor model on {self.device}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForTokenClassification.from_pretrained(model_name).to(self.device)
self.nlp = pipeline("ner", model=self.model, tokenizer=self.tokenizer, device=0 if self.device == "cuda" else -1)
print("Symptom Extractor model loaded successfully.")
def extract_symptoms(self, text: str) -> Dict[str, Any]:
"""Extract symptoms from the input text."""
results = self.nlp(text)
# Process the NER results to group related tokens
symptoms = []
current_symptom = None
for entity in results:
if entity["entity"].startswith("B-"): # Beginning of a symptom
if current_symptom:
symptoms.append(current_symptom)
current_symptom = {
"text": entity["word"],
"start": entity["start"],
"end": entity["end"],
"score": entity["score"]
}
elif entity["entity"].startswith("I-") and current_symptom: # Inside a symptom
current_symptom["text"] += " " + entity["word"].replace("##", "")
current_symptom["end"] = entity["end"]
current_symptom["score"] = (current_symptom["score"] + entity["score"]) / 2
if current_symptom:
symptoms.append(current_symptom)
# Extract duration information
duration_patterns = [
r"(\d+)\s*(day|days|week|weeks|month|months|year|years)",
r"since\s+(\w+)",
r"for\s+(\w+)"
]
duration_info = []
for pattern in duration_patterns:
matches = re.finditer(pattern, text, re.IGNORECASE)
for match in matches:
duration_info.append({
"text": match.group(0),
"start": match.start(),
"end": match.end()
})
return {
"symptoms": symptoms,
"duration": duration_info
}
class RiskClassifier:
"""Model for classifying patient risk level using PubMedBERT."""
def __init__(self, model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", device=None):
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading Risk Classifier model on {self.device}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=3 # Low, Medium, High
).to(self.device)
self.id2label = {0: "Low", 1: "Medium", 2: "High"}
print("Risk Classifier model loaded successfully.")
def classify_risk(self, text: str) -> Dict[str, Any]:
"""Classify the risk level based on the input text."""
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)[0].cpu().numpy()
model_prediction = torch.argmax(logits, dim=1).item()
# 由于模型没有经过微调,我们添加基于规则的后处理来调整风险级别
# 检查文本中是否存在高风险关键词
high_risk_keywords = [
"severe", "extreme", "intense", "unbearable", "emergency",
"chest pain", "difficulty breathing", "can't breathe",
"losing consciousness", "fainted", "seizure", "stroke", "heart attack",
"allergic reaction", "bleeding heavily", "blood", "poisoning",
"overdose", "suicide", "self-harm", "hallucinations"
]
medium_risk_keywords = [
"worsening", "spreading", "persistent", "chronic", "recurring",
"infection", "fever", "swelling", "rash", "pain", "vomiting",
"diarrhea", "dizzy", "headache", "concerning", "worried",
"weeks", "days", "increasing", "progressing"
]
low_risk_keywords = [
"mild", "slight", "minor", "occasional", "intermittent",
"improving", "better", "sometimes", "rarely", "manageable"
]
text_lower = text.lower()
# 计算匹配的关键词数量
high_risk_matches = sum(keyword in text_lower for keyword in high_risk_keywords)
medium_risk_matches = sum(keyword in text_lower for keyword in medium_risk_keywords)
low_risk_matches = sum(keyword in text_lower for keyword in low_risk_keywords)
# 根据关键词匹配调整风险预测
adjusted_prediction = model_prediction
if high_risk_matches >= 2:
adjusted_prediction = 2 # High risk
elif high_risk_matches == 1 and medium_risk_matches >= 2:
adjusted_prediction = 2 # High risk
elif medium_risk_matches >= 3:
adjusted_prediction = 1 # Medium risk
elif medium_risk_matches >= 1 and low_risk_matches <= 1:
adjusted_prediction = 1 # Medium risk
elif low_risk_matches >= 2 and high_risk_matches == 0:
adjusted_prediction = 0 # Low risk
# 如果文本很长(详细描述),可能表明情况更复杂,风险更高
if len(text.split()) > 40 and adjusted_prediction == 0:
adjusted_prediction = 1 # 升级到Medium风险
# 对调整后的概率进行修正
adjusted_probabilities = probabilities.copy()
# 增强对应风险级别的概率
adjusted_probabilities[adjusted_prediction] = max(0.6, adjusted_probabilities[adjusted_prediction])
# 规范化概率使其总和为1
adjusted_probabilities = adjusted_probabilities / adjusted_probabilities.sum()
return {
"risk_level": self.id2label[adjusted_prediction],
"confidence": float(adjusted_probabilities[adjusted_prediction]),
"all_probabilities": {
self.id2label[i]: float(prob)
for i, prob in enumerate(adjusted_probabilities)
},
"original_prediction": self.id2label[model_prediction]
}
class RecommendationGenerator:
"""Model for generating medical recommendations using fine-tuned t5-small."""
def __init__(self, model_path="t5-small", device=None):
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading Recommendation Generator model on {self.device}...")
# 检查常见的微调模型路径
possible_local_paths = [
"./finetuned_t5-small", # 添加用户指定的微调模型路径
"./t5-small-medical-recommendation",
"./models/t5-small-medical-recommendation",
"./fine_tuned_models/t5-small",
"./output",
"./fine_tuning_output"
]
# 检查是否为路径或模型标识符
model_exists = False
for path in possible_local_paths:
if os.path.exists(path):
model_path = path
model_exists = True
print(f"Found fine-tuned model at: {model_path}")
break
if not model_exists and model_path == "t5-small-medical-recommendation":
print("Fine-tuned model not found locally. Falling back to base t5-small...")
model_path = "t5-small"
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(self.device)
print(f"Recommendation Generator model '{model_path}' loaded successfully.")
except Exception as e:
print(f"Error loading model from {model_path}: {str(e)}")
print("Falling back to base t5-small model...")
self.tokenizer = AutoTokenizer.from_pretrained("t5-small")
self.model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(self.device)
print("Base t5-small model loaded successfully as fallback.")
# 科室映射 - 症状关键词到科室的映射
self.symptom_to_department = {
"headache": "Neurology",
"dizziness": "Neurology",
"confusion": "Neurology",
"memory": "Neurology",
"numbness": "Neurology",
"tingling": "Neurology",
"seizure": "Neurology",
"nerve": "Neurology",
"chest pain": "Cardiology",
"heart": "Cardiology",
"palpitation": "Cardiology",
"arrhythmia": "Cardiology",
"high blood pressure": "Cardiology",
"hypertension": "Cardiology",
"heart attack": "Cardiology",
"cardiovascular": "Cardiology",
"cough": "Pulmonology",
"breathing": "Pulmonology",
"shortness": "Pulmonology",
"lung": "Pulmonology",
"respiratory": "Pulmonology",
"asthma": "Pulmonology",
"pneumonia": "Pulmonology",
"copd": "Pulmonology",
"stomach": "Gastroenterology",
"abdomen": "Gastroenterology",
"nausea": "Gastroenterology",
"vomit": "Gastroenterology",
"diarrhea": "Gastroenterology",
"constipation": "Gastroenterology",
"heartburn": "Gastroenterology",
"liver": "Gastroenterology",
"digestive": "Gastroenterology",
"joint": "Orthopedics",
"bone": "Orthopedics",
"muscle": "Orthopedics",
"pain": "Orthopedics",
"back": "Orthopedics",
"arthritis": "Orthopedics",
"fracture": "Orthopedics",
"sprain": "Orthopedics",
"rash": "Dermatology",
"skin": "Dermatology",
"itching": "Dermatology",
"itch": "Dermatology",
"acne": "Dermatology",
"eczema": "Dermatology",
"psoriasis": "Dermatology",
"fever": "General Medicine / Primary Care",
"infection": "General Medicine / Primary Care",
"sore throat": "General Medicine / Primary Care",
"flu": "General Medicine / Primary Care",
"cold": "General Medicine / Primary Care",
"fatigue": "General Medicine / Primary Care",
"pregnancy": "Obstetrics / Gynecology",
"menstruation": "Obstetrics / Gynecology",
"period": "Obstetrics / Gynecology",
"vaginal": "Obstetrics / Gynecology",
"menopause": "Obstetrics / Gynecology",
"depression": "Psychiatry",
"anxiety": "Psychiatry",
"mood": "Psychiatry",
"stress": "Psychiatry",
"sleep": "Psychiatry",
"insomnia": "Psychiatry",
"mental": "Psychiatry",
"ear": "Otolaryngology (ENT)",
"nose": "Otolaryngology (ENT)",
"throat": "Otolaryngology (ENT)",
"hearing": "Otolaryngology (ENT)",
"sinus": "Otolaryngology (ENT)",
"eye": "Ophthalmology",
"vision": "Ophthalmology",
"blindness": "Ophthalmology",
"blurry": "Ophthalmology",
"urination": "Urology",
"kidney": "Urology",
"bladder": "Urology",
"urine": "Urology",
"prostate": "Urology"
}
# 自我护理建议
self.self_care_by_risk = {
"Low": [
"Ensure you're getting adequate rest",
"Stay hydrated by drinking plenty of water",
"Monitor your symptoms and note any changes",
"Consider over-the-counter medications appropriate for your symptoms",
"Maintain a balanced diet to support your immune system",
"Try gentle exercises if appropriate for your condition",
"Avoid activities that worsen your symptoms",
"Keep track of any patterns in your symptoms"
],
"Medium": [
"Rest and avoid strenuous activities",
"Stay hydrated and maintain proper nutrition",
"Take your temperature and other vital signs if possible",
"Write down any changes in symptoms and when they occur",
"Have someone stay with you if your symptoms are concerning",
"Prepare a list of your symptoms and medications for your doctor",
"Avoid self-medicating beyond basic over-the-counter remedies",
"Consider arranging transportation to your medical appointment"
],
"High": [
"Don't wait - seek medical attention immediately",
"Have someone drive you to the emergency room if safe to do so",
"Call emergency services if symptoms are severe",
"Bring a list of your current medications if possible",
"Follow any first aid protocols appropriate for your symptoms",
"Don't eat or drink anything if you might need surgery",
"Take prescribed emergency medications if applicable (like an inhaler for asthma)",
"Try to remain calm and focused on getting help"
]
}
def _extract_departments_from_symptoms(self, symptoms_text: str) -> List[str]:
"""
从症状描述中提取可能的相关科室
Args:
symptoms_text: 症状描述文本
Returns:
科室名称列表
"""
departments = set()
symptoms_lower = symptoms_text.lower()
# 通过关键词匹配寻找相关科室
for keyword, department in self.symptom_to_department.items():
if keyword in symptoms_lower:
departments.add(department)
# 如果没有找到匹配的科室,返回常规医疗科室
if not departments:
departments.add("General Medicine / Primary Care")
return list(departments)
def _get_self_care_suggestions(self, risk_level: str) -> List[str]:
"""
根据风险级别获取自我护理建议
Args:
risk_level: 风险级别 (Low, Medium, High)
Returns:
自我护理建议列表
"""
# 确保风险级别有效
if risk_level not in self.self_care_by_risk:
risk_level = "Medium" # 默认返回中等风险的建议
# 返回为该风险级别准备的建议
suggestions = self.self_care_by_risk[risk_level]
# 随机选择5项建议,避免每次返回完全相同的内容
import random
if len(suggestions) > 5:
selected = random.sample(suggestions, 5)
else:
selected = suggestions
return selected
def _format_structured_recommendation(self, medical_advice: str, departments: List[str], self_care: List[str], risk_level: str) -> str:
"""
格式化结构化建议为文本格式
Args:
medical_advice: 主要医疗建议
departments: 建议科室列表
self_care: 自我护理建议列表
risk_level: 风险级别
Returns:
格式化后的完整建议文本
"""
# 初始化建议文本
recommendation = ""
# 添加主要医疗建议
recommendation += medical_advice.strip() + "\n\n"
# 添加建议科室部分
recommendation += f"RECOMMENDED DEPARTMENTS: Based on your symptoms, consider consulting the following departments: {', '.join(departments)}.\n\n"
# 添加自我护理部分
recommendation += f"SELF-CARE SUGGESTIONS: While {risk_level.lower()} risk level requires {'immediate attention' if risk_level == 'High' else 'medical care soon' if risk_level == 'Medium' else 'monitoring'}, you can also:\n"
for suggestion in self_care:
recommendation += f"- {suggestion}\n"
return recommendation
def generate_recommendation(self,
symptoms: str,
risk_level: str,
max_length: int = 150) -> Dict[str, Any]:
"""
Generate a comprehensive medical recommendation based on symptoms and risk level.
Args:
symptoms: Symptom description text
risk_level: Risk level (Low, Medium, High)
max_length: Maximum length for generated text
Returns:
Dictionary containing structured recommendation including medical advice,
department suggestions, and self-care tips
"""
# 创建输入提示
input_text = f"Symptoms: {symptoms} Risk: {risk_level}"
# 通过模型生成主要医疗建议
inputs = self.tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_length=max_length,
num_beams=4,
early_stopping=True
)
# 解码生成的医疗建议
medical_advice = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
# 从症状提取建议科室
departments = self._extract_departments_from_symptoms(symptoms)
# 如果是高风险,添加急诊科
if risk_level == "High" and "Emergency Medicine" not in departments:
departments.insert(0, "Emergency Medicine")
# 获取自我护理建议
self_care_suggestions = self._get_self_care_suggestions(risk_level)
# 创建完整的结构化建议
structured_recommendation = {
"medical_advice": medical_advice,
"departments": departments,
"self_care": self_care_suggestions
}
# 格式化为文本格式的完整建议
formatted_text = self._format_structured_recommendation(
medical_advice,
departments,
self_care_suggestions,
risk_level
)
return {
"text": formatted_text,
"structured": structured_recommendation
}
class MedicalConsultationPipeline:
"""Complete pipeline for medical consultation."""
def __init__(self,
symptom_model="dmis-lab/biobert-v1.1",
risk_model="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
recommendation_model="t5-small",
device=None):
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Initializing Medical Consultation Pipeline on {self.device}...")
self.symptom_extractor = SymptomExtractor(model_name=symptom_model, device=self.device)
self.risk_classifier = RiskClassifier(model_name=risk_model, device=self.device)
self.recommendation_generator = RecommendationGenerator(model_path=recommendation_model, device=self.device)
print("Medical Consultation Pipeline initialized successfully.")
def process(self, text: str) -> Dict[str, Any]:
"""Process the patient description through the complete pipeline."""
# Step 1: Extract symptoms
extraction_results = self.symptom_extractor.extract_symptoms(text)
# Step 2: Classify risk
risk_results = self.risk_classifier.classify_risk(text)
# Create a summary of the symptoms for the recommendation model
symptoms_summary = ", ".join([symptom["text"] for symptom in extraction_results["symptoms"]])
if not symptoms_summary:
symptoms_summary = text # Use original text if no symptoms found
# Step 3: Generate recommendation
recommendation_result = self.recommendation_generator.generate_recommendation(
symptoms=symptoms_summary,
risk_level=risk_results["risk_level"]
)
return {
"extraction": extraction_results,
"risk": risk_results,
"recommendation": recommendation_result["text"],
"structured_recommendation": recommendation_result["structured"],
"input_text": text
}
# Example usage
if __name__ == "__main__":
# This is just a test code that won't run in the Streamlit app
pipeline = MedicalConsultationPipeline()
sample_text = "I've been experiencing severe headaches and dizziness for about 2 weeks. Sometimes I also feel nauseous."
result = pipeline.process(sample_text)
print("Extracted symptoms:", [s["text"] for s in result["extraction"]["symptoms"]])
print("Duration info:", [d["text"] for d in result["extraction"]["duration"]])
print("Risk level:", result["risk"]["risk_level"], f"(Confidence: {result['risk']['confidence']:.2f})")
print("Recommendation:", result["recommendation"])