YOLOV3-GradCAM / gradio_utils.py
Sijuade's picture
Update gradio_utils.py
e3e3ffd
import random
import torch
from albumentations.pytorch import ToTensorV2
import albumentations as A
import cv2
import glob2
import config
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from lightning_utils import YOLOv3Lightning
from pytorch_grad_cam import GradCAM, EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
from pytorch_grad_cam.utils.model_targets import FasterRCNNBoxScoreTarget
from utils import cells_to_bboxes, non_max_suppression
cmap = plt.get_cmap("tab20b")
class_labels = config.PASCAL_CLASSES
height, width = config.INFERENCE_IMAGE_SIZE, config.INFERENCE_IMAGE_SIZE
colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
icons = [
'flight', 'pedal_bike', 'flutter_dash', 'sailing',
'liquor', 'directions_bus', 'directions_car',
'pets', "chair", 'pets', 'table_restaurant',
'pets', 'bedroom_baby', 'motorcycle', 'person', 'yard',
'kebab_dining', 'chair', "train", "tvmonitor"]
icons_mapping = {config.PASCAL_CLASSES[i]:icons[i] for i in range(len(icons))}
model = YOLOv3Lightning.load_from_checkpoint('YoLoV3Model2.ckpt',
map_location=torch.device('cpu'))
model.eval()
scaled_anchors = (
torch.tensor(config.ANCHORS)
* torch.tensor(config.S[0]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(config.DEVICE)
def get_examples():
example_images = glob2.glob('*.jpg')
example_transparency = [random.choice([0.7, 0.8]) for r in range(len(example_images))]
examples = [[example_images[i], example_transparency[i]] for i in range(len(example_images))]
return(examples)
def yolov3_reshape_transform(x):
activations = []
size = x[0].size()[2:4]
for x_item in x:
x_permute = x_item.permute(0, 1, 4, 2, 3 )
x_permute = x_permute.reshape((x_permute.shape[0],
x_permute.shape[1]*x_permute.shape[2],
*x_permute.shape[3:]))
activations.append(torch.nn.functional.interpolate(torch.abs(x_permute), size, mode='bilinear'))
activations = torch.cat(activations, axis=1)
return(activations)
def infer_transform(IMAGE_SIZE=config.INFERENCE_IMAGE_SIZE):
transforms = A.Compose(
[
A.LongestMaxSize(max_size=IMAGE_SIZE),
A.PadIfNeeded(
min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
),
A.Normalize(mean=[0.45484068, 0.43406072, 0.40103856],
std=[0.23936155, 0.23471538, 0.23876129],
max_pixel_value=255,),
ToTensorV2(),
]
)
return(transforms)
def generate_html():
# Start the HTML string with some style and the Material Icons stylesheet
classes = config.PASCAL_CLASSES
html_string = """
<link href="https://fonts.googleapis.com/icon?family=Material+Icons" rel="stylesheet">
<style>
.title {
font-size: 24px;
font-weight: bold;
text-align: center;
margin-bottom: 20px;
color: #4a4a4a;
}
.subtitle {
font-size: 18px;
text-align: center;
margin-bottom: 10px;
color: #7a7a7a;
}
.class-container {
display: flex;
flex-wrap: wrap;
justify-content: center;
align-items: center;
padding: 20px;
border: 2px solid #e0e0e0;
border-radius: 10px;
background-color: #f5f5f5;
}
.class-item {
display: inline-flex; /* Changed from flex to inline-flex */
align-items: center;
margin: 5px 10px;
padding: 5px 8px; /* Adjusted padding */
border: 1px solid #d1d1d1;
border-radius: 20px;
background-color: #ffffff;
font-weight: bold;
color: #333;
box-shadow: 2px 2px 5px rgba(0, 0, 0, 0.1);
transition: transform 0.2s, box-shadow 0.2s;
}
.class-item:hover {
transform: scale(1.05);
box-shadow: 2px 2px 10px rgba(0, 0, 0, 0.2);
background-color: #e7e7e7;
}
.material-icons {
margin-right: 8px;
}
</style>
<div class="title">Object Detection Prediction & Grad-Cam for YOLOv3</div>
<div class="subtitle">Supported Classes</div>
<div class="class-container">
"""
# Loop through each class and add it to the HTML string with its corresponding icon
for class_name in classes:
icon_name = class_name.lower() # Assuming the icon name is the lowercase version of the class name
icon_name = icons_mapping[icon_name]
html_string += f'<div class="class-item"><i class="material-icons">{icon_name}</i>{class_name}</div>'
# Close the HTML string
html_string += "</div>"
return html_string
def upload_image_inference(img, transparency):
bboxes = [[] for _ in range(1)]
nms_boxes_output, annotations = [], []
img_copy = img.copy()
transform = infer_transform()
img = transform(image=img)['image'].unsqueeze(0)
out = model(img)
for i in range(3):
batch_size, A, S, _, _ = out[i].shape
anchor = scaled_anchors[i]
boxes_scale_i = cells_to_bboxes(
out[i], anchor, S=S, is_preds=True
)
for idx, (box) in enumerate(boxes_scale_i):
bboxes[idx] += box
for i in range(img.shape[0]):
iou_thresh, thresh = 0.5, 0.6
nms_boxes = non_max_suppression(
bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
)
nms_boxes_output.append(nms_boxes)
for box in nms_boxes_output[0]:
class_prediction = int(box[0])
box = box[2:]
upper_left_x = box[0] - box[2] / 2
upper_left_y = box[1] - box[3] / 2
rect = patches.Rectangle(
(upper_left_x * width, upper_left_y * height),
box[2] * width,
box[3] * height,
linewidth=2,
edgecolor=colors[class_prediction],
facecolor="none",
)
rect = rect.get_bbox().get_points()
annotations.append([rect[0].astype(int).tolist()+rect[1].astype(int).tolist(),
config.PASCAL_CLASSES[class_prediction]])
new_bboxes = [a[0] for a in annotations]
new_bboxes = [box for box in new_bboxes if all(val >= 0 for val in box)]
objs = [b[1] for b in nms_boxes_output[0]]
bbox_coord = [b[2:] for b in nms_boxes_output[0]]
targets = [FasterRCNNBoxScoreTarget(objs, bbox_coord)]
cam = EigenCAM(model=model,
target_layers=[model.model],
reshape_transform=yolov3_reshape_transform)
grayscale_cam = cam(input_tensor=img, targets=targets)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(img_copy/255, grayscale_cam, use_rgb=True, image_weight=transparency)
renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32)
for x1, y1, x2, y2 in new_bboxes:
renormalized_cam[y1:y2, x1:x2] = scale_cam_image(grayscale_cam[y1:y2, x1:x2].copy())
renormalized_cam = scale_cam_image(renormalized_cam)
eigencam_image_renormalized = show_cam_on_image(np.float32(img_copy)/255, renormalized_cam, use_rgb=True, image_weight=transparency)
return([[img_copy, annotations],
[grayscale_cam,
renormalized_cam,
visualization,
eigencam_image_renormalized]])