import gradio as gr import numpy as np from mmcv.transforms import Compose from mmdet.registry import VISUALIZERS from mmdet.apis import init_detector, inference_detector import torch from torchvision.io import read_image from torchvision.utils import draw_bounding_boxes import torchvision.transforms.functional as TF SIGNS_CLASSES = ('A10', 'A11', 'A12', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19', 'A1a', 'A1b', 'A22', 'A24', 'A28', 'A29', 'A2a', 'A2b', 'A30', 'A31a', 'A31b', 'A31c', 'A32a', 'A32b', 'A4', 'A5a', 'A6a', 'A6b', 'A7a', 'A8', 'A9', 'B1', 'B11', 'B12', 'B13', 'B14', 'B15', 'B16', 'B17', 'B19', 'B2', 'B20a', 'B20b', 'B21a', 'B21b', 'B24a', 'B24b', 'B26', 'B28', 'B29', 'B32', 'B4', 'B5', 'B6', 'C1', 'C10a', 'C10b', 'C13a', 'C14a', 'C2a', 'C2b', 'C2c', 'C2d', 'C2e', 'C2f', 'C3a', 'C3b', 'C4a', 'C4b', 'C4c', 'C7a', 'C9a', 'C9b', 'E1', 'E11', 'E11c', 'E12', 'E13', 'E2a', 'E2b', 'E2c', 'E2d', 'E3a', 'E3b', 'E4', 'E5', 'E6', 'E7a', 'E7b', 'E8a', 'E8b', 'E8c', 'E8d', 'E8e', 'E9', 'I2', 'IJ1', 'IJ10', 'IJ11a', 'IJ11b', 'IJ14c', 'IJ15', 'IJ2', 'IJ3', 'IJ4a', 'IJ4b', 'IJ4c', 'IJ4d', 'IJ4e', 'IJ5', 'IJ6', 'IJ7', 'IJ8', 'IJ9', 'IP10a', 'IP10b', 'IP11a', 'IP11b', 'IP11c', 'IP11e', 'IP11g', 'IP12', 'IP13a', 'IP13b', 'IP13c', 'IP13d', 'IP14a', 'IP15a', 'IP15b', 'IP16', 'IP17', 'IP18a', 'IP18b', 'IP19', 'IP2', 'IP21', 'IP21a', 'IP22', 'IP25a', 'IP25b', 'IP26a', 'IP26b', 'IP27a', 'IP3', 'IP31a', 'IP4a', 'IP4b', 'IP5', 'IP6', 'IP7', 'IP8a', 'IP8b', 'IS10b', 'IS11a', 'IS11b', 'IS11c', 'IS12a', 'IS12b', 'IS12c', 'IS13', 'IS14', 'IS15a', 'IS15b', 'IS16b', 'IS16c', 'IS16d', 'IS17', 'IS18a', 'IS18b', 'IS19a', 'IS19b', 'IS19c', 'IS19d', 'IS1a', 'IS1b', 'IS1c', 'IS1d', 'IS20', 'IS21a', 'IS21b', 'IS21c', 'IS22a', 'IS22c', 'IS22d', 'IS22e', 'IS22f', 'IS23', 'IS24a', 'IS24b', 'IS24c', 'IS2a', 'IS2b', 'IS2c', 'IS2d', 'IS3a', 'IS3b', 'IS3c', 'IS3d', 'IS4a', 'IS4b', 'IS4c', 'IS4d', 'IS5', 'IS6a', 'IS6b', 'IS6c', 'IS6e', 'IS6f', 'IS6g', 'IS7a', 'IS8a', 'IS8b', 'IS9a', 'IS9b', 'IS9c', 'IS9d', 'O2', 'P1', 'P2', 'P3', 'P4', 'P6', 'P7', 'P8', 'UNKNOWN', 'X1', 'X2', 'X3', 'XXX', 'Z2', 'Z3', 'Z4a', 'Z4b', 'Z4c', 'Z4d', 'Z4e', 'Z7', 'Z9') # Specify the path to model config and checkpoint file config_file = 'configs/config_cascade_rcnn_traffic_signs.py' checkpoint_file = 'checkpoints/traffic_signs_cascade_2v2.pth' def draw_coco_bboxes(img, bboxes, color=(255,255,0), width=5, show=False, export_p=None, labels=None, resize_to=None): bboxes_transf = bboxes img = draw_bounding_boxes(img, torch.Tensor(bboxes_transf), colors=color, width=width, labels=labels, font_size=150) if show: if resize_to is not None: img = TF.resize(img, resize_to) img_pil = TF.to_pil_image(img) img_pil.show() if export_p: img_pil.save(export_p) return img def traffic_sign_inference(img): # Build the model from a config file and a checkpoint file model = init_detector(config_file, checkpoint_file, device='cpu') result = inference_detector(model, img) # img = mmcv.imread(img) # numpy -> torch here! # img = mmcv.imconvert(img, 'bgr', 'rgb') bboxes = result.pred_instances.bboxes labels = [SIGNS_CLASSES[l] for l in result.pred_instances.labels] img_t = torch.from_numpy(img).permute(2, 0, 1) print(f"shape: {img_t.shape}") img_res_vis = draw_coco_bboxes(img_t, bboxes, labels=labels, show=True) return img_res_vis.permute(1, 2, 0).numpy() demo = gr.Interface(traffic_sign_inference, gr.Image(), "image") with demo: gr.Markdown(''' # Czech Traffic Signs Detector Using [Cascade R-CNN](https://arxiv.org/abs/1712.00726) pretrained on COCO, finetuned on dataset of 39425 images provided by kky.zcu.cz, running on [MMDetection](https://github.com/open-mmlab/mmdetection). Report (in Czech): https://drive.google.com/file/d/1bFafvrTdd6Gs9-uwIia8R1CZ1a-Fn9KR/view?usp=drive_link ## Run prediction 1. Upload an image (left box) 2. Press submit 3. See the detection result (on the right) ## Some of the classes ![Czech traffic signs](sdz.jpg "Czech traffic signs") https://www.znackydubi.cz/images/5bb48f0d7fa21/original ''') demo.launch()