Spaces:
Sleeping
Sleeping
File size: 6,761 Bytes
f368eec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
"""Flask web application for the drug interaction system."""
import os
import uuid
import io
from flask import Flask, request, jsonify, render_template, send_file
import matplotlib.pyplot as plt
from ..models.chatbot import DrugInteractionChatbot
def create_app():
"""Create and configure the Flask application"""
app = Flask(__name__,
static_folder='static',
template_folder='templates')
# Initialize the chatbot
chatbot = DrugInteractionChatbot()
# Ensure the visualization directory exists
os.makedirs(os.path.join(app.static_folder, "visualizations"), exist_ok=True)
@app.route('/')
def home():
"""Render the home page"""
return render_template('index.html')
@app.route('/api/ask', methods=['POST'])
def ask():
"""Process a user message and return a response"""
data = request.json
user_message = data.get('message', '')
if not user_message:
return jsonify({'error': 'No message provided'}), 400
response = chatbot.process_message(user_message)
# Check if we need to generate a visualization
visualization_needed = False
drug_name = None
if "interaction found between" in response:
# Extract drug name from response
import re
match = re.search(r'interaction found between (.+?) and', response)
if match:
drug_name = match.group(1)
visualization_needed = True
result = {
'response': response,
'visualization': None
}
if visualization_needed and drug_name:
# Generate a unique ID for this visualization
viz_id = str(uuid.uuid4())
# Create and save the visualization
G, error = chatbot.processor.generate_network(drug_name)
if not error:
# Save the visualization to a file
viz_path = os.path.join(app.static_folder, "visualizations", f"{viz_id}.png")
plt.savefig(viz_path)
plt.close()
# Add the URL to the result
result['visualization'] = f"/static/visualizations/{viz_id}.png"
return jsonify(result)
@app.route('/api/visualize/<drug_name>')
def visualize(drug_name):
"""Generate a visualization for a specific drug"""
# Create the visualization
G, error = chatbot.processor.generate_network(drug_name)
if error:
return jsonify({'error': error}), 404
# Save to a BytesIO object
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
return send_file(buf, mimetype='image/png')
@app.route('/api/analyze-note', methods=['POST'])
def analyze_note():
"""Analyze a clinical note for drug interactions"""
data = request.json
clinical_note = data.get('note', '')
if not clinical_note:
return jsonify({'error': 'No clinical note provided'}), 400
# Extract medications and interactions from note
results = chatbot.processor.extract_drugs_from_clinical_notes(clinical_note)
# Enhance results with database information
meds = [med.get("name") for med in results["medications"] if med.get("name")]
db_interactions = []
# Check all pairs of medications
for i in range(len(meds)):
for j in range(i+1, len(meds)):
interactions, _ = chatbot.db.get_interactions(meds[i], meds[j])
for d1, d2, desc, severity, source in interactions:
db_interactions.append({
"drug1": meds[i],
"drug2": meds[j],
"description": desc,
"severity": severity,
"source": source
})
# Add database interactions to results
results["database_interactions"] = db_interactions
return jsonify(results)
@app.route('/api/drug-info/<drug_name>')
def drug_info(drug_name):
"""Get information about a specific drug"""
# Get drug information
drug_info = chatbot.processor.get_drug_information(drug_name)
if not drug_info:
return jsonify({'error': f'No information found for {drug_name}'}), 404
return jsonify(drug_info)
@app.route('/api/interaction-network')
def interaction_network():
"""Generate a network visualization of drug interactions"""
drug_name = request.args.get('drug', None)
depth = int(request.args.get('depth', 1))
if drug_name:
G, error = chatbot.processor.generate_network(drug_name, depth)
else:
G, error = chatbot.processor.generate_network()
if error:
return jsonify({'error': error}), 404
# Create visualization
plt.figure(figsize=(12, 10))
# Get positions
import networkx as nx
pos = nx.spring_layout(G, seed=42)
# Draw nodes
node_sizes = [G.nodes[node].get('size', 10) for node in G.nodes()]
node_colors = [G.nodes[node].get('color', 'blue') for node in G.nodes()]
nx.draw_networkx_nodes(G, pos, node_size=node_sizes, node_color=node_colors, alpha=0.8)
# Draw edges with colors based on severity
edge_colors = []
edge_widths = []
for u, v, data in G.edges(data=True):
edge_colors.append(data.get('color', 'gray'))
edge_widths.append(data.get('weight', 1))
nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=edge_widths, alpha=0.7)
# Add labels
nx.draw_networkx_labels(G, pos, font_size=10, font_family="sans-serif")
# Save to BytesIO
buf = io.BytesIO()
plt.axis('off')
plt.tight_layout()
plt.savefig(buf, format='png', dpi=150)
buf.seek(0)
plt.close()
return send_file(buf, mimetype='image/png')
return app
if __name__ == "__main__":
app = create_app()
port = int(os.environ.get('PORT', 5000))
app.run(host='0.0.0.0', port=port) |