Norakneath commited on
Commit
4d73384
·
verified ·
1 Parent(s): 07e5b4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -90
app.py CHANGED
@@ -2,117 +2,127 @@ import gradio as gr
2
  from ultralytics import YOLO
3
  from PIL import Image, ImageDraw, ImageFont
4
  import random
 
 
5
 
6
- # Load YOLO model (ensure the model file exists in the working directory)
7
  YOLO_MODEL_PATH = "best-Dense.pt"
 
 
 
 
8
  model = YOLO(YOLO_MODEL_PATH, task='detect').to("cpu")
9
 
10
- # Define a set of colors for different classes
11
  CLASS_COLORS = {}
 
12
 
13
  def get_class_color(class_id):
14
- """Assign a random color to each class."""
15
  if class_id not in CLASS_COLORS:
16
  CLASS_COLORS[class_id] = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
17
  return CLASS_COLORS[class_id]
18
 
19
- # Class Names (Modify based on your dataset)
20
- CLASS_NAMES = {0: "Text"} # Example labels
21
-
22
- def detect_text_lines(image):
23
- """Detects text lines with two different confidence and IoU thresholds."""
24
- image = Image.fromarray(image)
25
- original_image = image.copy()
26
-
27
- # Define thresholds for debugging
28
- thresholds = [
29
- {"conf": 0.6, "iou": 0.5}, # Default thresholds
30
- {"conf": 0.4, "iou": 0.3}, # Lower thresholds for comparison
31
- ]
32
-
33
- results_list = []
34
-
35
- for threshold in thresholds:
36
- conf = threshold["conf"]
37
- iou = threshold["iou"]
38
-
39
- # Run YOLO text detection with specific thresholds
40
- results = model.predict(image, conf=conf, iou=iou, device="cpu")
41
- detected_boxes = results[0].boxes.xyxy.tolist() if hasattr(results[0].boxes, 'xyxy') else []
42
- class_ids = results[0].boxes.cls.tolist() if hasattr(results[0].boxes, 'cls') else []
43
- detected_boxes = [list(map(int, box)) for box in detected_boxes]
44
-
45
- # Draw bounding boxes on the image
46
- annotated_image = original_image.copy()
47
- draw = ImageDraw.Draw(annotated_image)
48
 
49
- try:
50
- font = ImageFont.truetype("arial.ttf", 18) # Load a font (ensure arial.ttf is available)
51
- except:
52
- font = ImageFont.load_default() # Fallback in case font is missing
53
-
54
- for idx, (x1, y1, x2, y2) in enumerate(detected_boxes):
55
- class_id = int(class_ids[idx]) if idx < len(class_ids) else -1
56
- color = get_class_color(class_id)
57
- class_name = CLASS_NAMES.get(class_id, f"Class {class_id}")
58
-
59
- # Draw bounding box
60
- draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
61
-
62
- # Draw label with background
63
- text_size = draw.textbbox((0, 0), class_name, font=font)
64
- text_width = text_size[2] - text_size[0]
65
- text_height = text_size[3] - text_size[1]
66
-
67
- # Draw filled rectangle behind text for better visibility
68
- draw.rectangle([x1, y1 - text_height - 4, x1 + text_width + 6, y1], fill=color)
69
- draw.text((x1 + 3, y1 - text_height - 2), class_name, fill="white", font=font)
70
-
71
- total_objects = len(detected_boxes)
72
- total_classes = len(set(class_ids))
73
 
74
- results_list.append({
75
- "image": annotated_image,
76
- "objects": f"Total Objects Detected: {total_objects} (Conf={conf}, IoU={iou})",
77
- "classes": f"Total Classes Detected: {total_classes} (Conf={conf}, IoU={iou})"
 
 
 
 
 
 
 
 
 
78
  })
79
 
80
- return (
81
- results_list[0]["image"], results_list[0]["objects"], results_list[0]["classes"],
82
- results_list[1]["image"], results_list[1]["objects"], results_list[1]["classes"]
 
 
83
  )
84
 
85
- # Gradio UI
86
  with gr.Blocks() as iface:
87
  gr.Markdown("# 📜 Text Line Detection with YOLO")
88
- gr.Markdown("## 📷 Upload an image to detect text lines")
89
-
90
  with gr.Row():
91
- with gr.Column(scale=1):
92
- gr.Markdown("### 📤 Upload Image")
93
- image_input = gr.Image(type="numpy", label="Upload an image")
94
-
95
- with gr.Column(scale=2):
96
- gr.Markdown("### 🖼 Annotated Images with Bounding Boxes")
97
- output_annotated_1 = gr.Image(type="pil", label="Detection (Conf=0.6, IoU=0.5)")
98
- output_annotated_2 = gr.Image(type="pil", label="Detection (Conf=0.4, IoU=0.3)")
99
-
100
- gr.Markdown("### 🔢 Detection Results")
101
- output_objects_1 = gr.Textbox(label="Total Objects Detected (Conf=0.6)", lines=1)
102
- output_classes_1 = gr.Textbox(label="Total Classes Detected (Conf=0.6)", lines=1)
103
-
104
- output_objects_2 = gr.Textbox(label="Total Objects Detected (Conf=0.4)", lines=1)
105
- output_classes_2 = gr.Textbox(label="Total Classes Detected (Conf=0.4)", lines=1)
106
-
107
- image_input.upload(
108
  detect_text_lines,
109
- inputs=image_input,
110
- outputs=[
111
- output_annotated_1, output_objects_1, output_classes_1,
112
- output_annotated_2, output_objects_2, output_classes_2
113
- ]
114
  )
115
 
116
- # 🚀 Run the app locally
117
  if __name__ == "__main__":
118
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
2
  from ultralytics import YOLO
3
  from PIL import Image, ImageDraw, ImageFont
4
  import random
5
+ import numpy as np
6
+ import os
7
 
8
+ # Check if YOLO model exists
9
  YOLO_MODEL_PATH = "best-Dense.pt"
10
+ if not os.path.exists(YOLO_MODEL_PATH):
11
+ raise FileNotFoundError(f"YOLO model file not found at {YOLO_MODEL_PATH}")
12
+
13
+ # Load YOLO model
14
  model = YOLO(YOLO_MODEL_PATH, task='detect').to("cpu")
15
 
16
+ # Define class colors and names
17
  CLASS_COLORS = {}
18
+ CLASS_NAMES = {0: "Text"}
19
 
20
  def get_class_color(class_id):
21
+ """Assign consistent random colors to classes."""
22
  if class_id not in CLASS_COLORS:
23
  CLASS_COLORS[class_id] = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
24
  return CLASS_COLORS[class_id]
25
 
26
+ def safe_font_load():
27
+ """Safely load font with fallback."""
28
+ try:
29
+ return ImageFont.truetype("arial.ttf", 18)
30
+ except:
31
+ return ImageFont.load_default()
32
+
33
+ def process_detection(image, conf, iou):
34
+ """Process detection with error handling."""
35
+ pil_image = Image.fromarray(image)
36
+ draw = ImageDraw.Draw(pil_image)
37
+ font = safe_font_load()
38
+
39
+ # Run model prediction
40
+ results = model.predict(pil_image, conf=conf, iou=iou, device="cpu")
41
+
42
+ # Handle empty results safely
43
+ detected_boxes = []
44
+ class_ids = []
45
+ if results[0].boxes is not None:
46
+ detected_boxes = results[0].boxes.xyxy.cpu().numpy().tolist()
47
+ class_ids = results[0].boxes.cls.cpu().numpy().astype(int).tolist()
48
+
49
+ # Draw bounding boxes and labels
50
+ for idx, (x1, y1, x2, y2) in enumerate(detected_boxes):
51
+ class_id = class_ids[idx] if idx < len(class_ids) else 0
52
+ color = get_class_color(class_id)
53
+ class_name = CLASS_NAMES.get(class_id, f"Class {class_id}")
 
54
 
55
+ # Draw rectangle
56
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
57
+
58
+ # Draw text label
59
+ text = f"{class_name}"
60
+ text_bbox = draw.textbbox((0, 0), text, font=font)
61
+ draw.rectangle(
62
+ [x1, y1 - (text_bbox[3] - text_bbox[1]) - 4, x1 + (text_bbox[2] - text_bbox[0]) + 6, y1],
63
+ fill=color
64
+ )
65
+ draw.text(
66
+ (x1 + 3, y1 - (text_bbox[3] - text_bbox[1]) - 2),
67
+ text,
68
+ fill="white",
69
+ font=font
70
+ )
71
+
72
+ return pil_image, len(detected_boxes), len(set(class_ids))
 
 
 
 
 
 
73
 
74
+ def detect_text_lines(image):
75
+ """Main detection function with dual threshold handling."""
76
+ # Process with two different threshold sets
77
+ results = []
78
+ for thresholds in [(0.6, 0.5), (0.4, 0.3)]:
79
+ conf, iou = thresholds
80
+ annotated_img, obj_count, class_count = process_detection(
81
+ np.array(image), conf, iou
82
+ )
83
+ results.append({
84
+ "image": annotated_img,
85
+ "objects": f"Objects: {obj_count} (Conf={conf}, IoU={iou})",
86
+ "classes": f"Classes: {class_count} (Conf={conf}, IoU={iou})"
87
  })
88
 
89
+ return tuple(
90
+ item for sublist in [
91
+ (results[0]["image"], results[0]["objects"], results[0]["classes"],
92
+ results[1]["image"], results[1]["objects"], results[1]["classes"])
93
+ ] for item in sublist
94
  )
95
 
96
+ # Gradio interface
97
  with gr.Blocks() as iface:
98
  gr.Markdown("# 📜 Text Line Detection with YOLO")
99
+
 
100
  with gr.Row():
101
+ with gr.Column():
102
+ input_image = gr.Image(type="numpy", label="Input Image")
103
+ submit_btn = gr.Button("Detect Text")
104
+
105
+ with gr.Column():
106
+ with gr.Tab("High Confidence"):
107
+ high_conf_img = gr.Image(type="pil", label="Detections (0.6 conf)")
108
+ high_conf_obj = gr.Textbox(label="Object Count")
109
+ high_conf_cls = gr.Textbox(label="Class Count")
110
+
111
+ with gr.Tab("Low Confidence"):
112
+ low_conf_img = gr.Image(type="pil", label="Detections (0.4 conf)")
113
+ low_conf_obj = gr.Textbox(label="Object Count")
114
+ low_conf_cls = gr.Textbox(label="Class Count")
115
+
116
+ submit_btn.click(
 
117
  detect_text_lines,
118
+ inputs=input_image,
119
+ outputs=[high_conf_img, high_conf_obj, high_conf_cls, low_conf_img, low_conf_obj, low_conf_cls]
 
 
 
120
  )
121
 
 
122
  if __name__ == "__main__":
123
+ iface.launch(
124
+ server_name="0.0.0.0",
125
+ server_port=7860,
126
+ show_error=True,
127
+ share=False
128
+ )