Norakneath commited on
Commit
9c1365f
·
verified ·
1 Parent(s): 4114182

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from ultralytics import YOLO
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import random
5
+
6
+ # Load YOLO model (ensure best.pt exists in the working directory)
7
+ YOLO_MODEL_PATH = "120epochs.pt"
8
+ model = YOLO(YOLO_MODEL_PATH, task='detect').to("cpu")
9
+
10
+ # Define a set of colors for different classes
11
+ CLASS_COLORS = {}
12
+
13
+ def get_class_color(class_id):
14
+ """Assign a random color to each class."""
15
+ if class_id not in CLASS_COLORS:
16
+ CLASS_COLORS[class_id] = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
17
+ return CLASS_COLORS[class_id]
18
+
19
+ # Class Names (Modify based on your dataset)
20
+ CLASS_NAMES = {0: "Text Line", 1: "Heading", 2: "Signature"} # Example labels
21
+
22
+ def detect_text_lines(image):
23
+ """Detects text lines with two different confidence and IoU thresholds."""
24
+ image = Image.fromarray(image)
25
+ original_image = image.copy()
26
+
27
+ # Define thresholds for debugging
28
+ thresholds = [
29
+ {"conf": 0.6, "iou": 0.5}, # Default thresholds
30
+ {"conf": 0.4, "iou": 0.3}, # Lower thresholds for comparison
31
+ ]
32
+
33
+ results_list = []
34
+
35
+ for threshold in thresholds:
36
+ conf = threshold["conf"]
37
+ iou = threshold["iou"]
38
+
39
+ # Run YOLO text detection with specific thresholds
40
+ results = model.predict(image, conf=conf, iou=iou, device="cpu")
41
+ detected_boxes = results[0].boxes.xyxy.tolist()
42
+ class_ids = results[0].boxes.cls.tolist()
43
+ detected_boxes = [list(map(int, box)) for box in detected_boxes]
44
+
45
+ # Draw bounding boxes on the image
46
+ annotated_image = original_image.copy()
47
+ draw = ImageDraw.Draw(annotated_image)
48
+
49
+ try:
50
+ font = ImageFont.truetype("arial.ttf", 18) # Load a font (ensure arial.ttf is available)
51
+ except:
52
+ font = ImageFont.load_default() # Fallback in case font is missing
53
+
54
+ for idx, (x1, y1, x2, y2) in enumerate(detected_boxes):
55
+ class_id = int(class_ids[idx])
56
+ color = get_class_color(class_id)
57
+ class_name = CLASS_NAMES.get(class_id, f"Class {class_id}")
58
+
59
+ # Draw bounding box
60
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
61
+
62
+ # Draw label with background
63
+ text_size = draw.textbbox((0, 0), class_name, font=font)
64
+ text_width = text_size[2] - text_size[0]
65
+ text_height = text_size[3] - text_size[1]
66
+
67
+ # Draw filled rectangle behind text for better visibility
68
+ draw.rectangle([x1, y1 - text_height - 4, x1 + text_width + 6, y1], fill=color)
69
+ draw.text((x1 + 3, y1 - text_height - 2), class_name, fill="white", font=font)
70
+
71
+ total_objects = len(detected_boxes)
72
+ total_classes = len(set(class_ids))
73
+
74
+ results_list.append({
75
+ "image": annotated_image,
76
+ "objects": f"Total Objects Detected: {total_objects} (Conf={conf}, IoU={iou})",
77
+ "classes": f"Total Classes Detected: {total_classes} (Conf={conf}, IoU={iou})"
78
+ })
79
+
80
+ return (
81
+ results_list[0]["image"], results_list[0]["objects"], results_list[0]["classes"],
82
+ results_list[1]["image"], results_list[1]["objects"], results_list[1]["classes"]
83
+ )
84
+
85
+ # Gradio UI
86
+ with gr.Blocks() as iface:
87
+ gr.Markdown("# 📜 Text Line Detection with YOLO")
88
+ gr.Markdown("## 📷 Upload an image to detect text lines")
89
+
90
+ with gr.Row():
91
+ with gr.Column(scale=1):
92
+ gr.Markdown("### 📤 Upload Image")
93
+ image_input = gr.Image(type="numpy", label="Upload an image")
94
+
95
+ with gr.Column(scale=2):
96
+ gr.Markdown("### 🖼 Annotated Images with Bounding Boxes")
97
+ output_annotated_1 = gr.Image(type="pil", label="Detection (Conf=0.6, IoU=0.5)")
98
+ output_annotated_2 = gr.Image(type="pil", label="Detection (Conf=0.4, IoU=0.3)")
99
+
100
+ gr.Markdown("### 🔢 Detection Results")
101
+ output_objects_1 = gr.Textbox(label="Total Objects Detected (Conf=0.6)", lines=1)
102
+ output_classes_1 = gr.Textbox(label="Total Classes Detected (Conf=0.6)", lines=1)
103
+
104
+ output_objects_2 = gr.Textbox(label="Total Objects Detected (Conf=0.4)", lines=1)
105
+ output_classes_2 = gr.Textbox(label="Total Classes Detected (Conf=0.4)", lines=1)
106
+
107
+ image_input.upload(
108
+ detect_text_lines,
109
+ inputs=image_input,
110
+ outputs=[
111
+ output_annotated_1, output_objects_1, output_classes_1,
112
+ output_annotated_2, output_objects_2, output_classes_2
113
+ ]
114
+ )
115
+
116
+ # 🚀 Ensure the app runs properly in Hugging Face Spaces
117
+ if __name__ == "__main__":
118
+ iface.launch(server_name="0.0.0.0", server_port=7860)