Spaces:
Running
Running
import os | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
from ultralytics import YOLO | |
import easyocr | |
import pytesseract | |
import keras_ocr | |
import pandas as pd | |
from PIL import Image | |
import io | |
import re | |
from typing import List, Tuple, Union | |
from datetime import datetime | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
import torch | |
from datetime import datetime | |
import time | |
from paddleocr import PaddleOCR | |
# Initialisation of models | |
def load_models(): | |
global model_vehicle, model_plate, reader_easyocr, pipeline_kerasocr, processor_trocr, model_trocr, ocr_paddle | |
model_vehicle = YOLO('models/yolov8n.pt') | |
model_plate = YOLO('models/best.pt') | |
reader_easyocr = easyocr.Reader(['en'], gpu=False) | |
pipeline_kerasocr = keras_ocr.pipeline.Pipeline() | |
processor_trocr = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten') | |
model_trocr = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten') | |
ocr_paddle = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False) | |
load_models() | |
# patterns plate layouts europe | |
EUROPEAN_PATTERNS = { | |
'FR': r'^(?:[A-Z]{2}-\d{3}-[A-Z]{2}|\d{2,4}\s?[A-Z]{2,3}\s?\d{2,4})$', # France | |
'DE': r'^[A-Z]{1,3}-[A-Z]{1,2}\s?\d{1,4}[EH]?$', # Germany | |
'ES': r'^(\d{4}[A-Z]{3}|[A-Z]{1,2}\d{4}[A-Z]{2,3})$', # Spain | |
'IT': r'^[A-Z]{2}\s?\d{3}\s?[A-Z]{2}$', # Italy | |
'GB': r'^[A-Z]{2}\d{2}\s?[A-Z]{3}$', # Great-Britain | |
'NL': r'^[A-Z]{2}-\d{3}-[A-Z]$', # Netherlands | |
'BE': r'^(1-[A-Z]{3}-\d{3}|\d-[A-Z]{3}-\d{3})$', # Belgium | |
'PL': r'^[A-Z]{2,3}\s?\d{4,5}$', # Poland | |
'SE': r'^[A-Z]{3}\s?\d{3}$', # Sweden | |
'NO': r'^[A-Z]{2}\s?\d{5}$', # Norway | |
'FI': r'^[A-Z]{3}-\d{3}$', # Finland | |
'DK': r'^[A-Z]{2}\s?\d{2}\s?\d{3}$', # Denmark | |
'CH': r'^[A-Z]{2}\s?\d{1,6}$', # Switzerland | |
'AT': r'^[A-Z]{1,2}\s?\d{1,5}[A-Z]$', # Austria | |
'PT': r'^[A-Z]{2}-\d{2}-[A-Z]{2}$', # Portugal | |
'EU': r'^[A-Z0-9]{2,4}[-\s]?[A-Z0-9]{1,4}[-\s]?[A-Z0-9]{1,4}$' # Generic European plate | |
} | |
def preprocess_image(image): | |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
blur = cv2.GaussianBlur(gray, (5, 5), 0) | |
thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] | |
return cv2.cvtColor(thresh, cv2.COLOR_GRAY2RGB) | |
def trocr_ocr(image): | |
pixel_values = processor_trocr(image, return_tensors="pt").pixel_values | |
generated_ids = model_trocr.generate(pixel_values) | |
return processor_trocr.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
def read_license_plate(license_plate_crop, ocr_engine='easyocr'): | |
if ocr_engine == 'easyocr': | |
detections_raw = reader_easyocr.readtext(license_plate_crop) | |
detections_preprocessed = reader_easyocr.readtext(preprocess_image(license_plate_crop)) | |
elif ocr_engine == 'pytesseract': | |
text_raw = pytesseract.image_to_string(license_plate_crop, config='--psm 8') | |
text_preprocessed = pytesseract.image_to_string(preprocess_image(license_plate_crop), config='--psm 8') | |
detections_raw = [(None, text_raw.strip(), None)] | |
detections_preprocessed = [(None, text_preprocessed.strip(), None)] | |
elif ocr_engine == 'kerasocr': | |
if len(license_plate_crop.shape) == 2 or license_plate_crop.shape[2] == 1: | |
license_plate_crop = cv2.cvtColor(license_plate_crop, cv2.COLOR_GRAY2RGB) | |
detection_results_raw = pipeline_kerasocr.recognize([license_plate_crop])[0] | |
detection_results_preprocessed = pipeline_kerasocr.recognize([preprocess_image(license_plate_crop)])[0] | |
detections_raw = [(None, ''.join([text for text, box in detection_results_raw]), None)] | |
detections_preprocessed = [(None, ''.join([text for text, box in detection_results_preprocessed]), None)] | |
elif ocr_engine == 'trocr': | |
text_raw = trocr_ocr(license_plate_crop) | |
text_preprocessed = trocr_ocr(preprocess_image(license_plate_crop)) | |
detections_raw = [(None, text_raw.strip(), None)] | |
detections_preprocessed = [(None, text_preprocessed.strip(), None)] | |
elif ocr_engine == 'paddleocr': | |
preprocessed_image = preprocess_image(license_plate_crop) # Assurez-vous que cette ligne est incluse | |
result_raw = ocr_paddle.ocr(license_plate_crop) | |
result_preprocessed = ocr_paddle.ocr(preprocessed_image) | |
# Vérifiez si les résultats ne sont pas vides avant de les utiliser | |
if result_raw and result_raw[0]: | |
detections_raw = [(None, result_raw[0][0][1][0], result_raw[0][0][1][1])] | |
else: | |
detections_raw = [(None, '', 0.0)] | |
if result_preprocessed and result_preprocessed[0]: | |
detections_preprocessed = [(None, result_preprocessed[0][0][1][0], result_preprocessed[0][0][1][1])] | |
else: | |
detections_preprocessed = [(None, '', 0.0)] | |
else: | |
raise ValueError(f"OCR engine '{ocr_engine}' not supported.") | |
def extract_text(detections): | |
plate = [] | |
for detection in detections: | |
_, text, _ = detection | |
text = text.upper().replace(' ', '') | |
plate.append(text) | |
return " ".join(plate) if plate else None | |
return extract_text(detections_raw), extract_text(detections_preprocessed) | |
def clean_plate_text(text): | |
if text is None: | |
return '' | |
cleaned = re.sub(r'[^A-Z0-9\-\s]', '', text.upper()) | |
cleaned = re.sub(r'\s+', '', cleaned).strip() | |
return cleaned | |
def validate_european_plate(text): | |
for country, pattern in EUROPEAN_PATTERNS.items(): | |
if re.match(pattern, text): | |
return text, country | |
return None, None | |
def post_process_ocr(raw_text, preprocessed_text): | |
cleaned_raw = clean_plate_text(raw_text) | |
validated_raw, country_raw = validate_european_plate(cleaned_raw) | |
cleaned_preprocessed = clean_plate_text(preprocessed_text) | |
validated_preprocessed, country_preprocessed = validate_european_plate(cleaned_preprocessed) | |
if validated_raw: | |
return validated_raw, country_raw, True | |
elif validated_preprocessed: | |
return validated_preprocessed, country_preprocessed, True | |
return cleaned_raw, 'Unknown', False | |
def detect_and_recognize_plates(image, ocr_engine='easyocr', confidence_threshold=0.5): | |
results_vehicle = model_vehicle(image) | |
plates_detected = [] | |
cropped_plates = [] | |
vehicles_found = False | |
for result in results_vehicle: | |
for bbox in result.boxes.data.tolist(): | |
x1, y1, x2, y2, score, class_id = bbox | |
if score < confidence_threshold: | |
continue # Skip detections below the confidence threshold | |
if int(class_id) == 2: # Class ID 2 represents cars in COCO dataset | |
vehicles_found = True | |
vehicle = image[int(y1):int(y2), int(x1):int(x2)] | |
results_plate = model_plate(vehicle) | |
for result_plate in results_plate: | |
for bbox_plate in result_plate.boxes.data.tolist(): | |
px1, py1, px2, py2, pscore, pclass_id = bbox_plate | |
if pscore < confidence_threshold: | |
continue # Skip detections below the confidence threshold | |
plate = vehicle[int(py1):int(py2), int(px1):int(px2)] | |
cropped_plates.append(plate) # Save the cropped plate | |
raw_result, preprocessed_result = read_license_plate(plate, ocr_engine=ocr_engine) | |
if raw_result or preprocessed_result: | |
validated_text, country, is_valid = post_process_ocr(raw_result, preprocessed_result) | |
plates_detected.append({ | |
'raw_text': raw_result, | |
'preprocessed_text': preprocessed_result, | |
'validated_text': validated_text, | |
'country': country, | |
'is_valid': is_valid, | |
'bbox': [int(x1+px1), int(y1+py1), int(x1+px2), int(y1+py2)] | |
}) | |
# Annotate the image | |
cv2.rectangle(image, (int(x1+px1), int(y1+py1)), (int(x1+px2), int(y1+py2)), (0, 255, 0), 2) | |
if validated_text: | |
cv2.putText(image, f"{validated_text} ({country})", (int(x1+px1), int(y1+py1)-10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) | |
if not vehicles_found: | |
results_plate = model_plate(image) | |
for result_plate in results_plate: | |
for bbox_plate in result_plate.boxes.data.tolist(): | |
px1, py1, px2, py2, pscore, pclass_id = bbox_plate | |
if pscore < confidence_threshold: | |
continue # Skip detections below the confidence threshold | |
plate = image[int(py1):int(py2), int(px1):int(px2)] | |
cropped_plates.append(plate) # Save the cropped plate | |
raw_result, preprocessed_result = read_license_plate(plate, ocr_engine=ocr_engine) | |
if raw_result or preprocessed_result: | |
validated_text, country, is_valid = post_process_ocr(raw_result, preprocessed_result) | |
plates_detected.append({ | |
'raw_text': raw_result, | |
'preprocessed_text': preprocessed_result, | |
'validated_text': validated_text, | |
'country': country, | |
'is_valid': is_valid, | |
'bbox': [int(px1), int(py1), int(px2), int(py2)] | |
}) | |
# Annotate the image | |
cv2.rectangle(image, (int(px1), int(py1)), (int(px2), int(py2)), (0, 255, 0), 2) | |
if validated_text: | |
cv2.putText(image, f"{validated_text} ({country})", (int(px1), int(py1)-10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) | |
return image, plates_detected, cropped_plates | |
def process_image(input_image, ocr_engine='easyocr', confidence_threshold=0.5) -> Tuple[Union[np.ndarray, None], pd.DataFrame, List[np.ndarray]]: | |
try: | |
# Convert Gradio image to numpy array | |
if isinstance(input_image, np.ndarray): | |
image_np = input_image | |
elif isinstance(input_image, Image.Image): | |
image_np = np.array(input_image) | |
else: | |
raise ValueError("Unsupported image type") | |
# Detect and recognize plates | |
annotated_image, plates, cropped_plates = detect_and_recognize_plates(image_np, ocr_engine=ocr_engine, confidence_threshold=confidence_threshold) | |
# Prepare the result as a pandas DataFrame | |
results = [] | |
for i, plate in enumerate(plates): | |
results.append({ | |
"Plate Number": i + 1, | |
"Validated Text": plate['validated_text'], | |
"Country": plate['country'], | |
"Valid": "Yes" if plate['is_valid'] else "No", | |
"Raw OCR": plate['raw_text'], | |
"Preprocessed OCR": plate['preprocessed_text'], | |
}) | |
df = pd.DataFrame(results) if results else pd.DataFrame({"Message": ["No license plates detected"]}) | |
return annotated_image, df, cropped_plates | |
except Exception as e: | |
print(f"An error occurred: {str(e)}") | |
return None, pd.DataFrame({"Error": [str(e)]}), [] | |
def compare_ocr_engines(image): | |
ocr_engines = ['easyocr', 'pytesseract', 'kerasocr', 'trocr'] | |
results = {} | |
for engine in ocr_engines: | |
start_time = time.time() | |
_, df, _ = process_image(image, ocr_engine=engine) | |
end_time = time.time() | |
results[engine] = { | |
'processing_time': end_time - start_time, | |
'plates_detected': len(df) if 'Plate Number' in df.columns else 0, | |
'texts': df['Validated Text'].tolist() if 'Validated Text' in df.columns else [] | |
} | |
comparison_df = pd.DataFrame({ | |
'OCR Engine': ocr_engines, | |
'Processing Time (s)': [results[engine]['processing_time'] for engine in ocr_engines], | |
'Plates Detected': [results[engine]['plates_detected'] for engine in ocr_engines], | |
'Detected Texts': [', '.join(results[engine]['texts']) for engine in ocr_engines] | |
}) | |
return comparison_df | |
# gradio app | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# 🚗 ALPR YOLOv8 and Multi-OCR 🚗 | |
Test this ALPR solution using YOLOv8 and various OCR engines! | |
> Better results with high quality images, plate aligned horizontally, clearly visible. | |
""" | |
) | |
with gr.Tabs(): | |
with gr.TabItem("Single Image Processing"): | |
with gr.Accordion("How It Works", open=False): | |
gr.Markdown( | |
""" | |
This ALPR (Automatic License Plate Recognition) system works in several steps: | |
1. Vehicle Detection: Uses YOLOv8 to detect vehicles in the image with pretrained model on MS-COCO dataset. | |
2. License Plate Detection: Applies a custom YOLOv8 model to locate license plates region within detected vehicles to crop it. | |
3. Add preprocessing on the cropped plate that can help to give better results in some situation. | |
4. OCR: Employs various OCR engines to read the text on the cropped license plates. | |
5. Post-processing: Cleans and validates the detected text against known license plate patterns. | |
""" | |
) | |
with gr.Accordion("OCR Engines", open=False): | |
gr.Markdown( | |
""" | |
The system supports multiple OCR engines: | |
- [EasyOCR](https://github.com/JaidedAI/EasyOCR): General-purpose OCR library with good accuracy. | |
- [Pytesseract](https://github.com/madmaze/pytesseract): Open-source OCR engine based on Tesseract. | |
- [Keras-OCR](https://github.com/faustomorales/keras-ocr): Deep learning-based OCR solution. | |
- [TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr): Transformer-based OCR model for handwritten and printed text. | |
Each engine has its strengths and may perform differently depending on the image quality and license plate style. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(type="numpy", label="Input image") | |
ocr_selector = gr.Radio(choices=['easyocr', 'paddleocr', 'pytesseract', 'kerasocr', 'trocr'], value='easyocr', label="Select OCR Engine") | |
confidence_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label="Detection Confidence Threshold") | |
submit_btn = gr.Button("Detect License Plates", variant="primary") | |
with gr.Column(scale=1): | |
output_image = gr.Image(type="numpy", label="Annotated image") | |
cropped_plate_gallery = gr.Gallery(label="Cropped plates") | |
output_table = gr.Dataframe(label="Detection results") | |
with gr.Accordion("Understanding the Results", open=False): | |
gr.Markdown( | |
""" | |
The results table provides the following information: | |
- Plate Number: Sequential number assigned to each detected plate. | |
- Validated Text: The final, cleaned, and validated license plate text. | |
- Country: Estimated country of origin based on the plate format. | |
- Valid: Indicates whether the plate matches a known format. | |
- Raw OCR: The initial text detected by the OCR engine. | |
- Preprocessed OCR: Text detected after image preprocessing. | |
The confidence threshold determines the minimum confidence score for a detection to be considered valid. | |
""" | |
) | |
with gr.TabItem("OCR Engine Comparison"): | |
with gr.Row(): | |
comparison_input = gr.Image(type="numpy", label="Input Image for Comparison") | |
compare_btn = gr.Button("Compare OCR Engines") | |
comparison_output = gr.Dataframe(label="OCR Engine Comparison Results") | |
# Event handlers | |
submit_btn.click( | |
fn=process_image, | |
inputs=[input_image, ocr_selector, confidence_slider], | |
outputs=[output_image, output_table, cropped_plate_gallery] | |
) | |
compare_btn.click( | |
fn=compare_ocr_engines, | |
inputs=[comparison_input], | |
outputs=[comparison_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |