TestingYolo / app.py
Norakneath's picture
Update app.py
4d73384 verified
import gradio as gr
from ultralytics import YOLO
from PIL import Image, ImageDraw, ImageFont
import random
import numpy as np
import os
# Check if YOLO model exists
YOLO_MODEL_PATH = "best-Dense.pt"
if not os.path.exists(YOLO_MODEL_PATH):
raise FileNotFoundError(f"YOLO model file not found at {YOLO_MODEL_PATH}")
# Load YOLO model
model = YOLO(YOLO_MODEL_PATH, task='detect').to("cpu")
# Define class colors and names
CLASS_COLORS = {}
CLASS_NAMES = {0: "Text"}
def get_class_color(class_id):
"""Assign consistent random colors to classes."""
if class_id not in CLASS_COLORS:
CLASS_COLORS[class_id] = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
return CLASS_COLORS[class_id]
def safe_font_load():
"""Safely load font with fallback."""
try:
return ImageFont.truetype("arial.ttf", 18)
except:
return ImageFont.load_default()
def process_detection(image, conf, iou):
"""Process detection with error handling."""
pil_image = Image.fromarray(image)
draw = ImageDraw.Draw(pil_image)
font = safe_font_load()
# Run model prediction
results = model.predict(pil_image, conf=conf, iou=iou, device="cpu")
# Handle empty results safely
detected_boxes = []
class_ids = []
if results[0].boxes is not None:
detected_boxes = results[0].boxes.xyxy.cpu().numpy().tolist()
class_ids = results[0].boxes.cls.cpu().numpy().astype(int).tolist()
# Draw bounding boxes and labels
for idx, (x1, y1, x2, y2) in enumerate(detected_boxes):
class_id = class_ids[idx] if idx < len(class_ids) else 0
color = get_class_color(class_id)
class_name = CLASS_NAMES.get(class_id, f"Class {class_id}")
# Draw rectangle
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
# Draw text label
text = f"{class_name}"
text_bbox = draw.textbbox((0, 0), text, font=font)
draw.rectangle(
[x1, y1 - (text_bbox[3] - text_bbox[1]) - 4, x1 + (text_bbox[2] - text_bbox[0]) + 6, y1],
fill=color
)
draw.text(
(x1 + 3, y1 - (text_bbox[3] - text_bbox[1]) - 2),
text,
fill="white",
font=font
)
return pil_image, len(detected_boxes), len(set(class_ids))
def detect_text_lines(image):
"""Main detection function with dual threshold handling."""
# Process with two different threshold sets
results = []
for thresholds in [(0.6, 0.5), (0.4, 0.3)]:
conf, iou = thresholds
annotated_img, obj_count, class_count = process_detection(
np.array(image), conf, iou
)
results.append({
"image": annotated_img,
"objects": f"Objects: {obj_count} (Conf={conf}, IoU={iou})",
"classes": f"Classes: {class_count} (Conf={conf}, IoU={iou})"
})
return tuple(
item for sublist in [
(results[0]["image"], results[0]["objects"], results[0]["classes"],
results[1]["image"], results[1]["objects"], results[1]["classes"])
] for item in sublist
)
# Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# 📜 Text Line Detection with YOLO")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="numpy", label="Input Image")
submit_btn = gr.Button("Detect Text")
with gr.Column():
with gr.Tab("High Confidence"):
high_conf_img = gr.Image(type="pil", label="Detections (0.6 conf)")
high_conf_obj = gr.Textbox(label="Object Count")
high_conf_cls = gr.Textbox(label="Class Count")
with gr.Tab("Low Confidence"):
low_conf_img = gr.Image(type="pil", label="Detections (0.4 conf)")
low_conf_obj = gr.Textbox(label="Object Count")
low_conf_cls = gr.Textbox(label="Class Count")
submit_btn.click(
detect_text_lines,
inputs=input_image,
outputs=[high_conf_img, high_conf_obj, high_conf_cls, low_conf_img, low_conf_obj, low_conf_cls]
)
if __name__ == "__main__":
iface.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
share=False
)