coldlike commited on
Commit
55459f2
Β·
1 Parent(s): faf90bc

docs: Add demo GIF to README

Browse files
Files changed (3) hide show
  1. README.md +13 -10
  2. app/app.py +8 -2
  3. gradcam/gradcam.py +216 -0
README.md CHANGED
@@ -11,14 +11,14 @@ A deep learning-based malaria detection system using ResNet50 and Grad-CAM expla
11
 
12
  ## πŸ› οΈ Built With
13
 
14
- - [PyTorch](https://pytorch.org/)
15
- - [Streamlit](https://streamlit.io/)
16
- - [Grad-CAM](https://arxiv.org/abs/1610.02391)
17
- - [ResNet50](https://pytorch.org/vision/stable/models.html)
18
 
19
  ## πŸ“¦ Dataset
20
 
21
- Uses the [Malaria Cell Images Dataset](https://www.kaggle.com/iarunava/cell-images-for-detecting-malaria)
22
 
23
  ## πŸ“ Folder Structure
24
 
@@ -27,13 +27,16 @@ data/cell_images/
27
  β”œβ”€β”€ Parasitized/
28
  └── Uninfected/
29
 
 
30
 
31
- ## πŸ“· Example Output
32
-
33
- ![Example Grad-CAM Output](image.png)
34
 
35
  ## πŸ§ͺ Usage
36
 
37
- ### Train the Model
 
 
 
38
  ```bash
39
- python notebooks/train.py
 
 
11
 
12
  ## πŸ› οΈ Built With
13
 
14
+ - [PyTorch](https://pytorch.org/)
15
+ - [Streamlit](https://streamlit.io/)
16
+ - [Grad-CAM](https://arxiv.org/abs/1610.02391)
17
+ - [ResNet50](https://pytorch.org/vision/stable/models.html)
18
 
19
  ## πŸ“¦ Dataset
20
 
21
+ Uses the [Malaria Cell Images Dataset](https://www.kaggle.com/iarunava/cell-images-for-detecting-malaria)
22
 
23
  ## πŸ“ Folder Structure
24
 
 
27
  β”œβ”€β”€ Parasitized/
28
  └── Uninfected/
29
 
30
+ ## Here's a quick preview of the app in action:
31
 
32
+ ![Malaria Classifier Demo](demo.gif)
 
 
33
 
34
  ## πŸ§ͺ Usage
35
 
36
+ ## πŸ› οΈ Requirements
37
+
38
+ Install dependencies:
39
+
40
  ```bash
41
+ pip install torch torchvision streamlit opencv-python matplotlib scikit-learn
42
+ ```
app/app.py CHANGED
@@ -5,6 +5,12 @@ from PIL import Image
5
  import numpy as np
6
  import warnings
7
  import torch.nn.functional as F
 
 
 
 
 
 
8
 
9
  # Avoid OMP error from PyTorch/OpenCV
10
  os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
@@ -13,8 +19,8 @@ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
13
  warnings.filterwarnings("ignore", category=UserWarning)
14
 
15
  # Import custom modules
16
- from models.resnet_model import MalariaResNet50
17
- from gradcam import visualize_gradcam
18
 
19
 
20
  # -----------------------------
 
5
  import numpy as np
6
  import warnings
7
  import torch.nn.functional as F
8
+ # Add root to PYTHONPATH
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ # Add root directory to Python path
13
+ sys.path.append(str(Path(__file__).parent.parent))
14
 
15
  # Avoid OMP error from PyTorch/OpenCV
16
  os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
 
19
  warnings.filterwarnings("ignore", category=UserWarning)
20
 
21
  # Import custom modules
22
+ from models.resnet_model import MalariaResNet50
23
+ from gradcam.gradcam import visualize_gradcam
24
 
25
 
26
  # -----------------------------
gradcam/gradcam.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import cv2
11
+ import numpy as np
12
+ from torchvision import transforms
13
+ import matplotlib.pyplot as plt
14
+ from PIL import Image
15
+ import streamlit as st
16
+
17
+
18
+ # In[2]:
19
+
20
+
21
+ def preprocess_image(image_path):
22
+ """
23
+ Load and preprocess an image for inference.
24
+ """
25
+ transform = transforms.Compose([
26
+ transforms.Resize((224, 224)),
27
+ transforms.ToTensor(),
28
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
29
+ ])
30
+
31
+ img = Image.open(image_path).convert('RGB')
32
+ tensor = transform(img)
33
+ return tensor.unsqueeze(0), img
34
+
35
+
36
+ # In[3]:
37
+
38
+
39
+ def get_last_conv_layer(model):
40
+ """
41
+ Get the last convolutional layer in the model.
42
+ """
43
+ # For ResNet architecture
44
+ for name, module in reversed(list(model.named_modules())):
45
+ if isinstance(module, nn.Conv2d):
46
+ return name
47
+ raise ValueError("No Conv2d layers found in the model.")
48
+
49
+
50
+ # In[4]:
51
+
52
+
53
+ def apply_gradcam(model, image_tensor, target_class=None):
54
+ """
55
+ Apply Grad-CAM to an image.
56
+ """
57
+ device = next(model.parameters()).device
58
+ image_tensor = image_tensor.to(device)
59
+
60
+ # Register hooks to get activations and gradients
61
+ features = []
62
+ gradients = []
63
+
64
+ def forward_hook(module, input, output):
65
+ features.append(output.detach())
66
+
67
+ def backward_hook(module, grad_input, grad_output):
68
+ gradients.append(grad_output[0].detach())
69
+
70
+ last_conv_layer_name = get_last_conv_layer(model)
71
+ last_conv_layer = dict(model.named_modules())[last_conv_layer_name]
72
+ handle_forward = last_conv_layer.register_forward_hook(forward_hook)
73
+ handle_backward = last_conv_layer.register_full_backward_hook(backward_hook)
74
+
75
+ # Forward pass
76
+ model.eval()
77
+ output = model(image_tensor)
78
+ if target_class is None:
79
+ target_class = output.argmax(dim=1).item()
80
+
81
+ # Zero out all gradients
82
+ model.zero_grad()
83
+
84
+ # Backward pass
85
+ one_hot = torch.zeros_like(output)
86
+ one_hot[0][target_class] = 1
87
+ output.backward(gradient=one_hot)
88
+
89
+ # Remove hooks
90
+ handle_forward.remove()
91
+ handle_backward.remove()
92
+
93
+ # Get feature maps and gradients
94
+ feature_map = features[-1].squeeze().cpu().numpy()
95
+ gradient = gradients[-1].squeeze().cpu().numpy()
96
+
97
+ # Global Average Pooling on gradients
98
+ pooled_gradients = np.mean(gradient, axis=(1, 2), keepdims=True)
99
+ cam = feature_map * pooled_gradients
100
+ cam = np.sum(cam, axis=0)
101
+
102
+ # Apply ReLU
103
+ cam = np.maximum(cam, 0)
104
+
105
+ # Normalize the CAM
106
+ cam = cam - np.min(cam)
107
+ cam = cam / np.max(cam)
108
+
109
+ # Resize CAM to match the original image size
110
+ cam = cv2.resize(cam, (224, 224))
111
+
112
+ return cam
113
+
114
+
115
+ # In[5]:
116
+
117
+
118
+ def overlay_heatmap(original_image, heatmap, alpha=0.5):
119
+ """
120
+ Overlay the heatmap on the original image.
121
+
122
+ Args:
123
+ original_image (np.ndarray): Original image (H, W, 3), uint8
124
+ heatmap (np.ndarray): Grad-CAM heatmap (H', W'), float between 0 and 1
125
+ alpha (float): Weight for the heatmap
126
+
127
+ Returns:
128
+ np.ndarray: Overlayed image
129
+ """
130
+ # Ensure heatmap is 2D
131
+ if heatmap.ndim == 3:
132
+ heatmap = np.mean(heatmap, axis=2)
133
+
134
+ # Resize heatmap to match original image size
135
+ heatmap_resized = cv2.resize(heatmap, (original_image.shape[1], original_image.shape[0]))
136
+
137
+ # Normalize heatmap to [0, 255]
138
+ heatmap_resized = np.uint8(255 * heatmap_resized)
139
+
140
+ # Apply colormap
141
+ heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
142
+
143
+ # Convert from BGR to RGB
144
+ heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
145
+
146
+ # Superimpose: blend heatmap and original image
147
+ superimposed_img = heatmap_colored * alpha + original_image * (1 - alpha)
148
+ return np.uint8(superimposed_img)
149
+
150
+ def visualize_gradcam(model, image_path):
151
+ """
152
+ Visualize Grad-CAM for a given image.
153
+ """
154
+ # Preprocess image
155
+ image_tensor, original_image = preprocess_image(image_path)
156
+ original_image_np = np.array(original_image) # PIL -> numpy array
157
+
158
+ # Resize original image for better display
159
+ max_size = (400, 400) # Max width and height
160
+ original_image_resized = cv2.resize(original_image_np, max_size)
161
+
162
+ # Apply Grad-CAM
163
+ cam = apply_gradcam(model, image_tensor)
164
+
165
+ # Resize CAM to match original image size
166
+ heatmap_resized = cv2.resize(cam, (original_image_np.shape[1], original_image_np.shape[0]))
167
+
168
+ # Normalize heatmap to [0, 255]
169
+ heatmap_resized = np.uint8(255 * heatmap_resized / np.max(heatmap_resized))
170
+
171
+ # Apply color map
172
+ heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
173
+ heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
174
+
175
+ # Overlay
176
+ superimposed_img = heatmap_colored * 0.4 + original_image_np * 0.6
177
+ superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
178
+
179
+ # Display results
180
+ fig, axes = plt.subplots(1, 2, figsize=(8, 4)) # Adjust figsize as needed
181
+ axes[0].imshow(original_image_resized)
182
+ axes[0].set_title("Original Image")
183
+ axes[0].axis("off")
184
+
185
+ axes[1].imshow(superimposed_img)
186
+ axes[1].set_title("Grad-CAM Heatmap")
187
+ axes[1].axis("off")
188
+
189
+ plt.tight_layout()
190
+ st.pyplot(fig)
191
+ plt.close(fig)
192
+
193
+
194
+ # In[6]:
195
+
196
+
197
+ if __name__ == "__main__":
198
+
199
+ from models.resnet_model import MalariaResNet50
200
+ # Load your trained model
201
+ model = MalariaResNet50(num_classes=2)
202
+ model.load_state_dict(torch.load("models/malaria_model.pth"))
203
+ model.eval()
204
+
205
+ # Path to an image
206
+ image_path = "malaria_ds/split_dataset/test/Parasitized/C33P1thinF_IMG_20150619_114756a_cell_181.png"
207
+
208
+ # Visualize Grad-CAM
209
+ visualize_gradcam(model, image_path)
210
+
211
+
212
+ # In[ ]:
213
+
214
+
215
+
216
+