import gradio as gr from PIL import Image import numpy as np from utils import perspective_transform from ultralytics import YOLO max_len = 10 def ocr(plate_image, ocr_model): try: if 'yolo' in ocr_model: model = YOLO(f'./weights/{ocr_model}.pt', task='detect') # model = YOLO(f'./weights/best.pt', task='detect') preds = model.predict(plate_image) pred_data = [list(pred.boxes.data.cpu().numpy()[0]) for pred in preds[0]] x_sorted_preds = sorted(pred_data, key=lambda x:x[0]) pred_cls = np.array(x_sorted_preds,dtype=np.int32)[:,-1] return ''.join([model.names[cls] for cls in pred_cls]) else: return '' except Exception : return 'error' def process_image(image, detection_model, ocr_model, yolo_thresh, perpective_width, perpective_height): task = 'obb' if 'obb' in detection_model else 'detect' model = YOLO(f'./weights/{detection_model}.pt', task=task) predict = model(image, conf=yolo_thresh) if 'obb' in detection_model: obb_crops, transformed = perspective_transform(predict, dst_width=perpective_width, dst_height=perpective_height) crop_results = [(plate, ocr(plate, ocr_model)) for plate in obb_crops] transform_results = [(plate, ocr(plate, ocr_model)) for plate in transformed] return crop_results, transform_results else: boxes = np.array(predict[0].boxes.xyxy.cpu().numpy(),dtype=np.int32) crops = [image[y1:y2,x1:x2,:] for x1,y1,x2,y2 in boxes] results = [(plate,ocr(plate,ocr_model)) for plate in crops] return results, [] def create_interface(): with gr.Blocks(css="footer{display:none !important}") as demo: with gr.Row(): with gr.Column(scale=1): detection_model = gr.Dropdown(label="Detection Model", choices=["yolov8-m", "yolov8-obb-m","yolov8-s", "yolov8-obb-s"], value="yolov8-m") ocr_model = gr.Dropdown(label="OCR Model", choices=["yolov32c", "trocr"], value="yolov32c") yolo_thresh = gr.Slider(minimum=0.0, maximum=1.0, label='yolo_threshold') with gr.Row(): perpective_width = gr.Slider(minimum=100, maximum=500, label='perpective_width') perpective_height = gr.Slider(minimum=50, maximum=200, label='perpective_height') with gr.Column(scale=3): input_image = gr.Image(label="Upload Image", type="numpy") with gr.Row() as r1: c_img1 = gr.Image(visible=False) c_text1 = gr.Textbox(visible=False) t_img1 = gr.Image(visible=False) t_text1 = gr.Textbox(visible=False) with gr.Row() as r2: c_img2 = gr.Image(visible=False) c_text2 = gr.Textbox(visible=False) t_img2 = gr.Image(visible=False) t_text2 = gr.Textbox(visible=False) with gr.Row() as r3: c_img3 = gr.Image(visible=False) c_text3 = gr.Textbox(visible=False) t_img3 = gr.Image(visible=False) t_text3 = gr.Textbox(visible=False) with gr.Row() as r4: c_img4 = gr.Image(visible=False) c_text4= gr.Textbox(visible=False) t_img4 = gr.Image(visible=False) t_text4 = gr.Textbox(visible=False) with gr.Row() as r5: c_img5 = gr.Image(visible=False) c_text5 = gr.Textbox(visible=False) t_img5 = gr.Image(visible=False) t_text5 = gr.Textbox(visible=False) #o = gr.Textbox() def main_fn(image, detection_model, ocr_model, yolo_thresh, perpective_width, perpective_height): crop_results, transform_results = process_image(image, detection_model, ocr_model, yolo_thresh, perpective_width, perpective_height) crop_output = [] for i,(crop, c_text) in enumerate(crop_results): crop_output.append(gr.Image(label=f'clp{i+1}', value=crop, visible=True)) crop_output.append(gr.Textbox(label=f'text{i+1}',value=c_text, visible=True)) for j in range((max_len - len(crop_output)) //2): crop_output.append(gr.Image(visible=False)) crop_output.append(gr.Textbox(visible=False)) transform_output = [] for i,(transform, t_text) in enumerate(transform_results): transform_output.append(gr.Image(label=f'tlp{i+1}', value=transform, visible=True)) transform_output.append(gr.Textbox(label=f'text{i+1}',value=t_text, visible=True)) for j in range((max_len - len(transform_output)) //2): transform_output.append(gr.Image(visible=False)) transform_output.append(gr.Textbox(visible=False)) return crop_output + transform_output # return len(crop_results) submit_button = gr.Button("Process Image") submit_button.click( fn=main_fn, inputs=[input_image, detection_model, ocr_model,yolo_thresh, perpective_width, perpective_height], outputs=[c_img1, c_text1, c_img2, c_text2, c_img3, c_text3, c_img4, c_text4, c_img5, c_text5, t_img1, t_text1, t_img2, t_text2, t_img3, t_text3, t_img4, t_text4, t_img5, t_text5] # outputs = o ) return demo # Run the Gradio app demo = create_interface() demo.launch()