import gradio as gr from ultralytics import YOLO from PIL import Image, ImageDraw, ImageFont import random import numpy as np import os # Check if YOLO model exists YOLO_MODEL_PATH = "best-Dense.pt" if not os.path.exists(YOLO_MODEL_PATH): raise FileNotFoundError(f"YOLO model file not found at {YOLO_MODEL_PATH}") # Load YOLO model model = YOLO(YOLO_MODEL_PATH, task='detect').to("cpu") # Define class colors and names CLASS_COLORS = {} CLASS_NAMES = {0: "Text"} def get_class_color(class_id): """Assign consistent random colors to classes.""" if class_id not in CLASS_COLORS: CLASS_COLORS[class_id] = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) return CLASS_COLORS[class_id] def safe_font_load(): """Safely load font with fallback.""" try: return ImageFont.truetype("arial.ttf", 18) except: return ImageFont.load_default() def process_detection(image, conf, iou): """Process detection with error handling.""" pil_image = Image.fromarray(image) draw = ImageDraw.Draw(pil_image) font = safe_font_load() # Run model prediction results = model.predict(pil_image, conf=conf, iou=iou, device="cpu") # Handle empty results safely detected_boxes = [] class_ids = [] if results[0].boxes is not None: detected_boxes = results[0].boxes.xyxy.cpu().numpy().tolist() class_ids = results[0].boxes.cls.cpu().numpy().astype(int).tolist() # Draw bounding boxes and labels for idx, (x1, y1, x2, y2) in enumerate(detected_boxes): class_id = class_ids[idx] if idx < len(class_ids) else 0 color = get_class_color(class_id) class_name = CLASS_NAMES.get(class_id, f"Class {class_id}") # Draw rectangle draw.rectangle([x1, y1, x2, y2], outline=color, width=2) # Draw text label text = f"{class_name}" text_bbox = draw.textbbox((0, 0), text, font=font) draw.rectangle( [x1, y1 - (text_bbox[3] - text_bbox[1]) - 4, x1 + (text_bbox[2] - text_bbox[0]) + 6, y1], fill=color ) draw.text( (x1 + 3, y1 - (text_bbox[3] - text_bbox[1]) - 2), text, fill="white", font=font ) return pil_image, len(detected_boxes), len(set(class_ids)) def detect_text_lines(image): """Main detection function with dual threshold handling.""" # Process with two different threshold sets results = [] for thresholds in [(0.6, 0.5), (0.4, 0.3)]: conf, iou = thresholds annotated_img, obj_count, class_count = process_detection( np.array(image), conf, iou ) results.append({ "image": annotated_img, "objects": f"Objects: {obj_count} (Conf={conf}, IoU={iou})", "classes": f"Classes: {class_count} (Conf={conf}, IoU={iou})" }) return tuple( item for sublist in [ (results[0]["image"], results[0]["objects"], results[0]["classes"], results[1]["image"], results[1]["objects"], results[1]["classes"]) ] for item in sublist ) # Gradio interface with gr.Blocks() as iface: gr.Markdown("# 📜 Text Line Detection with YOLO") with gr.Row(): with gr.Column(): input_image = gr.Image(type="numpy", label="Input Image") submit_btn = gr.Button("Detect Text") with gr.Column(): with gr.Tab("High Confidence"): high_conf_img = gr.Image(type="pil", label="Detections (0.6 conf)") high_conf_obj = gr.Textbox(label="Object Count") high_conf_cls = gr.Textbox(label="Class Count") with gr.Tab("Low Confidence"): low_conf_img = gr.Image(type="pil", label="Detections (0.4 conf)") low_conf_obj = gr.Textbox(label="Object Count") low_conf_cls = gr.Textbox(label="Class Count") submit_btn.click( detect_text_lines, inputs=input_image, outputs=[high_conf_img, high_conf_obj, high_conf_cls, low_conf_img, low_conf_obj, low_conf_cls] ) if __name__ == "__main__": iface.launch( server_name="0.0.0.0", server_port=7860, show_error=True, share=False )