Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
# In[1]: | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import cv2 | |
import numpy as np | |
from torchvision import transforms | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import streamlit as st | |
# In[2]: | |
def preprocess_image(image_path): | |
""" | |
Load and preprocess an image for inference. | |
""" | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
img = Image.open(image_path).convert('RGB') | |
tensor = transform(img) | |
return tensor.unsqueeze(0), img | |
# In[3]: | |
def get_last_conv_layer(model): | |
""" | |
Get the last convolutional layer in the model. | |
""" | |
# For ResNet architecture | |
for name, module in reversed(list(model.named_modules())): | |
if isinstance(module, nn.Conv2d): | |
return name | |
raise ValueError("No Conv2d layers found in the model.") | |
# In[4]: | |
def apply_gradcam(model, image_tensor, target_class=None): | |
""" | |
Apply Grad-CAM to an image. | |
""" | |
device = next(model.parameters()).device | |
image_tensor = image_tensor.to(device) | |
# Register hooks to get activations and gradients | |
features = [] | |
gradients = [] | |
def forward_hook(module, input, output): | |
features.append(output.detach()) | |
def backward_hook(module, grad_input, grad_output): | |
gradients.append(grad_output[0].detach()) | |
last_conv_layer_name = get_last_conv_layer(model) | |
last_conv_layer = dict(model.named_modules())[last_conv_layer_name] | |
handle_forward = last_conv_layer.register_forward_hook(forward_hook) | |
handle_backward = last_conv_layer.register_full_backward_hook(backward_hook) | |
# Forward pass | |
model.eval() | |
output = model(image_tensor) | |
if target_class is None: | |
target_class = output.argmax(dim=1).item() | |
# Zero out all gradients | |
model.zero_grad() | |
# Backward pass | |
one_hot = torch.zeros_like(output) | |
one_hot[0][target_class] = 1 | |
output.backward(gradient=one_hot) | |
# Remove hooks | |
handle_forward.remove() | |
handle_backward.remove() | |
# Get feature maps and gradients | |
feature_map = features[-1].squeeze().cpu().numpy() | |
gradient = gradients[-1].squeeze().cpu().numpy() | |
# Global Average Pooling on gradients | |
pooled_gradients = np.mean(gradient, axis=(1, 2), keepdims=True) | |
cam = feature_map * pooled_gradients | |
cam = np.sum(cam, axis=0) | |
# Apply ReLU | |
cam = np.maximum(cam, 0) | |
# Normalize the CAM | |
cam = cam - np.min(cam) | |
cam = cam / np.max(cam) | |
# Resize CAM to match the original image size | |
cam = cv2.resize(cam, (224, 224)) | |
return cam | |
# In[5]: | |
def overlay_heatmap(original_image, heatmap, alpha=0.5): | |
""" | |
Overlay the heatmap on the original image. | |
Args: | |
original_image (np.ndarray): Original image (H, W, 3), uint8 | |
heatmap (np.ndarray): Grad-CAM heatmap (H', W'), float between 0 and 1 | |
alpha (float): Weight for the heatmap | |
Returns: | |
np.ndarray: Overlayed image | |
""" | |
# Ensure heatmap is 2D | |
if heatmap.ndim == 3: | |
heatmap = np.mean(heatmap, axis=2) | |
# Resize heatmap to match original image size | |
heatmap_resized = cv2.resize(heatmap, (original_image.shape[1], original_image.shape[0])) | |
# Normalize heatmap to [0, 255] | |
heatmap_resized = np.uint8(255 * heatmap_resized) | |
# Apply colormap | |
heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET) | |
# Convert from BGR to RGB | |
heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) | |
# Superimpose: blend heatmap and original image | |
superimposed_img = heatmap_colored * alpha + original_image * (1 - alpha) | |
return np.uint8(superimposed_img) | |
def visualize_gradcam(model, image_path): | |
""" | |
Visualize Grad-CAM for a given image. | |
""" | |
# Preprocess image | |
image_tensor, original_image = preprocess_image(image_path) | |
original_image_np = np.array(original_image) # PIL -> numpy array | |
# Resize original image for better display | |
max_size = (400, 400) # Max width and height | |
original_image_resized = cv2.resize(original_image_np, max_size) | |
# Apply Grad-CAM | |
cam = apply_gradcam(model, image_tensor) | |
# Resize CAM to match original image size | |
heatmap_resized = cv2.resize(cam, (original_image_np.shape[1], original_image_np.shape[0])) | |
# Normalize heatmap to [0, 255] | |
heatmap_resized = np.uint8(255 * heatmap_resized / np.max(heatmap_resized)) | |
# Apply color map | |
heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET) | |
heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) | |
# Overlay | |
superimposed_img = heatmap_colored * 0.4 + original_image_np * 0.6 | |
superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8) | |
# Display results | |
fig, axes = plt.subplots(1, 2, figsize=(8, 4)) # Adjust figsize as needed | |
axes[0].imshow(original_image_resized) | |
axes[0].set_title("Original Image") | |
axes[0].axis("off") | |
axes[1].imshow(superimposed_img) | |
axes[1].set_title("Grad-CAM Heatmap") | |
axes[1].axis("off") | |
plt.tight_layout() | |
st.pyplot(fig) | |
plt.close(fig) | |
# In[6]: | |
if __name__ == "__main__": | |
from models.resnet_model import MalariaResNet50 | |
# Load your trained model | |
model = MalariaResNet50(num_classes=2) | |
model.load_state_dict(torch.load("models/malaria_model.pth")) | |
model.eval() | |
# Path to an image | |
image_path = "malaria_ds/split_dataset/test/Parasitized/C33P1thinF_IMG_20150619_114756a_cell_181.png" | |
# Visualize Grad-CAM | |
visualize_gradcam(model, image_path) | |
# In[ ]: | |