Spaces:
Running
Running
import os | |
import streamlit as st | |
import torch | |
from PIL import Image | |
import numpy as np | |
import warnings | |
import torch.nn.functional as F | |
# Add root to PYTHONPATH | |
import sys | |
from pathlib import Path | |
# Add root directory to Python path | |
sys.path.append(str(Path(__file__).parent.parent)) | |
# Avoid OMP error from PyTorch/OpenCV | |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' | |
# Suppress FutureWarning from Matplotlib | |
warnings.filterwarnings("ignore", category=UserWarning) | |
# Import custom modules | |
from models.resnet_model import MalariaResNet50 | |
from gradcam.gradcam import visualize_gradcam | |
# ----------------------------- | |
# Streamlit Page Setup | |
# ----------------------------- | |
st.set_page_config(page_title="🧬 Malaria Cell Classifier", layout="wide") | |
st.title("🧬 Malaria Cell Classifier with Grad-CAM") | |
st.write("Upload a blood smear image and the model will classify it as infected or uninfected, and highlight key regions using Grad-CAM.") | |
# ----------------------------- | |
# Load Model | |
# ----------------------------- | |
def load_model(): | |
# Ensure model class doesn't wrap backbone | |
model = MalariaResNet50(num_classes=2) | |
model.load_state_dict(torch.load("models/malaria_model.pth", map_location='cpu')) | |
model.eval() | |
return model | |
model = load_model() | |
# ----------------------------- | |
# Upload Image | |
# ----------------------------- | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) | |
if uploaded_file is not None: | |
# Save uploaded image temporarily | |
temp_image_path = f"temp_{uploaded_file.name}" | |
with open(temp_image_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
# Display original image (resize if needed) | |
image = Image.open(uploaded_file).convert("RGB") | |
max_size = (400, 400) # Max width and height | |
image.thumbnail(max_size) | |
st.image(image, caption="Uploaded Image", use_container_width=False) | |
# Predict button | |
if st.button("Predict"): | |
with st.spinner("Classifying..."): | |
# Run prediction | |
pred_label, confidence = model.predict(temp_image_path, device='cpu', show_image=False) | |
st.success(f"✅ Prediction: **{pred_label}** | Confidence: **{confidence:.2%}**") | |
# Show Grad-CAM | |
st.subheader("🔍 Grad-CAM Visualization") | |
with st.expander("ℹ️ What is Grad-CAM?"): | |
st.markdown(""" | |
**Grad-CAM (Gradient-weighted Class Activation Mapping)** is an interpretability method that shows which parts of an image are most important for a CNN's prediction. | |
How it works: | |
1. Gradients flow from the output neuron back to the last convolutional layer. | |
2. These gradients are global average pooled to get importance weights. | |
3. A weighted combination creates a coarse heatmap. | |
4. Final heatmap is overlaid on the original image. | |
🔬 In this app: | |
- Helps understand *why* the model thinks a blood smear cell is infected | |
- Makes predictions more transparent and reliable | |
""") | |
visualize_gradcam(model, temp_image_path) |