"""Flask web application for the drug interaction system.""" import os import uuid import io from flask import Flask, request, jsonify, render_template, send_file import matplotlib.pyplot as plt from ..models.chatbot import DrugInteractionChatbot def create_app(): """Create and configure the Flask application""" app = Flask(__name__, static_folder='static', template_folder='templates') # Initialize the chatbot chatbot = DrugInteractionChatbot() # Ensure the visualization directory exists os.makedirs(os.path.join(app.static_folder, "visualizations"), exist_ok=True) @app.route('/') def home(): """Render the home page""" return render_template('index.html') @app.route('/api/ask', methods=['POST']) def ask(): """Process a user message and return a response""" data = request.json user_message = data.get('message', '') if not user_message: return jsonify({'error': 'No message provided'}), 400 response = chatbot.process_message(user_message) # Check if we need to generate a visualization visualization_needed = False drug_name = None if "interaction found between" in response: # Extract drug name from response import re match = re.search(r'interaction found between (.+?) and', response) if match: drug_name = match.group(1) visualization_needed = True result = { 'response': response, 'visualization': None } if visualization_needed and drug_name: # Generate a unique ID for this visualization viz_id = str(uuid.uuid4()) # Create and save the visualization G, error = chatbot.processor.generate_network(drug_name) if not error: # Save the visualization to a file viz_path = os.path.join(app.static_folder, "visualizations", f"{viz_id}.png") plt.savefig(viz_path) plt.close() # Add the URL to the result result['visualization'] = f"/static/visualizations/{viz_id}.png" return jsonify(result) @app.route('/api/visualize/') def visualize(drug_name): """Generate a visualization for a specific drug""" # Create the visualization G, error = chatbot.processor.generate_network(drug_name) if error: return jsonify({'error': error}), 404 # Save to a BytesIO object buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) plt.close() return send_file(buf, mimetype='image/png') @app.route('/api/analyze-note', methods=['POST']) def analyze_note(): """Analyze a clinical note for drug interactions""" data = request.json clinical_note = data.get('note', '') if not clinical_note: return jsonify({'error': 'No clinical note provided'}), 400 # Extract medications and interactions from note results = chatbot.processor.extract_drugs_from_clinical_notes(clinical_note) # Enhance results with database information meds = [med.get("name") for med in results["medications"] if med.get("name")] db_interactions = [] # Check all pairs of medications for i in range(len(meds)): for j in range(i+1, len(meds)): interactions, _ = chatbot.db.get_interactions(meds[i], meds[j]) for d1, d2, desc, severity, source in interactions: db_interactions.append({ "drug1": meds[i], "drug2": meds[j], "description": desc, "severity": severity, "source": source }) # Add database interactions to results results["database_interactions"] = db_interactions return jsonify(results) @app.route('/api/drug-info/') def drug_info(drug_name): """Get information about a specific drug""" # Get drug information drug_info = chatbot.processor.get_drug_information(drug_name) if not drug_info: return jsonify({'error': f'No information found for {drug_name}'}), 404 return jsonify(drug_info) @app.route('/api/interaction-network') def interaction_network(): """Generate a network visualization of drug interactions""" drug_name = request.args.get('drug', None) depth = int(request.args.get('depth', 1)) if drug_name: G, error = chatbot.processor.generate_network(drug_name, depth) else: G, error = chatbot.processor.generate_network() if error: return jsonify({'error': error}), 404 # Create visualization plt.figure(figsize=(12, 10)) # Get positions import networkx as nx pos = nx.spring_layout(G, seed=42) # Draw nodes node_sizes = [G.nodes[node].get('size', 10) for node in G.nodes()] node_colors = [G.nodes[node].get('color', 'blue') for node in G.nodes()] nx.draw_networkx_nodes(G, pos, node_size=node_sizes, node_color=node_colors, alpha=0.8) # Draw edges with colors based on severity edge_colors = [] edge_widths = [] for u, v, data in G.edges(data=True): edge_colors.append(data.get('color', 'gray')) edge_widths.append(data.get('weight', 1)) nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=edge_widths, alpha=0.7) # Add labels nx.draw_networkx_labels(G, pos, font_size=10, font_family="sans-serif") # Save to BytesIO buf = io.BytesIO() plt.axis('off') plt.tight_layout() plt.savefig(buf, format='png', dpi=150) buf.seek(0) plt.close() return send_file(buf, mimetype='image/png') return app if __name__ == "__main__": app = create_app() port = int(os.environ.get('PORT', 5000)) app.run(host='0.0.0.0', port=port)