Spaces:
Runtime error
Runtime error
import os | |
import logging | |
import numpy as np | |
import torch | |
from torch import nn | |
import pandas as pd | |
from torchvision import transforms, models | |
from PIL import Image | |
import faiss | |
from transformers import AutoTokenizer, AutoModel, T5ForConditionalGeneration, T5Tokenizer | |
import gradio as gr | |
import cv2 | |
import traceback | |
from datetime import datetime | |
import re | |
import random | |
import functools | |
import gc | |
from collections import OrderedDict | |
import json | |
import sys | |
import time | |
from tqdm.auto import tqdm | |
import warnings | |
import matplotlib.pyplot as plt | |
from fastapi import FastAPI, File, UploadFile, Form | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from typing import Optional, List, Dict, Any, Union | |
import base64 | |
import io | |
# Suppress unnecessary warnings | |
warnings.filterwarnings("ignore", category=UserWarning) | |
# === Configuration === | |
class Config: | |
"""Configuration for MediQuery system""" | |
# Model configuration | |
IMAGE_MODEL = "chexnet" # Options: "chexnet", "densenet" | |
TEXT_MODEL = "biobert" # Options: "biobert", "clinicalbert" | |
GEN_MODEL = "flan-t5-base-finetuned" # Base generation model | |
# Resource management | |
CACHE_SIZE = 50 # Reduced from 200 for deployment | |
CACHE_EXPIRY_TIME = 1800 # Cache expiry time in seconds (30 minutes) | |
LAZY_LOADING = True # Enable lazy loading of models | |
USE_HALF_PRECISION = True # Use half precision for models if available | |
# Feature flags | |
DEBUG = True # Enable detailed debugging | |
PHI_DETECTION_ENABLED = True # Enable PHI detection | |
ANATOMY_MAPPING_ENABLED = True # Enable anatomical mapping | |
# Thresholds and parameters | |
CONFIDENCE_THRESHOLD = 0.4 # Threshold for flagging low confidence | |
TOP_K_RETRIEVAL = 10 # Reduced from 30 for deployment | |
MAX_CONTEXT_DOCS = 3 # Reduced from 5 for deployment | |
# Advanced retrieval settings | |
DYNAMIC_RERANKING = True # Dynamically adjust reranking weights | |
DIVERSITY_PENALTY = 0.1 # Penalty for duplicate content | |
# Performance optimization | |
BATCH_SIZE = 1 # Reduced from 4 for deployment | |
OPTIMIZE_MEMORY = True # Optimize memory usage | |
USE_CACHING = True # Use caching for embeddings and queries | |
# Path settings | |
DEFAULT_KNOWLEDGE_BASE_DIR = "./knowledge_base" | |
DEFAULT_MODEL_PATH = "./models/flan-t5-finetuned" | |
LOG_DIR = "./logs" | |
# Advanced settings | |
EMBEDDING_AGGREGATION = "weighted_avg" # Options: "avg", "weighted_avg", "cls", "pooled" | |
EMBEDDING_NORMALIZE = True # Normalize embeddings to unit length | |
# Error recovery settings | |
MAX_RETRIES = 2 # Reduced from 3 for deployment | |
RECOVERY_WAIT_TIME = 1 # Seconds to wait between retries | |
# Set up logging with improved formatting | |
os.makedirs(Config.LOG_DIR, exist_ok=True) | |
logging.basicConfig( | |
level=logging.DEBUG if Config.DEBUG else logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler(os.path.join(Config.LOG_DIR, f"mediquery_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger("MediQuery") | |
def debug_print(msg): | |
"""Print and log debug messages""" | |
if Config.DEBUG: | |
logger.debug(msg) | |
print(f"DEBUG: {msg}") | |
# === Helper Functions for Conditions === | |
def get_mimic_cxr_conditions(): | |
"""Return the comprehensive list of conditions in MIMIC-CXR dataset""" | |
return [ | |
"atelectasis", | |
"cardiomegaly", | |
"consolidation", | |
"edema", | |
"enlarged cardiomediastinum", | |
"fracture", | |
"lung lesion", | |
"lung opacity", | |
"no finding", | |
"pleural effusion", | |
"pleural other", | |
"pneumonia", | |
"pneumothorax", | |
"support devices" | |
] | |
def get_condition_synonyms(): | |
"""Return synonyms for conditions to improve matching""" | |
return { | |
"atelectasis": ["atelectatic change", "collapsed lung", "lung collapse"], | |
"cardiomegaly": ["enlarged heart", "cardiac enlargement", "heart enlargement"], | |
"consolidation": ["airspace opacity", "air-space opacity", "alveolar opacity"], | |
"edema": ["pulmonary edema", "fluid overload", "vascular congestion"], | |
"fracture": ["broken bone", "bone fracture", "rib fracture"], | |
"lung opacity": ["pulmonary opacity", "opacification", "lung opacification"], | |
"pleural effusion": ["pleural fluid", "fluid in pleural space", "effusion"], | |
"pneumonia": ["pulmonary infection", "lung infection", "bronchopneumonia"], | |
"pneumothorax": ["air in pleural space", "collapsed lung", "ptx"], | |
"support devices": ["tube", "line", "catheter", "pacemaker", "device"] | |
} | |
def get_anatomical_regions(): | |
"""Return mapping of anatomical regions with descriptions and conditions""" | |
return { | |
"upper_right_lung": { | |
"description": "Upper right lung field", | |
"conditions": ["pneumonia", "lung lesion", "pneumothorax", "atelectasis"] | |
}, | |
"upper_left_lung": { | |
"description": "Upper left lung field", | |
"conditions": ["pneumonia", "lung lesion", "pneumothorax", "atelectasis"] | |
}, | |
"middle_right_lung": { | |
"description": "Middle right lung field", | |
"conditions": ["pneumonia", "lung opacity", "atelectasis"] | |
}, | |
"lower_right_lung": { | |
"description": "Lower right lung field", | |
"conditions": ["pneumonia", "pleural effusion", "atelectasis"] | |
}, | |
"lower_left_lung": { | |
"description": "Lower left lung field", | |
"conditions": ["pneumonia", "pleural effusion", "atelectasis"] | |
}, | |
"heart": { | |
"description": "Cardiac silhouette", | |
"conditions": ["cardiomegaly", "enlarged cardiomediastinum"] | |
}, | |
"hilar": { | |
"description": "Hilar regions", | |
"conditions": ["enlarged cardiomediastinum", "adenopathy"] | |
}, | |
"costophrenic_angles": { | |
"description": "Costophrenic angles", | |
"conditions": ["pleural effusion", "pneumothorax"] | |
}, | |
"spine": { | |
"description": "Spine", | |
"conditions": ["fracture", "degenerative changes"] | |
}, | |
"diaphragm": { | |
"description": "Diaphragm", | |
"conditions": ["elevated diaphragm", "flattened diaphragm"] | |
} | |
} | |
# === PHI Detection and Anonymization === | |
def detect_phi(text): | |
"""Detect potential PHI (Protected Health Information) in text""" | |
# Patterns for PHI detection | |
patterns = { | |
'name': r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', | |
'mrn': r'\b[A-Z]{0,3}[0-9]{4,10}\b', | |
'ssn': r'\b[0-9]{3}[-]?[0-9]{2}[-]?[0-9]{4}\b', | |
'date': r'\b(0?[1-9]|1[0-2])[\/\-](0?[1-9]|[12]\d|3[01])[\/\-](19|20)\d{2}\b', | |
'phone': r'\b(\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}\b', | |
'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', | |
'address': r'\b\d+\s+[A-Z][a-z]+\s+[A-Z][a-z]+\.?\b' | |
} | |
# Check each pattern | |
phi_detected = {} | |
for phi_type, pattern in patterns.items(): | |
matches = re.findall(pattern, text) | |
if matches: | |
phi_detected[phi_type] = matches | |
return phi_detected | |
def anonymize_text(text): | |
"""Replace potential PHI with [REDACTED]""" | |
if not text: | |
return "" | |
if not Config.PHI_DETECTION_ENABLED: | |
return text | |
try: | |
# Detect PHI | |
phi_detected = detect_phi(text) | |
# Replace PHI with [REDACTED] | |
anonymized = text | |
for phi_type, matches in phi_detected.items(): | |
for match in matches: | |
anonymized = anonymized.replace(match, "[REDACTED]") | |
return anonymized | |
except Exception as e: | |
debug_print(f"Error in anonymize_text: {str(e)}") | |
return text | |
# === LRU Cache Implementation with Enhanced Features === | |
class LRUCache: | |
"""LRU (Least Recently Used) Cache implementation with TTL and size tracking""" | |
def __init__(self, capacity=Config.CACHE_SIZE, expiry_time=Config.CACHE_EXPIRY_TIME): | |
self.cache = OrderedDict() | |
self.capacity = capacity | |
self.expiry_time = expiry_time # in seconds | |
self.timestamps = {} | |
self.size_tracking = { | |
"current_size_bytes": 0, | |
"max_size_bytes": 0, | |
"items_evicted": 0, | |
"cache_hits": 0, | |
"cache_misses": 0 | |
} | |
def get(self, key): | |
"""Get item from cache with statistics tracking""" | |
if key not in self.cache: | |
self.size_tracking["cache_misses"] += 1 | |
return None | |
# Check expiry | |
if self.is_expired(key): | |
self._remove_with_tracking(key) | |
self.size_tracking["cache_misses"] += 1 | |
return None | |
# Move to end (recently used) | |
self.size_tracking["cache_hits"] += 1 | |
value = self.cache.pop(key) | |
self.cache[key] = value | |
return value | |
def put(self, key, value): | |
"""Add item to cache with size tracking""" | |
# Calculate approximate size of the value | |
value_size = self._estimate_size(value) | |
if key in self.cache: | |
old_value = self.cache.pop(key) | |
old_size = self._estimate_size(old_value) | |
self.size_tracking["current_size_bytes"] -= old_size | |
# Make space if needed | |
while len(self.cache) >= self.capacity or ( | |
Config.OPTIMIZE_MEMORY and | |
self.size_tracking["current_size_bytes"] + value_size > 1e9 # 1 GB limit | |
): | |
self._evict_least_recently_used() | |
# Add new item and timestamp | |
self.cache[key] = value | |
self.timestamps[key] = datetime.now().timestamp() | |
self.size_tracking["current_size_bytes"] += value_size | |
# Update max size | |
if self.size_tracking["current_size_bytes"] > self.size_tracking["max_size_bytes"]: | |
self.size_tracking["max_size_bytes"] = self.size_tracking["current_size_bytes"] | |
def is_expired(self, key): | |
"""Check if item has expired""" | |
if key not in self.timestamps: | |
return True | |
current_time = datetime.now().timestamp() | |
return (current_time - self.timestamps[key]) > self.expiry_time | |
def _evict_least_recently_used(self): | |
"""Remove least recently used item with tracking""" | |
if not self.cache: | |
return | |
# Get oldest item | |
key, value = self.cache.popitem(last=False) | |
# Remove from timestamps and update tracking | |
self._remove_with_tracking(key) | |
def _remove_with_tracking(self, key): | |
"""Remove item with size tracking""" | |
if key in self.cache: | |
value = self.cache.pop(key) | |
value_size = self._estimate_size(value) | |
self.size_tracking["current_size_bytes"] -= value_size | |
self.size_tracking["items_evicted"] += 1 | |
if key in self.timestamps: | |
self.timestamps.pop(key) | |
def remove(self, key): | |
"""Remove item from cache""" | |
self._remove_with_tracking(key) | |
def clear(self): | |
"""Clear the cache""" | |
self.cache.clear() | |
self.timestamps.clear() | |
self.size_tracking["current_size_bytes"] = 0 | |
def get_stats(self): | |
"""Get cache statistics""" | |
return { | |
"size_bytes": self.size_tracking["current_size_bytes"], | |
"max_size_bytes": self.size_tracking["max_size_bytes"], | |
"items": len(self.cache), | |
"capacity": self.capacity, | |
"items_evicted": self.size_tracking["items_evicted"], | |
"hit_rate": self.size_tracking["cache_hits"] / | |
(self.size_tracking["cache_hits"] + self.size_tracking["cache_misses"] + 1e-8) | |
} | |
def _estimate_size(self, obj): | |
"""Estimate memory size of an object in bytes""" | |
if obj is None: | |
return 0 | |
if isinstance(obj, np.ndarray): | |
return obj.nbytes | |
elif isinstance(obj, torch.Tensor): | |
return obj.element_size() * obj.nelement() | |
elif isinstance(obj, (str, bytes)): | |
return len(obj) | |
elif isinstance(obj, (list, tuple)): | |
return sum(self._estimate_size(x) for x in obj) | |
elif isinstance(obj, dict): | |
return sum(self._estimate_size(k) + self._estimate_size(v) for k, v in obj.items()) | |
else: | |
# Fallback - rough estimate | |
return sys.getsizeof(obj) | |
# === Improved Lazy Model Loading === | |
class LazyModel: | |
"""Lazy loading wrapper for models with proper method forwarding and error recovery""" | |
def __init__(self, model_name, model_class, device, **kwargs): | |
self.model_name = model_name | |
self.model_class = model_class | |
self.device = device | |
self.kwargs = kwargs | |
self._model = None | |
self.last_error = None | |
self.last_used = datetime.now() | |
debug_print(f"LazyModel initialized for {model_name}") | |
def _ensure_loaded(self, retries=Config.MAX_RETRIES): | |
"""Ensure model is loaded with retry mechanism""" | |
if self._model is None: | |
debug_print(f"Lazy loading model: {self.model_name}") | |
for attempt in range(retries): | |
try: | |
self._model = self.model_class.from_pretrained(self.model_name, **self.kwargs) | |
# Apply memory optimizations | |
if Config.OPTIMIZE_MEMORY: | |
# Convert to half precision if available and enabled | |
if Config.USE_HALF_PRECISION and self.device.type == 'cuda' and hasattr(self._model, 'half'): | |
self._model = self._model.half() | |
debug_print(f"Using half precision for {self.model_name}") | |
self._model = self._model.to(self.device) | |
self._model.eval() # Set to evaluation mode | |
debug_print(f"Model {self.model_name} loaded successfully") | |
self.last_error = None | |
break | |
except Exception as e: | |
self.last_error = str(e) | |
debug_print(f"Error loading model {self.model_name} (attempt {attempt+1}/{retries}): {str(e)}") | |
if attempt < retries - 1: | |
# Wait before retrying | |
time.sleep(Config.RECOVERY_WAIT_TIME) | |
else: | |
raise RuntimeError(f"Failed to load model {self.model_name} after {retries} attempts: {str(e)}") | |
# Update last used timestamp | |
self.last_used = datetime.now() | |
return self._model | |
def __call__(self, *args, **kwargs): | |
"""Call the model""" | |
model = self._ensure_loaded() | |
return model(*args, **kwargs) | |
# Forward common model methods | |
def generate(self, *args, **kwargs): | |
"""Forward generate method to model with error recovery""" | |
model = self._ensure_loaded() | |
try: | |
return model.generate(*args, **kwargs) | |
except Exception as e: | |
# If generation fails, try reloading the model once | |
debug_print(f"Generation failed, reloading model: {str(e)}") | |
self.unload() | |
model = self._ensure_loaded() | |
return model.generate(*args, **kwargs) | |
def to(self, device): | |
"""Move model to specified device""" | |
self.device = device | |
if self._model is not None: | |
self._model = self._model.to(device) | |
return self | |
def eval(self): | |
"""Set model to evaluation mode""" | |
if self._model is not None: | |
self._model.eval() | |
return self | |
def unload(self): | |
"""Unload model from memory""" | |
if self._model is not None: | |
del self._model | |
self._model = None | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
debug_print(f"Model {self.model_name} unloaded") | |
# === MediQuery Core System === | |
class MediQuery: | |
"""Core MediQuery system for medical image and text analysis""" | |
def __init__(self, knowledge_base_dir=Config.DEFAULT_KNOWLEDGE_BASE_DIR, model_path=Config.DEFAULT_MODEL_PATH): | |
self.knowledge_base_dir = knowledge_base_dir | |
self.model_path = model_path | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
debug_print(f"Using device: {self.device}") | |
# Create directories if they don't exist | |
os.makedirs(knowledge_base_dir, exist_ok=True) | |
os.makedirs(os.path.dirname(model_path), exist_ok=True) | |
# Initialize caches | |
self.embedding_cache = LRUCache(capacity=Config.CACHE_SIZE) | |
self.query_cache = LRUCache(capacity=Config.CACHE_SIZE) | |
# Initialize models | |
self._init_models() | |
# Load knowledge base | |
self._init_knowledge_base() | |
debug_print("MediQuery system initialized") | |
def _init_models(self): | |
"""Initialize all required models with lazy loading""" | |
debug_print("Initializing models...") | |
# Image model | |
if Config.IMAGE_MODEL == "chexnet": | |
self.image_model = models.densenet121(pretrained=False) | |
# For deployment, we'll download the weights during initialization | |
try: | |
# Simplified for deployment - would need to download weights | |
self.image_model = nn.Sequential(*list(self.image_model.children())[:-1]) | |
debug_print("CheXNet model initialized") | |
except Exception as e: | |
debug_print(f"Error initializing CheXNet: {str(e)}") | |
# Fallback to standard DenseNet | |
self.image_model = nn.Sequential(*list(models.densenet121(pretrained=True).children())[:-1]) | |
else: | |
self.image_model = nn.Sequential(*list(models.densenet121(pretrained=True).children())[:-1]) | |
self.image_model = self.image_model.to(self.device).eval() | |
# Text model - lazy loaded | |
text_model_name = "dmis-lab/biobert-v1.1" if Config.TEXT_MODEL == "biobert" else "emilyalsentzer/Bio_ClinicalBERT" | |
self.text_tokenizer = AutoTokenizer.from_pretrained(text_model_name) | |
self.text_model = LazyModel( | |
text_model_name, | |
AutoModel, | |
self.device | |
) | |
# Generation model - lazy loaded | |
if os.path.exists(self.model_path): | |
gen_model_path = self.model_path | |
else: | |
gen_model_path = "google/flan-t5-base" # Fallback to base model | |
self.gen_tokenizer = T5Tokenizer.from_pretrained(gen_model_path) | |
self.gen_model = LazyModel( | |
gen_model_path, | |
T5ForConditionalGeneration, | |
self.device | |
) | |
# Image transformation | |
self.image_transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
debug_print("Models initialized") | |
def _init_knowledge_base(self): | |
"""Initialize knowledge base with FAISS indices""" | |
debug_print("Initializing knowledge base...") | |
# For deployment, we'll create a minimal knowledge base | |
# In a real deployment, you would download the knowledge base files | |
# Create dummy knowledge base for demonstration | |
self.text_data = pd.DataFrame({ | |
'combined_text': [ | |
"The chest X-ray shows clear lung fields with no evidence of consolidation, effusion, or pneumothorax. The heart size is normal. No acute cardiopulmonary abnormality.", | |
"Bilateral patchy airspace opacities consistent with multifocal pneumonia. No pleural effusion or pneumothorax. Heart size is normal.", | |
"Cardiomegaly with pulmonary vascular congestion and bilateral pleural effusions, consistent with congestive heart failure. No pneumothorax or pneumonia.", | |
"Right upper lobe opacity concerning for pneumonia. No pleural effusion or pneumothorax. Heart size is normal.", | |
"Left lower lobe atelectasis. No pneumothorax or pleural effusion. Heart size is normal.", | |
"Bilateral pleural effusions with bibasilar atelectasis. Cardiomegaly present. Findings consistent with heart failure.", | |
"Right pneumothorax with partial lung collapse. No pleural effusion. Heart size is normal.", | |
"Endotracheal tube, central venous catheter, and nasogastric tube in place. No pneumothorax or pleural effusion.", | |
"Hyperinflated lungs with flattened diaphragms, consistent with COPD. No acute infiltrate or effusion.", | |
"Multiple rib fractures on the right side. No pneumothorax or hemothorax. Lung fields are clear." | |
], | |
'valid_index': list(range(10)) | |
}) | |
# Create dummy FAISS indices | |
self.image_index = None # Will be created on first use | |
self.text_index = None # Will be created on first use | |
debug_print("Knowledge base initialized") | |
def _create_dummy_indices(self): | |
"""Create dummy FAISS indices for demonstration""" | |
# Text embeddings (768 dimensions for BERT-based models) | |
text_dim = 768 | |
text_embeddings = np.random.rand(len(self.text_data), text_dim).astype('float32') | |
# Image embeddings (1024 dimensions for DenseNet121) | |
image_dim = 1024 | |
image_embeddings = np.random.rand(len(self.text_data), image_dim).astype('float32') | |
# Create FAISS indices | |
self.text_index = faiss.IndexFlatL2(text_dim) | |
self.text_index.add(text_embeddings) | |
self.image_index = faiss.IndexFlatL2(image_dim) | |
self.image_index.add(image_embeddings) | |
debug_print("Dummy FAISS indices created") | |
def process_image(self, image_path): | |
"""Process an X-ray image and return analysis results""" | |
try: | |
debug_print(f"Processing image: {image_path}") | |
# Check cache | |
if Config.USE_CACHING: | |
cached_result = self.query_cache.get(f"img_{image_path}") | |
if cached_result: | |
debug_print("Using cached image result") | |
return cached_result | |
# Load and preprocess image | |
image = Image.open(image_path).convert('RGB') | |
image_tensor = self.image_transform(image).unsqueeze(0).to(self.device) | |
# Generate image embedding | |
with torch.no_grad(): | |
image_embedding = self.image_model(image_tensor) | |
image_embedding = nn.functional.avg_pool2d(image_embedding, kernel_size=7).squeeze().cpu().numpy() | |
# Initialize FAISS indices if needed | |
if self.image_index is None: | |
self._create_dummy_indices() | |
# Retrieve similar cases | |
distances, indices = self.image_index.search(np.array([image_embedding]), k=Config.TOP_K_RETRIEVAL) | |
# Get relevant text data | |
retrieved_texts = [self.text_data.iloc[idx]['combined_text'] for idx in indices[0]] | |
# Generate context for the model | |
context = "\n\n".join(retrieved_texts[:Config.MAX_CONTEXT_DOCS]) | |
# Generate analysis | |
prompt = f"Analyze this chest X-ray based on similar cases:\n\n{context}\n\nProvide a detailed radiological assessment including findings and impression:" | |
analysis = self._generate_text(prompt) | |
# Generate attention map (simplified for deployment) | |
attention_map = self._generate_attention_map(image) | |
# Prepare result | |
result = { | |
"analysis": analysis, | |
"attention_map": attention_map, | |
"confidence": 0.85, # Placeholder | |
"similar_cases": retrieved_texts[:3] # Return top 3 similar cases | |
} | |
# Cache result | |
if Config.USE_CACHING: | |
self.query_cache.put(f"img_{image_path}", result) | |
return result | |
except Exception as e: | |
error_msg = f"Error processing image: {str(e)}\n{traceback.format_exc()}" | |
debug_print(error_msg) | |
return {"error": error_msg} | |
def process_query(self, query_text): | |
"""Process a text query and return relevant information""" | |
try: | |
debug_print(f"Processing query: {query_text}") | |
# Check cache | |
if Config.USE_CACHING: | |
cached_result = self.query_cache.get(f"txt_{query_text}") | |
if cached_result: | |
debug_print("Using cached query result") | |
return cached_result | |
# Anonymize query | |
query_text = anonymize_text(query_text) | |
# Generate text embedding | |
query_embedding = self._generate_text_embedding(query_text) | |
# Initialize FAISS indices if needed | |
if self.text_index is None: | |
self._create_dummy_indices() | |
# Retrieve similar texts | |
distances, indices = self.text_index.search(np.array([query_embedding]), k=Config.TOP_K_RETRIEVAL) | |
# Get relevant text data | |
retrieved_texts = [self.text_data.iloc[idx]['combined_text'] for idx in indices[0]] | |
# Generate context for the model | |
context = "\n\n".join(retrieved_texts[:Config.MAX_CONTEXT_DOCS]) | |
# Generate response | |
prompt = f"Answer this medical question based on the following information:\n\nQuestion: {query_text}\n\nRelevant information:\n{context}\n\nDetailed answer:" | |
response = self._generate_text(prompt) | |
# Prepare result | |
result = { | |
"response": response, | |
"confidence": 0.9, # Placeholder | |
"sources": retrieved_texts[:3] # Return top 3 sources | |
} | |
# Cache result | |
if Config.USE_CACHING: | |
self.query_cache.put(f"txt_{query_text}", result) | |
return result | |
except Exception as e: | |
error_msg = f"Error processing query: {str(e)}\n{traceback.format_exc()}" | |
debug_print(error_msg) | |
return {"error": error_msg} | |
def _generate_text_embedding(self, text): | |
"""Generate embedding for text using the text model""" | |
try: | |
# Check cache | |
if Config.USE_CACHING: | |
cached_embedding = self.embedding_cache.get(f"txt_emb_{text}") | |
if cached_embedding is not None: | |
return cached_embedding | |
# Tokenize | |
inputs = self.text_tokenizer( | |
text, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
max_length=512 | |
).to(self.device) | |
# Generate embedding | |
with torch.no_grad(): | |
outputs = self.text_model(**inputs) | |
# Use mean pooling | |
embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()[0] | |
# Cache embedding | |
if Config.USE_CACHING: | |
self.embedding_cache.put(f"txt_emb_{text}", embedding) | |
return embedding | |
except Exception as e: | |
debug_print(f"Error generating text embedding: {str(e)}") | |
# Return random embedding as fallback | |
return np.random.rand(768).astype('float32') | |
def _generate_text(self, prompt): | |
"""Generate text using the language model""" | |
try: | |
# Tokenize | |
inputs = self.gen_tokenizer( | |
prompt, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
max_length=512 | |
).to(self.device) | |
# Generate | |
with torch.no_grad(): | |
output_ids = self.gen_model.generate( | |
inputs.input_ids, | |
max_length=256, | |
num_beams=4, | |
early_stopping=True | |
) | |
# Decode | |
output_text = self.gen_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return output_text | |
except Exception as e: | |
debug_print(f"Error generating text: {str(e)}") | |
return "I apologize, but I'm unable to generate a response at this time. Please try again later." | |
def _generate_attention_map(self, image): | |
"""Generate a simplified attention map for the image""" | |
try: | |
# Convert to numpy array | |
img_np = np.array(image.resize((224, 224))) | |
# Create a simple heatmap (this is a placeholder - real implementation would use model attention) | |
heatmap = np.zeros((224, 224), dtype=np.float32) | |
# Add some random "attention" areas | |
for _ in range(3): | |
x, y = np.random.randint(50, 174, 2) | |
radius = np.random.randint(20, 50) | |
for i in range(224): | |
for j in range(224): | |
dist = np.sqrt((i - x)**2 + (j - y)**2) | |
if dist < radius: | |
heatmap[i, j] += max(0, 1 - dist/radius) | |
# Normalize | |
heatmap = heatmap / heatmap.max() | |
# Apply colormap | |
heatmap_colored = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET) | |
# Overlay on original image | |
img_rgb = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) | |
overlay = cv2.addWeighted(img_rgb, 0.7, heatmap_colored, 0.3, 0) | |
# Convert to base64 for API response | |
_, buffer = cv2.imencode('.png', overlay) | |
img_str = base64.b64encode(buffer).decode('utf-8') | |
return img_str | |
except Exception as e: | |
debug_print(f"Error generating attention map: {str(e)}") | |
return None | |
def cleanup(self): | |
"""Clean up resources""" | |
debug_print("Cleaning up resources...") | |
# Unload models | |
if hasattr(self, 'text_model') and isinstance(self.text_model, LazyModel): | |
self.text_model.unload() | |
if hasattr(self, 'gen_model') and isinstance(self.gen_model, LazyModel): | |
self.gen_model.unload() | |
# Clear caches | |
if hasattr(self, 'embedding_cache'): | |
self.embedding_cache.clear() | |
if hasattr(self, 'query_cache'): | |
self.query_cache.clear() | |
# Force garbage collection | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
debug_print("Cleanup complete") | |
# === FastAPI Application === | |
app = FastAPI(title="MediQuery API", description="API for MediQuery AI medical assistant") | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # For production, specify the actual frontend domain | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize MediQuery system | |
mediquery = MediQuery() | |
# Define API models | |
class QueryRequest(BaseModel): | |
text: str | |
class QueryResponse(BaseModel): | |
response: str | |
confidence: float | |
sources: List[str] | |
error: Optional[str] = None | |
class ImageAnalysisResponse(BaseModel): | |
analysis: str | |
attention_map: Optional[str] = None | |
confidence: float | |
similar_cases: List[str] | |
error: Optional[str] = None | |
async def process_text_query(query: QueryRequest): | |
"""Process a text query and return relevant information""" | |
result = mediquery.process_query(query.text) | |
return result | |
async def analyze_image(file: UploadFile = File(...)): | |
"""Analyze an X-ray image and return results""" | |
# Save uploaded file temporarily | |
temp_file = f"/tmp/{file.filename}" | |
with open(temp_file, "wb") as f: | |
f.write(await file.read()) | |
# Process image | |
result = mediquery.process_image(temp_file) | |
# Clean up | |
os.remove(temp_file) | |
return result | |
async def health_check(): | |
"""Health check endpoint""" | |
return {"status": "ok", "version": "1.0.0"} | |
# === Gradio Interface === | |
def create_gradio_interface(): | |
"""Create a Gradio interface for the MediQuery system""" | |
# Define processing functions | |
def process_image_gradio(image): | |
# Save image temporarily | |
temp_file = "/tmp/gradio_image.png" | |
image.save(temp_file) | |
# Process image | |
result = mediquery.process_image(temp_file) | |
# Clean up | |
os.remove(temp_file) | |
# Prepare output | |
analysis = result.get("analysis", "Error processing image") | |
attention_map_b64 = result.get("attention_map") | |
# Convert base64 to image if available | |
attention_map = None | |
if attention_map_b64: | |
try: | |
attention_map = Image.open(io.BytesIO(base64.b64decode(attention_map_b64))) | |
except: | |
pass | |
return analysis, attention_map | |
def process_query_gradio(query): | |
result = mediquery.process_query(query) | |
return result.get("response", "Error processing query") | |
# Create interface | |
with gr.Blocks(title="MediQuery") as demo: | |
gr.Markdown("# MediQuery - AI Medical Assistant") | |
with gr.Tab("Image Analysis"): | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload Chest X-ray") | |
image_button = gr.Button("Analyze X-ray") | |
with gr.Column(): | |
text_output = gr.Textbox(label="Analysis Results", lines=10) | |
image_output = gr.Image(label="Attention Map") | |
image_button.click( | |
fn=process_image_gradio, | |
inputs=[image_input], | |
outputs=[text_output, image_output] | |
) | |
with gr.Tab("Text Query"): | |
query_input = gr.Textbox(label="Medical Query", lines=3, placeholder="e.g., What does pneumonia look like on a chest X-ray?") | |
query_button = gr.Button("Submit Query") | |
query_output = gr.Textbox(label="Response", lines=10) | |
query_button.click( | |
fn=process_query_gradio, | |
inputs=[query_input], | |
outputs=[query_output] | |
) | |
gr.Markdown("## Example Queries") | |
gr.Examples( | |
examples=[ | |
["What does pleural effusion look like?"], | |
["How to differentiate pneumonia from tuberculosis?"], | |
["What are the signs of cardiomegaly on X-ray?"] | |
], | |
inputs=[query_input] | |
) | |
return demo | |
# Create Gradio interface | |
demo = create_gradio_interface() | |
# Mount Gradio app to FastAPI | |
app = gr.mount_gradio_app(app, demo, path="/") | |
# Startup and shutdown events | |
async def startup_event(): | |
"""Initialize resources on startup""" | |
debug_print("API starting up...") | |
async def shutdown_event(): | |
"""Clean up resources on shutdown""" | |
debug_print("API shutting down...") | |
mediquery.cleanup() | |
# Run the FastAPI app with uvicorn when executed directly | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |