coldlike commited on
Commit
ddd9242
·
1 Parent(s): 3ef5868

Ready for huggingface deployment

Browse files
Files changed (1) hide show
  1. grdcam/gradcam.py +0 -216
grdcam/gradcam.py DELETED
@@ -1,216 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- # In[1]:
5
-
6
-
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- import cv2
11
- import numpy as np
12
- from torchvision import transforms
13
- import matplotlib.pyplot as plt
14
- from PIL import Image
15
- import streamlit as st
16
-
17
-
18
- # In[2]:
19
-
20
-
21
- def preprocess_image(image_path):
22
- """
23
- Load and preprocess an image for inference.
24
- """
25
- transform = transforms.Compose([
26
- transforms.Resize((224, 224)),
27
- transforms.ToTensor(),
28
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
29
- ])
30
-
31
- img = Image.open(image_path).convert('RGB')
32
- tensor = transform(img)
33
- return tensor.unsqueeze(0), img
34
-
35
-
36
- # In[3]:
37
-
38
-
39
- def get_last_conv_layer(model):
40
- """
41
- Get the last convolutional layer in the model.
42
- """
43
- # For ResNet architecture
44
- for name, module in reversed(list(model.named_modules())):
45
- if isinstance(module, nn.Conv2d):
46
- return name
47
- raise ValueError("No Conv2d layers found in the model.")
48
-
49
-
50
- # In[4]:
51
-
52
-
53
- def apply_gradcam(model, image_tensor, target_class=None):
54
- """
55
- Apply Grad-CAM to an image.
56
- """
57
- device = next(model.parameters()).device
58
- image_tensor = image_tensor.to(device)
59
-
60
- # Register hooks to get activations and gradients
61
- features = []
62
- gradients = []
63
-
64
- def forward_hook(module, input, output):
65
- features.append(output.detach())
66
-
67
- def backward_hook(module, grad_input, grad_output):
68
- gradients.append(grad_output[0].detach())
69
-
70
- last_conv_layer_name = get_last_conv_layer(model)
71
- last_conv_layer = dict(model.named_modules())[last_conv_layer_name]
72
- handle_forward = last_conv_layer.register_forward_hook(forward_hook)
73
- handle_backward = last_conv_layer.register_full_backward_hook(backward_hook)
74
-
75
- # Forward pass
76
- model.eval()
77
- output = model(image_tensor)
78
- if target_class is None:
79
- target_class = output.argmax(dim=1).item()
80
-
81
- # Zero out all gradients
82
- model.zero_grad()
83
-
84
- # Backward pass
85
- one_hot = torch.zeros_like(output)
86
- one_hot[0][target_class] = 1
87
- output.backward(gradient=one_hot)
88
-
89
- # Remove hooks
90
- handle_forward.remove()
91
- handle_backward.remove()
92
-
93
- # Get feature maps and gradients
94
- feature_map = features[-1].squeeze().cpu().numpy()
95
- gradient = gradients[-1].squeeze().cpu().numpy()
96
-
97
- # Global Average Pooling on gradients
98
- pooled_gradients = np.mean(gradient, axis=(1, 2), keepdims=True)
99
- cam = feature_map * pooled_gradients
100
- cam = np.sum(cam, axis=0)
101
-
102
- # Apply ReLU
103
- cam = np.maximum(cam, 0)
104
-
105
- # Normalize the CAM
106
- cam = cam - np.min(cam)
107
- cam = cam / np.max(cam)
108
-
109
- # Resize CAM to match the original image size
110
- cam = cv2.resize(cam, (224, 224))
111
-
112
- return cam
113
-
114
-
115
- # In[5]:
116
-
117
-
118
- def overlay_heatmap(original_image, heatmap, alpha=0.5):
119
- """
120
- Overlay the heatmap on the original image.
121
-
122
- Args:
123
- original_image (np.ndarray): Original image (H, W, 3), uint8
124
- heatmap (np.ndarray): Grad-CAM heatmap (H', W'), float between 0 and 1
125
- alpha (float): Weight for the heatmap
126
-
127
- Returns:
128
- np.ndarray: Overlayed image
129
- """
130
- # Ensure heatmap is 2D
131
- if heatmap.ndim == 3:
132
- heatmap = np.mean(heatmap, axis=2)
133
-
134
- # Resize heatmap to match original image size
135
- heatmap_resized = cv2.resize(heatmap, (original_image.shape[1], original_image.shape[0]))
136
-
137
- # Normalize heatmap to [0, 255]
138
- heatmap_resized = np.uint8(255 * heatmap_resized)
139
-
140
- # Apply colormap
141
- heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
142
-
143
- # Convert from BGR to RGB
144
- heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
145
-
146
- # Superimpose: blend heatmap and original image
147
- superimposed_img = heatmap_colored * alpha + original_image * (1 - alpha)
148
- return np.uint8(superimposed_img)
149
-
150
- def visualize_gradcam(model, image_path):
151
- """
152
- Visualize Grad-CAM for a given image.
153
- """
154
- # Preprocess image
155
- image_tensor, original_image = preprocess_image(image_path)
156
- original_image_np = np.array(original_image) # PIL -> numpy array
157
-
158
- # Resize original image for better display
159
- max_size = (400, 400) # Max width and height
160
- original_image_resized = cv2.resize(original_image_np, max_size)
161
-
162
- # Apply Grad-CAM
163
- cam = apply_gradcam(model, image_tensor)
164
-
165
- # Resize CAM to match original image size
166
- heatmap_resized = cv2.resize(cam, (original_image_np.shape[1], original_image_np.shape[0]))
167
-
168
- # Normalize heatmap to [0, 255]
169
- heatmap_resized = np.uint8(255 * heatmap_resized / np.max(heatmap_resized))
170
-
171
- # Apply color map
172
- heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
173
- heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
174
-
175
- # Overlay
176
- superimposed_img = heatmap_colored * 0.4 + original_image_np * 0.6
177
- superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
178
-
179
- # Display results
180
- fig, axes = plt.subplots(1, 2, figsize=(8, 4)) # Adjust figsize as needed
181
- axes[0].imshow(original_image_resized)
182
- axes[0].set_title("Original Image")
183
- axes[0].axis("off")
184
-
185
- axes[1].imshow(superimposed_img)
186
- axes[1].set_title("Grad-CAM Heatmap")
187
- axes[1].axis("off")
188
-
189
- plt.tight_layout()
190
- st.pyplot(fig)
191
- plt.close(fig)
192
-
193
-
194
- # In[6]:
195
-
196
-
197
- if __name__ == "__main__":
198
-
199
- from models.resnet_model import MalariaResNet50
200
- # Load your trained model
201
- model = MalariaResNet50(num_classes=2)
202
- model.load_state_dict(torch.load("models/malaria_model.pth"))
203
- model.eval()
204
-
205
- # Path to an image
206
- image_path = "malaria_ds/split_dataset/test/Parasitized/C33P1thinF_IMG_20150619_114756a_cell_181.png"
207
-
208
- # Visualize Grad-CAM
209
- visualize_gradcam(model, image_path)
210
-
211
-
212
- # In[ ]:
213
-
214
-
215
-
216
-