Spaces:
Sleeping
Sleeping
File size: 3,232 Bytes
faf90bc 55459f2 faf90bc 55459f2 faf90bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import os
import streamlit as st
import torch
from PIL import Image
import numpy as np
import warnings
import torch.nn.functional as F
# Add root to PYTHONPATH
import sys
from pathlib import Path
# Add root directory to Python path
sys.path.append(str(Path(__file__).parent.parent))
# Avoid OMP error from PyTorch/OpenCV
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
# Suppress FutureWarning from Matplotlib
warnings.filterwarnings("ignore", category=UserWarning)
# Import custom modules
from models.resnet_model import MalariaResNet50
from gradcam.gradcam import visualize_gradcam
# -----------------------------
# Streamlit Page Setup
# -----------------------------
st.set_page_config(page_title="🧬 Malaria Cell Classifier", layout="wide")
st.title("🧬 Malaria Cell Classifier with Grad-CAM")
st.write("Upload a blood smear image and the model will classify it as infected or uninfected, and highlight key regions using Grad-CAM.")
# -----------------------------
# Load Model
# -----------------------------
@st.cache_resource
def load_model():
# Ensure model class doesn't wrap backbone
model = MalariaResNet50(num_classes=2)
model.load_state_dict(torch.load("models/malaria_model.pth", map_location='cpu'))
model.eval()
return model
model = load_model()
# -----------------------------
# Upload Image
# -----------------------------
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
# Save uploaded image temporarily
temp_image_path = f"temp_{uploaded_file.name}"
with open(temp_image_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# Display original image (resize if needed)
image = Image.open(uploaded_file).convert("RGB")
max_size = (400, 400) # Max width and height
image.thumbnail(max_size)
st.image(image, caption="Uploaded Image", use_container_width=False)
# Predict button
if st.button("Predict"):
with st.spinner("Classifying..."):
# Run prediction
pred_label, confidence = model.predict(temp_image_path, device='cpu', show_image=False)
st.success(f"✅ Prediction: **{pred_label}** | Confidence: **{confidence:.2%}**")
# Show Grad-CAM
st.subheader("🔍 Grad-CAM Visualization")
with st.expander("ℹ️ What is Grad-CAM?"):
st.markdown("""
**Grad-CAM (Gradient-weighted Class Activation Mapping)** is an interpretability method that shows which parts of an image are most important for a CNN's prediction.
How it works:
1. Gradients flow from the output neuron back to the last convolutional layer.
2. These gradients are global average pooled to get importance weights.
3. A weighted combination creates a coarse heatmap.
4. Final heatmap is overlaid on the original image.
🔬 In this app:
- Helps understand *why* the model thinks a blood smear cell is infected
- Makes predictions more transparent and reliable
""")
visualize_gradcam(model, temp_image_path) |