import os import logging import numpy as np import torch from torch import nn import pandas as pd from torchvision import transforms, models from PIL import Image import faiss from transformers import AutoTokenizer, AutoModel, T5ForConditionalGeneration, T5Tokenizer import gradio as gr import cv2 import traceback from datetime import datetime import re import random import functools import gc from collections import OrderedDict import json import sys import time from tqdm.auto import tqdm import warnings import matplotlib.pyplot as plt from fastapi import FastAPI, File, UploadFile, Form from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional, List, Dict, Any, Union import base64 import io # Suppress unnecessary warnings warnings.filterwarnings("ignore", category=UserWarning) # === Configuration === class Config: """Configuration for MediQuery system""" # Model configuration IMAGE_MODEL = "chexnet" # Options: "chexnet", "densenet" TEXT_MODEL = "biobert" # Options: "biobert", "clinicalbert" GEN_MODEL = "flan-t5-base-finetuned" # Base generation model # Resource management CACHE_SIZE = 50 # Reduced from 200 for deployment CACHE_EXPIRY_TIME = 1800 # Cache expiry time in seconds (30 minutes) LAZY_LOADING = True # Enable lazy loading of models USE_HALF_PRECISION = True # Use half precision for models if available # Feature flags DEBUG = True # Enable detailed debugging PHI_DETECTION_ENABLED = True # Enable PHI detection ANATOMY_MAPPING_ENABLED = True # Enable anatomical mapping # Thresholds and parameters CONFIDENCE_THRESHOLD = 0.4 # Threshold for flagging low confidence TOP_K_RETRIEVAL = 10 # Reduced from 30 for deployment MAX_CONTEXT_DOCS = 3 # Reduced from 5 for deployment # Advanced retrieval settings DYNAMIC_RERANKING = True # Dynamically adjust reranking weights DIVERSITY_PENALTY = 0.1 # Penalty for duplicate content # Performance optimization BATCH_SIZE = 1 # Reduced from 4 for deployment OPTIMIZE_MEMORY = True # Optimize memory usage USE_CACHING = True # Use caching for embeddings and queries # Path settings DEFAULT_KNOWLEDGE_BASE_DIR = "./knowledge_base" DEFAULT_MODEL_PATH = "./models/flan-t5-finetuned" LOG_DIR = "./logs" # Advanced settings EMBEDDING_AGGREGATION = "weighted_avg" # Options: "avg", "weighted_avg", "cls", "pooled" EMBEDDING_NORMALIZE = True # Normalize embeddings to unit length # Error recovery settings MAX_RETRIES = 2 # Reduced from 3 for deployment RECOVERY_WAIT_TIME = 1 # Seconds to wait between retries # Set up logging with improved formatting os.makedirs(Config.LOG_DIR, exist_ok=True) logging.basicConfig( level=logging.DEBUG if Config.DEBUG else logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(os.path.join(Config.LOG_DIR, f"mediquery_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")), logging.StreamHandler() ] ) logger = logging.getLogger("MediQuery") def debug_print(msg): """Print and log debug messages""" if Config.DEBUG: logger.debug(msg) print(f"DEBUG: {msg}") # === Helper Functions for Conditions === def get_mimic_cxr_conditions(): """Return the comprehensive list of conditions in MIMIC-CXR dataset""" return [ "atelectasis", "cardiomegaly", "consolidation", "edema", "enlarged cardiomediastinum", "fracture", "lung lesion", "lung opacity", "no finding", "pleural effusion", "pleural other", "pneumonia", "pneumothorax", "support devices" ] def get_condition_synonyms(): """Return synonyms for conditions to improve matching""" return { "atelectasis": ["atelectatic change", "collapsed lung", "lung collapse"], "cardiomegaly": ["enlarged heart", "cardiac enlargement", "heart enlargement"], "consolidation": ["airspace opacity", "air-space opacity", "alveolar opacity"], "edema": ["pulmonary edema", "fluid overload", "vascular congestion"], "fracture": ["broken bone", "bone fracture", "rib fracture"], "lung opacity": ["pulmonary opacity", "opacification", "lung opacification"], "pleural effusion": ["pleural fluid", "fluid in pleural space", "effusion"], "pneumonia": ["pulmonary infection", "lung infection", "bronchopneumonia"], "pneumothorax": ["air in pleural space", "collapsed lung", "ptx"], "support devices": ["tube", "line", "catheter", "pacemaker", "device"] } def get_anatomical_regions(): """Return mapping of anatomical regions with descriptions and conditions""" return { "upper_right_lung": { "description": "Upper right lung field", "conditions": ["pneumonia", "lung lesion", "pneumothorax", "atelectasis"] }, "upper_left_lung": { "description": "Upper left lung field", "conditions": ["pneumonia", "lung lesion", "pneumothorax", "atelectasis"] }, "middle_right_lung": { "description": "Middle right lung field", "conditions": ["pneumonia", "lung opacity", "atelectasis"] }, "lower_right_lung": { "description": "Lower right lung field", "conditions": ["pneumonia", "pleural effusion", "atelectasis"] }, "lower_left_lung": { "description": "Lower left lung field", "conditions": ["pneumonia", "pleural effusion", "atelectasis"] }, "heart": { "description": "Cardiac silhouette", "conditions": ["cardiomegaly", "enlarged cardiomediastinum"] }, "hilar": { "description": "Hilar regions", "conditions": ["enlarged cardiomediastinum", "adenopathy"] }, "costophrenic_angles": { "description": "Costophrenic angles", "conditions": ["pleural effusion", "pneumothorax"] }, "spine": { "description": "Spine", "conditions": ["fracture", "degenerative changes"] }, "diaphragm": { "description": "Diaphragm", "conditions": ["elevated diaphragm", "flattened diaphragm"] } } # === PHI Detection and Anonymization === def detect_phi(text): """Detect potential PHI (Protected Health Information) in text""" # Patterns for PHI detection patterns = { 'name': r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', 'mrn': r'\b[A-Z]{0,3}[0-9]{4,10}\b', 'ssn': r'\b[0-9]{3}[-]?[0-9]{2}[-]?[0-9]{4}\b', 'date': r'\b(0?[1-9]|1[0-2])[\/\-](0?[1-9]|[12]\d|3[01])[\/\-](19|20)\d{2}\b', 'phone': r'\b(\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}\b', 'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', 'address': r'\b\d+\s+[A-Z][a-z]+\s+[A-Z][a-z]+\.?\b' } # Check each pattern phi_detected = {} for phi_type, pattern in patterns.items(): matches = re.findall(pattern, text) if matches: phi_detected[phi_type] = matches return phi_detected def anonymize_text(text): """Replace potential PHI with [REDACTED]""" if not text: return "" if not Config.PHI_DETECTION_ENABLED: return text try: # Detect PHI phi_detected = detect_phi(text) # Replace PHI with [REDACTED] anonymized = text for phi_type, matches in phi_detected.items(): for match in matches: anonymized = anonymized.replace(match, "[REDACTED]") return anonymized except Exception as e: debug_print(f"Error in anonymize_text: {str(e)}") return text # === LRU Cache Implementation with Enhanced Features === class LRUCache: """LRU (Least Recently Used) Cache implementation with TTL and size tracking""" def __init__(self, capacity=Config.CACHE_SIZE, expiry_time=Config.CACHE_EXPIRY_TIME): self.cache = OrderedDict() self.capacity = capacity self.expiry_time = expiry_time # in seconds self.timestamps = {} self.size_tracking = { "current_size_bytes": 0, "max_size_bytes": 0, "items_evicted": 0, "cache_hits": 0, "cache_misses": 0 } def get(self, key): """Get item from cache with statistics tracking""" if key not in self.cache: self.size_tracking["cache_misses"] += 1 return None # Check expiry if self.is_expired(key): self._remove_with_tracking(key) self.size_tracking["cache_misses"] += 1 return None # Move to end (recently used) self.size_tracking["cache_hits"] += 1 value = self.cache.pop(key) self.cache[key] = value return value def put(self, key, value): """Add item to cache with size tracking""" # Calculate approximate size of the value value_size = self._estimate_size(value) if key in self.cache: old_value = self.cache.pop(key) old_size = self._estimate_size(old_value) self.size_tracking["current_size_bytes"] -= old_size # Make space if needed while len(self.cache) >= self.capacity or ( Config.OPTIMIZE_MEMORY and self.size_tracking["current_size_bytes"] + value_size > 1e9 # 1 GB limit ): self._evict_least_recently_used() # Add new item and timestamp self.cache[key] = value self.timestamps[key] = datetime.now().timestamp() self.size_tracking["current_size_bytes"] += value_size # Update max size if self.size_tracking["current_size_bytes"] > self.size_tracking["max_size_bytes"]: self.size_tracking["max_size_bytes"] = self.size_tracking["current_size_bytes"] def is_expired(self, key): """Check if item has expired""" if key not in self.timestamps: return True current_time = datetime.now().timestamp() return (current_time - self.timestamps[key]) > self.expiry_time def _evict_least_recently_used(self): """Remove least recently used item with tracking""" if not self.cache: return # Get oldest item key, value = self.cache.popitem(last=False) # Remove from timestamps and update tracking self._remove_with_tracking(key) def _remove_with_tracking(self, key): """Remove item with size tracking""" if key in self.cache: value = self.cache.pop(key) value_size = self._estimate_size(value) self.size_tracking["current_size_bytes"] -= value_size self.size_tracking["items_evicted"] += 1 if key in self.timestamps: self.timestamps.pop(key) def remove(self, key): """Remove item from cache""" self._remove_with_tracking(key) def clear(self): """Clear the cache""" self.cache.clear() self.timestamps.clear() self.size_tracking["current_size_bytes"] = 0 def get_stats(self): """Get cache statistics""" return { "size_bytes": self.size_tracking["current_size_bytes"], "max_size_bytes": self.size_tracking["max_size_bytes"], "items": len(self.cache), "capacity": self.capacity, "items_evicted": self.size_tracking["items_evicted"], "hit_rate": self.size_tracking["cache_hits"] / (self.size_tracking["cache_hits"] + self.size_tracking["cache_misses"] + 1e-8) } def _estimate_size(self, obj): """Estimate memory size of an object in bytes""" if obj is None: return 0 if isinstance(obj, np.ndarray): return obj.nbytes elif isinstance(obj, torch.Tensor): return obj.element_size() * obj.nelement() elif isinstance(obj, (str, bytes)): return len(obj) elif isinstance(obj, (list, tuple)): return sum(self._estimate_size(x) for x in obj) elif isinstance(obj, dict): return sum(self._estimate_size(k) + self._estimate_size(v) for k, v in obj.items()) else: # Fallback - rough estimate return sys.getsizeof(obj) # === Improved Lazy Model Loading === class LazyModel: """Lazy loading wrapper for models with proper method forwarding and error recovery""" def __init__(self, model_name, model_class, device, **kwargs): self.model_name = model_name self.model_class = model_class self.device = device self.kwargs = kwargs self._model = None self.last_error = None self.last_used = datetime.now() debug_print(f"LazyModel initialized for {model_name}") def _ensure_loaded(self, retries=Config.MAX_RETRIES): """Ensure model is loaded with retry mechanism""" if self._model is None: debug_print(f"Lazy loading model: {self.model_name}") for attempt in range(retries): try: self._model = self.model_class.from_pretrained(self.model_name, **self.kwargs) # Apply memory optimizations if Config.OPTIMIZE_MEMORY: # Convert to half precision if available and enabled if Config.USE_HALF_PRECISION and self.device.type == 'cuda' and hasattr(self._model, 'half'): self._model = self._model.half() debug_print(f"Using half precision for {self.model_name}") self._model = self._model.to(self.device) self._model.eval() # Set to evaluation mode debug_print(f"Model {self.model_name} loaded successfully") self.last_error = None break except Exception as e: self.last_error = str(e) debug_print(f"Error loading model {self.model_name} (attempt {attempt+1}/{retries}): {str(e)}") if attempt < retries - 1: # Wait before retrying time.sleep(Config.RECOVERY_WAIT_TIME) else: raise RuntimeError(f"Failed to load model {self.model_name} after {retries} attempts: {str(e)}") # Update last used timestamp self.last_used = datetime.now() return self._model def __call__(self, *args, **kwargs): """Call the model""" model = self._ensure_loaded() return model(*args, **kwargs) # Forward common model methods def generate(self, *args, **kwargs): """Forward generate method to model with error recovery""" model = self._ensure_loaded() try: return model.generate(*args, **kwargs) except Exception as e: # If generation fails, try reloading the model once debug_print(f"Generation failed, reloading model: {str(e)}") self.unload() model = self._ensure_loaded() return model.generate(*args, **kwargs) def to(self, device): """Move model to specified device""" self.device = device if self._model is not None: self._model = self._model.to(device) return self def eval(self): """Set model to evaluation mode""" if self._model is not None: self._model.eval() return self def unload(self): """Unload model from memory""" if self._model is not None: del self._model self._model = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() debug_print(f"Model {self.model_name} unloaded") # === MediQuery Core System === class MediQuery: """Core MediQuery system for medical image and text analysis""" def __init__(self, knowledge_base_dir=Config.DEFAULT_KNOWLEDGE_BASE_DIR, model_path=Config.DEFAULT_MODEL_PATH): self.knowledge_base_dir = knowledge_base_dir self.model_path = model_path self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") debug_print(f"Using device: {self.device}") # Create directories if they don't exist os.makedirs(knowledge_base_dir, exist_ok=True) os.makedirs(os.path.dirname(model_path), exist_ok=True) # Initialize caches self.embedding_cache = LRUCache(capacity=Config.CACHE_SIZE) self.query_cache = LRUCache(capacity=Config.CACHE_SIZE) # Initialize models self._init_models() # Load knowledge base self._init_knowledge_base() debug_print("MediQuery system initialized") def _init_models(self): """Initialize all required models with lazy loading""" debug_print("Initializing models...") # Image model if Config.IMAGE_MODEL == "chexnet": self.image_model = models.densenet121(pretrained=False) # For deployment, we'll download the weights during initialization try: # Simplified for deployment - would need to download weights self.image_model = nn.Sequential(*list(self.image_model.children())[:-1]) debug_print("CheXNet model initialized") except Exception as e: debug_print(f"Error initializing CheXNet: {str(e)}") # Fallback to standard DenseNet self.image_model = nn.Sequential(*list(models.densenet121(pretrained=True).children())[:-1]) else: self.image_model = nn.Sequential(*list(models.densenet121(pretrained=True).children())[:-1]) self.image_model = self.image_model.to(self.device).eval() # Text model - lazy loaded text_model_name = "dmis-lab/biobert-v1.1" if Config.TEXT_MODEL == "biobert" else "emilyalsentzer/Bio_ClinicalBERT" self.text_tokenizer = AutoTokenizer.from_pretrained(text_model_name) self.text_model = LazyModel( text_model_name, AutoModel, self.device ) # Generation model - lazy loaded if os.path.exists(self.model_path): gen_model_path = self.model_path else: gen_model_path = "google/flan-t5-base" # Fallback to base model self.gen_tokenizer = T5Tokenizer.from_pretrained(gen_model_path) self.gen_model = LazyModel( gen_model_path, T5ForConditionalGeneration, self.device ) # Image transformation self.image_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) debug_print("Models initialized") def _init_knowledge_base(self): """Initialize knowledge base with FAISS indices""" debug_print("Initializing knowledge base...") # For deployment, we'll create a minimal knowledge base # In a real deployment, you would download the knowledge base files # Create dummy knowledge base for demonstration self.text_data = pd.DataFrame({ 'combined_text': [ "The chest X-ray shows clear lung fields with no evidence of consolidation, effusion, or pneumothorax. The heart size is normal. No acute cardiopulmonary abnormality.", "Bilateral patchy airspace opacities consistent with multifocal pneumonia. No pleural effusion or pneumothorax. Heart size is normal.", "Cardiomegaly with pulmonary vascular congestion and bilateral pleural effusions, consistent with congestive heart failure. No pneumothorax or pneumonia.", "Right upper lobe opacity concerning for pneumonia. No pleural effusion or pneumothorax. Heart size is normal.", "Left lower lobe atelectasis. No pneumothorax or pleural effusion. Heart size is normal.", "Bilateral pleural effusions with bibasilar atelectasis. Cardiomegaly present. Findings consistent with heart failure.", "Right pneumothorax with partial lung collapse. No pleural effusion. Heart size is normal.", "Endotracheal tube, central venous catheter, and nasogastric tube in place. No pneumothorax or pleural effusion.", "Hyperinflated lungs with flattened diaphragms, consistent with COPD. No acute infiltrate or effusion.", "Multiple rib fractures on the right side. No pneumothorax or hemothorax. Lung fields are clear." ], 'valid_index': list(range(10)) }) # Create dummy FAISS indices self.image_index = None # Will be created on first use self.text_index = None # Will be created on first use debug_print("Knowledge base initialized") def _create_dummy_indices(self): """Create dummy FAISS indices for demonstration""" # Text embeddings (768 dimensions for BERT-based models) text_dim = 768 text_embeddings = np.random.rand(len(self.text_data), text_dim).astype('float32') # Image embeddings (1024 dimensions for DenseNet121) image_dim = 1024 image_embeddings = np.random.rand(len(self.text_data), image_dim).astype('float32') # Create FAISS indices self.text_index = faiss.IndexFlatL2(text_dim) self.text_index.add(text_embeddings) self.image_index = faiss.IndexFlatL2(image_dim) self.image_index.add(image_embeddings) debug_print("Dummy FAISS indices created") def process_image(self, image_path): """Process an X-ray image and return analysis results""" try: debug_print(f"Processing image: {image_path}") # Check cache if Config.USE_CACHING: cached_result = self.query_cache.get(f"img_{image_path}") if cached_result: debug_print("Using cached image result") return cached_result # Load and preprocess image image = Image.open(image_path).convert('RGB') image_tensor = self.image_transform(image).unsqueeze(0).to(self.device) # Generate image embedding with torch.no_grad(): image_embedding = self.image_model(image_tensor) image_embedding = nn.functional.avg_pool2d(image_embedding, kernel_size=7).squeeze().cpu().numpy() # Initialize FAISS indices if needed if self.image_index is None: self._create_dummy_indices() # Retrieve similar cases distances, indices = self.image_index.search(np.array([image_embedding]), k=Config.TOP_K_RETRIEVAL) # Get relevant text data retrieved_texts = [self.text_data.iloc[idx]['combined_text'] for idx in indices[0]] # Generate context for the model context = "\n\n".join(retrieved_texts[:Config.MAX_CONTEXT_DOCS]) # Generate analysis prompt = f"Analyze this chest X-ray based on similar cases:\n\n{context}\n\nProvide a detailed radiological assessment including findings and impression:" analysis = self._generate_text(prompt) # Generate attention map (simplified for deployment) attention_map = self._generate_attention_map(image) # Prepare result result = { "analysis": analysis, "attention_map": attention_map, "confidence": 0.85, # Placeholder "similar_cases": retrieved_texts[:3] # Return top 3 similar cases } # Cache result if Config.USE_CACHING: self.query_cache.put(f"img_{image_path}", result) return result except Exception as e: error_msg = f"Error processing image: {str(e)}\n{traceback.format_exc()}" debug_print(error_msg) return {"error": error_msg} def process_query(self, query_text): """Process a text query and return relevant information""" try: debug_print(f"Processing query: {query_text}") # Check cache if Config.USE_CACHING: cached_result = self.query_cache.get(f"txt_{query_text}") if cached_result: debug_print("Using cached query result") return cached_result # Anonymize query query_text = anonymize_text(query_text) # Generate text embedding query_embedding = self._generate_text_embedding(query_text) # Initialize FAISS indices if needed if self.text_index is None: self._create_dummy_indices() # Retrieve similar texts distances, indices = self.text_index.search(np.array([query_embedding]), k=Config.TOP_K_RETRIEVAL) # Get relevant text data retrieved_texts = [self.text_data.iloc[idx]['combined_text'] for idx in indices[0]] # Generate context for the model context = "\n\n".join(retrieved_texts[:Config.MAX_CONTEXT_DOCS]) # Generate response prompt = f"Answer this medical question based on the following information:\n\nQuestion: {query_text}\n\nRelevant information:\n{context}\n\nDetailed answer:" response = self._generate_text(prompt) # Prepare result result = { "response": response, "confidence": 0.9, # Placeholder "sources": retrieved_texts[:3] # Return top 3 sources } # Cache result if Config.USE_CACHING: self.query_cache.put(f"txt_{query_text}", result) return result except Exception as e: error_msg = f"Error processing query: {str(e)}\n{traceback.format_exc()}" debug_print(error_msg) return {"error": error_msg} def _generate_text_embedding(self, text): """Generate embedding for text using the text model""" try: # Check cache if Config.USE_CACHING: cached_embedding = self.embedding_cache.get(f"txt_emb_{text}") if cached_embedding is not None: return cached_embedding # Tokenize inputs = self.text_tokenizer( text, padding=True, truncation=True, return_tensors="pt", max_length=512 ).to(self.device) # Generate embedding with torch.no_grad(): outputs = self.text_model(**inputs) # Use mean pooling embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()[0] # Cache embedding if Config.USE_CACHING: self.embedding_cache.put(f"txt_emb_{text}", embedding) return embedding except Exception as e: debug_print(f"Error generating text embedding: {str(e)}") # Return random embedding as fallback return np.random.rand(768).astype('float32') def _generate_text(self, prompt): """Generate text using the language model""" try: # Tokenize inputs = self.gen_tokenizer( prompt, padding=True, truncation=True, return_tensors="pt", max_length=512 ).to(self.device) # Generate with torch.no_grad(): output_ids = self.gen_model.generate( inputs.input_ids, max_length=256, num_beams=4, early_stopping=True ) # Decode output_text = self.gen_tokenizer.decode(output_ids[0], skip_special_tokens=True) return output_text except Exception as e: debug_print(f"Error generating text: {str(e)}") return "I apologize, but I'm unable to generate a response at this time. Please try again later." def _generate_attention_map(self, image): """Generate a simplified attention map for the image""" try: # Convert to numpy array img_np = np.array(image.resize((224, 224))) # Create a simple heatmap (this is a placeholder - real implementation would use model attention) heatmap = np.zeros((224, 224), dtype=np.float32) # Add some random "attention" areas for _ in range(3): x, y = np.random.randint(50, 174, 2) radius = np.random.randint(20, 50) for i in range(224): for j in range(224): dist = np.sqrt((i - x)**2 + (j - y)**2) if dist < radius: heatmap[i, j] += max(0, 1 - dist/radius) # Normalize heatmap = heatmap / heatmap.max() # Apply colormap heatmap_colored = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET) # Overlay on original image img_rgb = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) overlay = cv2.addWeighted(img_rgb, 0.7, heatmap_colored, 0.3, 0) # Convert to base64 for API response _, buffer = cv2.imencode('.png', overlay) img_str = base64.b64encode(buffer).decode('utf-8') return img_str except Exception as e: debug_print(f"Error generating attention map: {str(e)}") return None def cleanup(self): """Clean up resources""" debug_print("Cleaning up resources...") # Unload models if hasattr(self, 'text_model') and isinstance(self.text_model, LazyModel): self.text_model.unload() if hasattr(self, 'gen_model') and isinstance(self.gen_model, LazyModel): self.gen_model.unload() # Clear caches if hasattr(self, 'embedding_cache'): self.embedding_cache.clear() if hasattr(self, 'query_cache'): self.query_cache.clear() # Force garbage collection gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() debug_print("Cleanup complete") # === FastAPI Application === app = FastAPI(title="MediQuery API", description="API for MediQuery AI medical assistant") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # For production, specify the actual frontend domain allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize MediQuery system mediquery = MediQuery() # Define API models class QueryRequest(BaseModel): text: str class QueryResponse(BaseModel): response: str confidence: float sources: List[str] error: Optional[str] = None class ImageAnalysisResponse(BaseModel): analysis: str attention_map: Optional[str] = None confidence: float similar_cases: List[str] error: Optional[str] = None @app.post("/api/query", response_model=QueryResponse) async def process_text_query(query: QueryRequest): """Process a text query and return relevant information""" result = mediquery.process_query(query.text) return result @app.post("/api/analyze-image", response_model=ImageAnalysisResponse) async def analyze_image(file: UploadFile = File(...)): """Analyze an X-ray image and return results""" # Save uploaded file temporarily temp_file = f"/tmp/{file.filename}" with open(temp_file, "wb") as f: f.write(await file.read()) # Process image result = mediquery.process_image(temp_file) # Clean up os.remove(temp_file) return result @app.get("/api/health") async def health_check(): """Health check endpoint""" return {"status": "ok", "version": "1.0.0"} # === Gradio Interface === def create_gradio_interface(): """Create a Gradio interface for the MediQuery system""" # Define processing functions def process_image_gradio(image): # Save image temporarily temp_file = "/tmp/gradio_image.png" image.save(temp_file) # Process image result = mediquery.process_image(temp_file) # Clean up os.remove(temp_file) # Prepare output analysis = result.get("analysis", "Error processing image") attention_map_b64 = result.get("attention_map") # Convert base64 to image if available attention_map = None if attention_map_b64: try: attention_map = Image.open(io.BytesIO(base64.b64decode(attention_map_b64))) except: pass return analysis, attention_map def process_query_gradio(query): result = mediquery.process_query(query) return result.get("response", "Error processing query") # Create interface with gr.Blocks(title="MediQuery") as demo: gr.Markdown("# MediQuery - AI Medical Assistant") with gr.Tab("Image Analysis"): with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Chest X-ray") image_button = gr.Button("Analyze X-ray") with gr.Column(): text_output = gr.Textbox(label="Analysis Results", lines=10) image_output = gr.Image(label="Attention Map") image_button.click( fn=process_image_gradio, inputs=[image_input], outputs=[text_output, image_output] ) with gr.Tab("Text Query"): query_input = gr.Textbox(label="Medical Query", lines=3, placeholder="e.g., What does pneumonia look like on a chest X-ray?") query_button = gr.Button("Submit Query") query_output = gr.Textbox(label="Response", lines=10) query_button.click( fn=process_query_gradio, inputs=[query_input], outputs=[query_output] ) gr.Markdown("## Example Queries") gr.Examples( examples=[ ["What does pleural effusion look like?"], ["How to differentiate pneumonia from tuberculosis?"], ["What are the signs of cardiomegaly on X-ray?"] ], inputs=[query_input] ) return demo # Create Gradio interface demo = create_gradio_interface() # Mount Gradio app to FastAPI app = gr.mount_gradio_app(app, demo, path="/") # Startup and shutdown events @app.on_event("startup") async def startup_event(): """Initialize resources on startup""" debug_print("API starting up...") @app.on_event("shutdown") async def shutdown_event(): """Clean up resources on shutdown""" debug_print("API shutting down...") mediquery.cleanup() # Run the FastAPI app with uvicorn when executed directly if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)