Spaces:
Sleeping
Sleeping
import gradio as gr | |
from ultralytics import YOLO | |
from PIL import Image, ImageDraw, ImageFont | |
import random | |
import numpy as np | |
import os | |
# Check if YOLO model exists | |
YOLO_MODEL_PATH = "best-Dense.pt" | |
if not os.path.exists(YOLO_MODEL_PATH): | |
raise FileNotFoundError(f"YOLO model file not found at {YOLO_MODEL_PATH}") | |
# Load YOLO model | |
model = YOLO(YOLO_MODEL_PATH, task='detect').to("cpu") | |
# Define class colors and names | |
CLASS_COLORS = {} | |
CLASS_NAMES = {0: "Text"} | |
def get_class_color(class_id): | |
"""Assign consistent random colors to classes.""" | |
if class_id not in CLASS_COLORS: | |
CLASS_COLORS[class_id] = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) | |
return CLASS_COLORS[class_id] | |
def safe_font_load(): | |
"""Safely load font with fallback.""" | |
try: | |
return ImageFont.truetype("arial.ttf", 18) | |
except: | |
return ImageFont.load_default() | |
def process_detection(image, conf, iou): | |
"""Process detection with error handling.""" | |
pil_image = Image.fromarray(image) | |
draw = ImageDraw.Draw(pil_image) | |
font = safe_font_load() | |
# Run model prediction | |
results = model.predict(pil_image, conf=conf, iou=iou, device="cpu") | |
# Handle empty results safely | |
detected_boxes = [] | |
class_ids = [] | |
if results[0].boxes is not None: | |
detected_boxes = results[0].boxes.xyxy.cpu().numpy().tolist() | |
class_ids = results[0].boxes.cls.cpu().numpy().astype(int).tolist() | |
# Draw bounding boxes and labels | |
for idx, (x1, y1, x2, y2) in enumerate(detected_boxes): | |
class_id = class_ids[idx] if idx < len(class_ids) else 0 | |
color = get_class_color(class_id) | |
class_name = CLASS_NAMES.get(class_id, f"Class {class_id}") | |
# Draw rectangle | |
draw.rectangle([x1, y1, x2, y2], outline=color, width=2) | |
# Draw text label | |
text = f"{class_name}" | |
text_bbox = draw.textbbox((0, 0), text, font=font) | |
draw.rectangle( | |
[x1, y1 - (text_bbox[3] - text_bbox[1]) - 4, x1 + (text_bbox[2] - text_bbox[0]) + 6, y1], | |
fill=color | |
) | |
draw.text( | |
(x1 + 3, y1 - (text_bbox[3] - text_bbox[1]) - 2), | |
text, | |
fill="white", | |
font=font | |
) | |
return pil_image, len(detected_boxes), len(set(class_ids)) | |
def detect_text_lines(image): | |
"""Main detection function with dual threshold handling.""" | |
# Process with two different threshold sets | |
results = [] | |
for thresholds in [(0.6, 0.5), (0.4, 0.3)]: | |
conf, iou = thresholds | |
annotated_img, obj_count, class_count = process_detection( | |
np.array(image), conf, iou | |
) | |
results.append({ | |
"image": annotated_img, | |
"objects": f"Objects: {obj_count} (Conf={conf}, IoU={iou})", | |
"classes": f"Classes: {class_count} (Conf={conf}, IoU={iou})" | |
}) | |
return tuple( | |
item for sublist in [ | |
(results[0]["image"], results[0]["objects"], results[0]["classes"], | |
results[1]["image"], results[1]["objects"], results[1]["classes"]) | |
] for item in sublist | |
) | |
# Gradio interface | |
with gr.Blocks() as iface: | |
gr.Markdown("# 📜 Text Line Detection with YOLO") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="numpy", label="Input Image") | |
submit_btn = gr.Button("Detect Text") | |
with gr.Column(): | |
with gr.Tab("High Confidence"): | |
high_conf_img = gr.Image(type="pil", label="Detections (0.6 conf)") | |
high_conf_obj = gr.Textbox(label="Object Count") | |
high_conf_cls = gr.Textbox(label="Class Count") | |
with gr.Tab("Low Confidence"): | |
low_conf_img = gr.Image(type="pil", label="Detections (0.4 conf)") | |
low_conf_obj = gr.Textbox(label="Object Count") | |
low_conf_cls = gr.Textbox(label="Class Count") | |
submit_btn.click( | |
detect_text_lines, | |
inputs=input_image, | |
outputs=[high_conf_img, high_conf_obj, high_conf_cls, low_conf_img, low_conf_obj, low_conf_cls] | |
) | |
if __name__ == "__main__": | |
iface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
share=False | |
) | |