oussamaor's picture
Upload 12 files
f368eec verified
"""Drug Interaction Chatbot for natural language interaction with the drug interaction system."""
import re
import uuid
import io
import matplotlib.pyplot as plt
import networkx as nx
from typing import Dict, List, Tuple, Optional, Any
from .biomedical_llm import BiomedicalLLM
from .drug_interaction_db import DrugInteractionDatabase
from .ddi_processor import DDIProcessor
class DrugInteractionChatbot:
"""Chatbot interface for drug interaction analysis system."""
def __init__(self, model_name="stanford-crfm/BioMedLM"):
"""Initialize the Drug Interaction Chatbot with Biomedical LLM"""
self.db = DrugInteractionDatabase()
self.bio_llm = BiomedicalLLM(model_name)
self.processor = DDIProcessor(self.db, self.bio_llm)
def process_message(self, message):
"""Process a user message and provide an appropriate response"""
# Check if this is a clinical notes analysis request
if any(term in message.lower() for term in ["clinical note", "patient note", "extract from", "analyze note", "medical record"]):
# Extract the clinical note part
note_pattern = r"(?:clinical note|patient note|medical record)[s]?:?\s*([\s\S]+)$"
note_match = re.search(note_pattern, message, re.IGNORECASE)
if note_match:
clinical_text = note_match.group(1).strip()
extracted_info = self.processor.extract_drugs_from_clinical_notes(clinical_text)
# Format the response
response = "📋 **Analysis of Clinical Notes**\n\n"
# Add medications
if extracted_info["medications"]:
response += "**Medications Identified:**\n"
for med in extracted_info["medications"]:
name = med.get("name", "Unknown")
dosage = med.get("dosage", "Not specified")
frequency = med.get("frequency", "Not specified")
if dosage != "Not specified" or frequency != "Not specified":
response += f"- {name}: {dosage} {frequency}\n"
else:
response += f"- {name}\n"
response += "\n"
else:
response += "No medications were identified in the clinical notes.\n\n"
# Add potential interactions
if extracted_info["potential_interactions"]:
response += "**Potential Interactions:**\n"
for interaction in extracted_info["potential_interactions"]:
drug1 = interaction.get("drug1", "Unknown")
drug2 = interaction.get("drug2", "Unknown")
concern = interaction.get("concern", "Potential interaction")
response += f"- {drug1} + {drug2}: {concern}\n"
response += "\n"
else:
# Try to identify interactions from the medications list
meds = [med.get("name") for med in extracted_info["medications"] if med.get("name")]
potential_interactions = []
# Check all pairs of medications
for i in range(len(meds)):
for j in range(i+1, len(meds)):
interactions, _ = self.db.get_interactions(meds[i], meds[j])
if interactions:
for d1, d2, desc, severity, _ in interactions:
potential_interactions.append(f"- {meds[i]} + {meds[j]}: {desc} ({severity})")
if potential_interactions:
response += "**Potential Interactions:**\n"
response += "\n".join(potential_interactions) + "\n\n"
else:
response += "No potential interactions were identified among the medications.\n\n"
response += "Please consult with a healthcare professional for a comprehensive review of drug interactions and medical advice."
return response
# Check if user is asking for information about a specific drug
drug_info_pattern = r"(tell me about|information on|what is|info about|details on)\s+(.+?)(?:\?|$)"
drug_info_match = re.search(drug_info_pattern, message.lower())
if drug_info_match:
drug_name = drug_info_match.group(2).strip()
canonical = self.db.search_drug(drug_name)
# If not in database, use original name but still try to get info
if not canonical:
canonical = drug_name
# Get drug information from biomedical LLM
drug_info = self.processor.get_drug_information(canonical)
if drug_info:
# Format the response
response = f"📊 **Information about {drug_info['drug_name']}**\n\n"
if drug_info.get("drug_class") and drug_info["drug_class"] != "Information not available":
response += f"**Drug Class:** {drug_info['drug_class']}\n\n"
if drug_info.get("mechanism") and drug_info["mechanism"] != "Information not available":
response += f"**Mechanism of Action:** {drug_info['mechanism']}\n\n"
if drug_info.get("indications") and drug_info["indications"][0] != "Information not available":
response += "**Common Indications:**\n"
for indication in drug_info["indications"]:
response += f"- {indication}\n"
response += "\n"
if drug_info.get("side_effects") and drug_info["side_effects"][0] != "Information not available":
response += "**Common Side Effects:**\n"
for effect in drug_info["side_effects"]:
response += f"- {effect}\n"
response += "\n"
if drug_info.get("common_interactions") and drug_info["common_interactions"][0] != "Information not available":
response += "**Common Interactions:**\n"
for interaction in drug_info["common_interactions"]:
response += f"- {interaction}\n"
response += "\n"
if drug_info.get("contraindications") and drug_info["contraindications"][0] != "Information not available":
response += "**Contraindications:**\n"
for contraindication in drug_info["contraindications"]:
response += f"- {contraindication}\n"
response += "\n"
response += "This information is for educational purposes only. Always consult a healthcare professional for medical advice."
return response
else:
return f"I couldn't find detailed information about {drug_name}. Please check the spelling or try another medication."
# Check if this is a drug interaction query
if re.search(r'take|interact|safe|drug|interaction|medicine|pill|medication', message.lower()):
result = self.processor.process_query(message)
if result["status"] == "error":
return result["message"]
elif result["status"] == "not_found":
return result["message"]
elif result["status"] == "no_interaction":
return (f"Based on our database and biomedical literature analysis, no known interactions were found between {result['drugs'][0]} "
f"and {result['drugs'][1]}. However, always consult with a healthcare "
f"professional before combining medications.")
elif result["status"] == "found":
drug1, drug2 = result['drugs']
interactions = result["interactions"]
# Generate response
response = f"⚠️ **Potential interaction found between {drug1} and {drug2}:**\n\n"
for i, interaction in enumerate(interactions, 1):
severity = interaction["severity"]
# Add appropriate emoji based on severity
if severity.lower() == "severe":
emoji = "🔴"
elif severity.lower() == "moderate":
emoji = "🟠"
else:
emoji = "🟡"
response += f"{emoji} **{severity} interaction:** {interaction['description']}\n"
response += f" Source: {interaction['source']}\n\n"
# Add any management recommendations if available
try:
literature_info = self.bio_llm.extract_ddi_from_literature(drug1, drug2)
if "interactions" in literature_info and literature_info["interactions"]:
management = literature_info["interactions"][0].get("management")
if management:
response += f"📝 **Management Recommendation:** {management}\n\n"
except:
pass
response += "⚕️ Please consult with a healthcare professional before taking these medications together."
# Generate visualization
try:
G, error = self.processor.generate_network(drug1, depth=1)
if G:
response += "\n\nA visualization of this interaction has been generated."
# In a real implementation, we would save the graph image and provide a link or display it
except Exception as e:
pass # Handle gracefully if visualization fails
return response
# Check if the user is asking for all interactions for a specific drug
pattern = r"(what|show|list|tell).+?(interaction|interacts).+?(with|for|of)\s+(.+?)(?:\?|$)"
match = re.search(pattern, message.lower())
if match:
drug_name = match.group(4).strip()
canonical = self.db.search_drug(drug_name)
if not canonical:
return f"I couldn't find information about '{drug_name}' in our database."
interactions, _ = self.db.get_all_interactions(canonical)
if not interactions:
return f"No known interactions were found for {canonical} in our database."
response = f"**Known interactions for {canonical}:**\n\n"
# Group by severity
severe = []
moderate = []
mild = []
for _, other_drug, desc, severity, source in interactions:
if severity.lower() == "severe":
severe.append((other_drug, desc, source))
elif severity.lower() == "moderate":
moderate.append((other_drug, desc, source))
else:
mild.append((other_drug, desc, source))
# Add severe interactions
if severe:
response += "🔴 **Severe interactions:**\n"
for drug, desc, source in severe:
response += f"- **{drug}**: {desc} ({source})\n"
response += "\n"
# Add moderate interactions
if moderate:
response += "🟠 **Moderate interactions:**\n"
for drug, desc, source in moderate:
response += f"- **{drug}**: {desc} ({source})\n"
response += "\n"
# Add mild interactions
if mild:
response += "🟡 **Mild interactions:**\n"
for drug, desc, source in mild:
response += f"- **{drug}**: {desc} ({source})\n"
response += "\n"
response += "Please consult with a healthcare professional for personalized advice."
return response
# Check if the user is requesting a visualization
if re.search(r'(visualize|visualization|graph|chart|network|diagram).+?(drug|interaction|medicine)', message.lower()):
drug_match = re.search(r'(visualize|visualization|graph|chart|network|diagram).+?(for|of|between)\s+(.+?)(?:\?|$)', message.lower())
if drug_match:
drug_name = drug_match.group(3).strip()
canonical = self.db.search_drug(drug_name)
if not canonical:
return f"I couldn't find information about '{drug_name}' in our database."
try:
G, error = self.processor.generate_network(canonical, depth=1)
if error:
return error
return f"I've generated a network visualization for {canonical}'s interactions. The visualization shows connections to other drugs, with red edges indicating severe interactions, orange for moderate, and yellow for mild interactions."
except Exception as e:
return f"Sorry, I encountered an error while generating the visualization: {str(e)}"
else:
try:
G, error = self.processor.generate_network()
if error:
return error
return "I've generated a general drug interaction network visualization showing connections between several common drugs. Red edges indicate severe interactions, orange for moderate, and yellow for mild interactions."
except Exception as e:
return f"Sorry, I encountered an error while generating the visualization: {str(e)}"
# If not specifically about drugs
return ("I'm a drug interaction assistant powered by biomedical language models. You can ask me about:\n\n"
"1. Potential interactions between medications (e.g., 'Can I take aspirin and warfarin together?')\n"
"2. Information about specific drugs (e.g., 'Tell me about metformin')\n"
"3. Analysis of clinical notes (e.g., 'Analyze these clinical notes: [paste notes here]')\n"
"4. Visualizations of drug interaction networks (e.g., 'Show me a visualization for warfarin')")
def generate_visualization(self, drug_name=None, depth=1):
"""Generate a visualization of drug interactions"""
G, error = self.processor.generate_network(drug_name, depth)
if error:
return None, error
# Create a unique filename
viz_id = str(uuid.uuid4())
filename = f"static/visualizations/{viz_id}.png"
# Create the visualization
plt.figure(figsize=(12, 10))
# Get positions
pos = nx.spring_layout(G, seed=42)
# Draw nodes
node_sizes = [G.nodes[node].get('size', 10) for node in G.nodes()]
node_colors = [G.nodes[node].get('color', 'blue') for node in G.nodes()]
nx.draw_networkx_nodes(G, pos, node_size=node_sizes, node_color=node_colors, alpha=0.8)
# Draw edges with colors based on severity
edge_colors = []
edge_widths = []
for u, v, data in G.edges(data=True):
edge_colors.append(data.get('color', 'gray'))
edge_widths.append(data.get('weight', 1))
nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=edge_widths, alpha=0.7)
# Add labels
nx.draw_networkx_labels(G, pos, font_size=10, font_family="sans-serif")
# Save to file
plt.axis('off')
plt.tight_layout()
plt.savefig(filename, format='png', dpi=150)
plt.close()
return filename, None
def get_visualization_bytes(self, drug_name=None, depth=1):
"""Get visualization as bytes for web display"""
G, error = self.processor.generate_network(drug_name, depth)
if error:
return None, error
# Create the visualization
plt.figure(figsize=(12, 10))
# Get positions
pos = nx.spring_layout(G, seed=42)
# Draw nodes
node_sizes = [G.nodes[node].get('size', 10) for node in G.nodes()]
node_colors = [G.nodes[node].get('color', 'blue') for node in G.nodes()]
nx.draw_networkx_nodes(G, pos, node_size=node_sizes, node_color=node_colors, alpha=0.8)
# Draw edges with colors based on severity
edge_colors = []
edge_widths = []
for u, v, data in G.edges(data=True):
edge_colors.append(data.get('color', 'gray'))
edge_widths.append(data.get('weight', 1))
nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=edge_widths, alpha=0.7)
# Add labels
nx.draw_networkx_labels(G, pos, font_size=10, font_family="sans-serif")
# Save to BytesIO
buf = io.BytesIO()
plt.axis('off')
plt.tight_layout()
plt.savefig(buf, format='png', dpi=150)
buf.seek(0)
plt.close()
return buf, None