ahasera commited on
Commit
078e8d3
Β·
1 Parent(s): f78a3ee

Add application file

Browse files
Files changed (1) hide show
  1. app.py +345 -0
app.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import gradio as gr
4
+ import numpy as np
5
+ from ultralytics import YOLO
6
+ import easyocr
7
+ import pytesseract
8
+ import keras_ocr
9
+ import pandas as pd
10
+ from PIL import Image
11
+ import io
12
+ import re
13
+ from typing import List, Tuple, Union
14
+ from datetime import datetime
15
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
16
+ import torch
17
+ from datetime import datetime
18
+ import time
19
+
20
+ # Initialisation of models
21
+ def load_models():
22
+ global model_vehicle, model_plate, reader_easyocr, pipeline_kerasocr, processor_trocr, model_trocr
23
+ model_vehicle = YOLO('models/yolov8n.pt')
24
+ model_plate = YOLO('models/best.pt')
25
+ reader_easyocr = easyocr.Reader(['en'], gpu=True)
26
+ pipeline_kerasocr = keras_ocr.pipeline.Pipeline()
27
+ processor_trocr = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
28
+ model_trocr = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
29
+ load_models()
30
+
31
+ # patterns plate layouts europe
32
+ EUROPEAN_PATTERNS = {
33
+ 'FR': r'^(?:[A-Z]{2}-\d{3}-[A-Z]{2}|\d{2,4}\s?[A-Z]{2,3}\s?\d{2,4})$', # France
34
+ 'DE': r'^[A-Z]{1,3}-[A-Z]{1,2}\s?\d{1,4}[EH]?$', # Germany
35
+ 'ES': r'^(\d{4}[A-Z]{3}|[A-Z]{1,2}\d{4}[A-Z]{2,3})$', # Spain
36
+ 'IT': r'^[A-Z]{2}\s?\d{3}\s?[A-Z]{2}$', # Italy
37
+ 'GB': r'^[A-Z]{2}\d{2}\s?[A-Z]{3}$', # Great-Britain
38
+ 'NL': r'^[A-Z]{2}-\d{3}-[A-Z]$', # Netherlands
39
+ 'BE': r'^(1-[A-Z]{3}-\d{3}|\d-[A-Z]{3}-\d{3})$', # Belgium
40
+ 'PL': r'^[A-Z]{2,3}\s?\d{4,5}$', # Poland
41
+ 'SE': r'^[A-Z]{3}\s?\d{3}$', # Sweden
42
+ 'NO': r'^[A-Z]{2}\s?\d{5}$', # Norway
43
+ 'FI': r'^[A-Z]{3}-\d{3}$', # Finland
44
+ 'DK': r'^[A-Z]{2}\s?\d{2}\s?\d{3}$', # Denmark
45
+ 'CH': r'^[A-Z]{2}\s?\d{1,6}$', # Switzerland
46
+ 'AT': r'^[A-Z]{1,2}\s?\d{1,5}[A-Z]$', # Austria
47
+ 'PT': r'^[A-Z]{2}-\d{2}-[A-Z]{2}$', # Portugal
48
+ 'EU': r'^[A-Z0-9]{2,4}[-\s]?[A-Z0-9]{1,4}[-\s]?[A-Z0-9]{1,4}$' # Generic European plate
49
+ }
50
+
51
+ def preprocess_image(image):
52
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
53
+ blur = cv2.GaussianBlur(gray, (5, 5), 0)
54
+ thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
55
+ return cv2.cvtColor(thresh, cv2.COLOR_GRAY2RGB)
56
+
57
+ @torch.no_grad()
58
+ def trocr_ocr(image):
59
+ pixel_values = processor_trocr(image, return_tensors="pt").pixel_values
60
+ generated_ids = model_trocr.generate(pixel_values)
61
+ return processor_trocr.batch_decode(generated_ids, skip_special_tokens=True)[0]
62
+
63
+ def read_license_plate(license_plate_crop, ocr_engine='easyocr'):
64
+ if ocr_engine == 'easyocr':
65
+ detections_raw = reader_easyocr.readtext(license_plate_crop)
66
+ detections_preprocessed = reader_easyocr.readtext(preprocess_image(license_plate_crop))
67
+ elif ocr_engine == 'pytesseract':
68
+ text_raw = pytesseract.image_to_string(license_plate_crop, config='--psm 8')
69
+ text_preprocessed = pytesseract.image_to_string(preprocess_image(license_plate_crop), config='--psm 8')
70
+ detections_raw = [(None, text_raw.strip(), None)]
71
+ detections_preprocessed = [(None, text_preprocessed.strip(), None)]
72
+ elif ocr_engine == 'kerasocr':
73
+ if len(license_plate_crop.shape) == 2 or license_plate_crop.shape[2] == 1:
74
+ license_plate_crop = cv2.cvtColor(license_plate_crop, cv2.COLOR_GRAY2RGB)
75
+ detection_results_raw = pipeline_kerasocr.recognize([license_plate_crop])[0]
76
+ detection_results_preprocessed = pipeline_kerasocr.recognize([preprocess_image(license_plate_crop)])[0]
77
+ detections_raw = [(None, ''.join([text for text, box in detection_results_raw]), None)]
78
+ detections_preprocessed = [(None, ''.join([text for text, box in detection_results_preprocessed]), None)]
79
+ elif ocr_engine == 'trocr':
80
+ text_raw = trocr_ocr(license_plate_crop)
81
+ text_preprocessed = trocr_ocr(preprocess_image(license_plate_crop))
82
+ detections_raw = [(None, text_raw.strip(), None)]
83
+ detections_preprocessed = [(None, text_preprocessed.strip(), None)]
84
+ else:
85
+ raise ValueError(f"OCR engine '{ocr_engine}' not supported.")
86
+
87
+ def extract_text(detections):
88
+ plate = []
89
+ for detection in detections:
90
+ _, text, _ = detection
91
+ text = text.upper().replace(' ', '')
92
+ plate.append(text)
93
+ return " ".join(plate) if plate else None
94
+
95
+ return extract_text(detections_raw), extract_text(detections_preprocessed)
96
+
97
+ def clean_plate_text(text):
98
+ if text is None:
99
+ return ''
100
+ cleaned = re.sub(r'[^A-Z0-9\-\s]', '', text.upper())
101
+ cleaned = re.sub(r'\s+', '', cleaned).strip()
102
+ return cleaned
103
+
104
+ def validate_european_plate(text):
105
+ for country, pattern in EUROPEAN_PATTERNS.items():
106
+ if re.match(pattern, text):
107
+ return text, country
108
+ return None, None
109
+
110
+ def post_process_ocr(raw_text, preprocessed_text):
111
+ cleaned_raw = clean_plate_text(raw_text)
112
+ validated_raw, country_raw = validate_european_plate(cleaned_raw)
113
+
114
+ cleaned_preprocessed = clean_plate_text(preprocessed_text)
115
+ validated_preprocessed, country_preprocessed = validate_european_plate(cleaned_preprocessed)
116
+
117
+ if validated_raw:
118
+ return validated_raw, country_raw, True
119
+ elif validated_preprocessed:
120
+ return validated_preprocessed, country_preprocessed, True
121
+
122
+ return cleaned_raw, 'Unknown', False
123
+
124
+ def detect_and_recognize_plates(image, ocr_engine='easyocr', confidence_threshold=0.5):
125
+ results_vehicle = model_vehicle(image)
126
+
127
+ plates_detected = []
128
+ cropped_plates = []
129
+ vehicles_found = False
130
+
131
+ for result in results_vehicle:
132
+ for bbox in result.boxes.data.tolist():
133
+ x1, y1, x2, y2, score, class_id = bbox
134
+ if score < confidence_threshold:
135
+ continue # Skip detections below the confidence threshold
136
+ if int(class_id) == 2: # Class ID 2 represents cars in COCO dataset
137
+ vehicles_found = True
138
+ vehicle = image[int(y1):int(y2), int(x1):int(x2)]
139
+
140
+ results_plate = model_plate(vehicle)
141
+
142
+ for result_plate in results_plate:
143
+ for bbox_plate in result_plate.boxes.data.tolist():
144
+ px1, py1, px2, py2, pscore, pclass_id = bbox_plate
145
+ if pscore < confidence_threshold:
146
+ continue # Skip detections below the confidence threshold
147
+ plate = vehicle[int(py1):int(py2), int(px1):int(px2)]
148
+ cropped_plates.append(plate) # Save the cropped plate
149
+
150
+ raw_result, preprocessed_result = read_license_plate(plate, ocr_engine=ocr_engine)
151
+
152
+ if raw_result or preprocessed_result:
153
+ validated_text, country, is_valid = post_process_ocr(raw_result, preprocessed_result)
154
+
155
+ plates_detected.append({
156
+ 'raw_text': raw_result,
157
+ 'preprocessed_text': preprocessed_result,
158
+ 'validated_text': validated_text,
159
+ 'country': country,
160
+ 'is_valid': is_valid,
161
+ 'bbox': [int(x1+px1), int(y1+py1), int(x1+px2), int(y1+py2)]
162
+ })
163
+
164
+ # Annotate the image
165
+ cv2.rectangle(image, (int(x1+px1), int(y1+py1)), (int(x1+px2), int(y1+py2)), (0, 255, 0), 2)
166
+ if validated_text:
167
+ cv2.putText(image, f"{validated_text} ({country})", (int(x1+px1), int(y1+py1)-10),
168
+ cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
169
+
170
+ if not vehicles_found:
171
+ results_plate = model_plate(image)
172
+ for result_plate in results_plate:
173
+ for bbox_plate in result_plate.boxes.data.tolist():
174
+ px1, py1, px2, py2, pscore, pclass_id = bbox_plate
175
+ if pscore < confidence_threshold:
176
+ continue # Skip detections below the confidence threshold
177
+ plate = image[int(py1):int(py2), int(px1):int(px2)]
178
+ cropped_plates.append(plate) # Save the cropped plate
179
+
180
+ raw_result, preprocessed_result = read_license_plate(plate, ocr_engine=ocr_engine)
181
+
182
+ if raw_result or preprocessed_result:
183
+ validated_text, country, is_valid = post_process_ocr(raw_result, preprocessed_result)
184
+
185
+ plates_detected.append({
186
+ 'raw_text': raw_result,
187
+ 'preprocessed_text': preprocessed_result,
188
+ 'validated_text': validated_text,
189
+ 'country': country,
190
+ 'is_valid': is_valid,
191
+ 'bbox': [int(px1), int(py1), int(px2), int(py2)]
192
+ })
193
+
194
+ # Annotate the image
195
+ cv2.rectangle(image, (int(px1), int(py1)), (int(px2), int(py2)), (0, 255, 0), 2)
196
+ if validated_text:
197
+ cv2.putText(image, f"{validated_text} ({country})", (int(px1), int(py1)-10),
198
+ cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
199
+
200
+ return image, plates_detected, cropped_plates
201
+
202
+ def process_image(input_image, ocr_engine='easyocr', confidence_threshold=0.5) -> Tuple[Union[np.ndarray, None], pd.DataFrame, List[np.ndarray]]:
203
+ try:
204
+ # Convert Gradio image to numpy array
205
+ if isinstance(input_image, np.ndarray):
206
+ image_np = input_image
207
+ elif isinstance(input_image, Image.Image):
208
+ image_np = np.array(input_image)
209
+ else:
210
+ raise ValueError("Unsupported image type")
211
+
212
+ # Detect and recognize plates
213
+ annotated_image, plates, cropped_plates = detect_and_recognize_plates(image_np, ocr_engine=ocr_engine, confidence_threshold=confidence_threshold)
214
+
215
+ # Prepare the result as a pandas DataFrame
216
+ results = []
217
+ for i, plate in enumerate(plates):
218
+ results.append({
219
+ "Plate Number": i + 1,
220
+ "Validated Text": plate['validated_text'],
221
+ "Country": plate['country'],
222
+ "Valid": "Yes" if plate['is_valid'] else "No",
223
+ "Raw OCR": plate['raw_text'],
224
+ "Preprocessed OCR": plate['preprocessed_text'],
225
+ })
226
+
227
+ df = pd.DataFrame(results) if results else pd.DataFrame({"Message": ["No license plates detected"]})
228
+
229
+ return annotated_image, df, cropped_plates
230
+ except Exception as e:
231
+ print(f"An error occurred: {str(e)}")
232
+ return None, pd.DataFrame({"Error": [str(e)]}), []
233
+
234
+ def compare_ocr_engines(image):
235
+ ocr_engines = ['easyocr', 'pytesseract', 'kerasocr', 'trocr']
236
+ results = {}
237
+
238
+ for engine in ocr_engines:
239
+ start_time = time.time()
240
+ _, df, _ = process_image(image, ocr_engine=engine)
241
+ end_time = time.time()
242
+
243
+ results[engine] = {
244
+ 'processing_time': end_time - start_time,
245
+ 'plates_detected': len(df) if 'Plate Number' in df.columns else 0,
246
+ 'texts': df['Validated Text'].tolist() if 'Validated Text' in df.columns else []
247
+ }
248
+
249
+ comparison_df = pd.DataFrame({
250
+ 'OCR Engine': ocr_engines,
251
+ 'Processing Time (s)': [results[engine]['processing_time'] for engine in ocr_engines],
252
+ 'Plates Detected': [results[engine]['plates_detected'] for engine in ocr_engines],
253
+ 'Detected Texts': [', '.join(results[engine]['texts']) for engine in ocr_engines]
254
+ })
255
+
256
+ return comparison_df
257
+
258
+ # gradio app
259
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
260
+ gr.Markdown(
261
+ """
262
+ # πŸš— ALPR YOLOv8 and Multi-OCR πŸš—
263
+
264
+ Test this ALPR solution using YOLOv8 and various OCR engines!
265
+
266
+ > Better results with high quality images, plate aligned horizontally, clearly visible.
267
+ """
268
+ )
269
+
270
+ with gr.Tabs():
271
+ with gr.TabItem("Single Image Processing"):
272
+ with gr.Accordion("How It Works", open=False):
273
+ gr.Markdown(
274
+ """
275
+ This ALPR (Automatic License Plate Recognition) system works in several steps:
276
+ 1. Vehicle Detection: Uses YOLOv8 to detect vehicles in the image with pretrained model on MS-COCO dataset.
277
+ 2. License Plate Detection: Applies a custom YOLOv8 model to locate license plates region within detected vehicles to crop it.
278
+ 3. Add preprocessing on the cropped plate that can help to give better results in some situation.
279
+ 4. OCR: Employs various OCR engines to read the text on the cropped license plates.
280
+ 5. Post-processing: Cleans and validates the detected text against known license plate patterns.
281
+ """
282
+ )
283
+
284
+ with gr.Accordion("OCR Engines", open=False):
285
+ gr.Markdown(
286
+ """
287
+ The system supports multiple OCR engines:
288
+ - [EasyOCR](https://github.com/JaidedAI/EasyOCR): General-purpose OCR library with good accuracy.
289
+ - [Pytesseract](https://github.com/madmaze/pytesseract): Open-source OCR engine based on Tesseract.
290
+ - [Keras-OCR](https://github.com/faustomorales/keras-ocr): Deep learning-based OCR solution.
291
+ - [TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr): Transformer-based OCR model for handwritten and printed text.
292
+
293
+ Each engine has its strengths and may perform differently depending on the image quality and license plate style.
294
+ """
295
+ )
296
+
297
+ with gr.Row():
298
+ with gr.Column(scale=1):
299
+ input_image = gr.Image(type="numpy", label="Input image")
300
+ ocr_selector = gr.Radio(choices=['easyocr', 'pytesseract', 'kerasocr', 'trocr'], value='easyocr', label="Select OCR Engine")
301
+ confidence_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label="Detection Confidence Threshold")
302
+ submit_btn = gr.Button("Detect License Plates", variant="primary")
303
+
304
+ with gr.Column(scale=1):
305
+ output_image = gr.Image(type="numpy", label="Annotated image")
306
+ cropped_plate_gallery = gr.Gallery(label="Cropped plates")
307
+
308
+ output_table = gr.Dataframe(label="Detection results")
309
+
310
+ with gr.Accordion("Understanding the Results", open=False):
311
+ gr.Markdown(
312
+ """
313
+ The results table provides the following information:
314
+ - Plate Number: Sequential number assigned to each detected plate.
315
+ - Validated Text: The final, cleaned, and validated license plate text.
316
+ - Country: Estimated country of origin based on the plate format.
317
+ - Valid: Indicates whether the plate matches a known format.
318
+ - Raw OCR: The initial text detected by the OCR engine.
319
+ - Preprocessed OCR: Text detected after image preprocessing.
320
+
321
+ The confidence threshold determines the minimum confidence score for a detection to be considered valid.
322
+ """
323
+ )
324
+
325
+ with gr.TabItem("OCR Engine Comparison"):
326
+ with gr.Row():
327
+ comparison_input = gr.Image(type="numpy", label="Input Image for Comparison")
328
+ compare_btn = gr.Button("Compare OCR Engines")
329
+ comparison_output = gr.Dataframe(label="OCR Engine Comparison Results")
330
+
331
+ # Event handlers
332
+ submit_btn.click(
333
+ fn=process_image,
334
+ inputs=[input_image, ocr_selector, confidence_slider],
335
+ outputs=[output_image, output_table, cropped_plate_gallery]
336
+ )
337
+
338
+ compare_btn.click(
339
+ fn=compare_ocr_engines,
340
+ inputs=[comparison_input],
341
+ outputs=[comparison_output]
342
+ )
343
+
344
+ if __name__ == "__main__":
345
+ demo.launch()