ChaseHan commited on
Commit
3ab79b5
·
verified ·
1 Parent(s): b1925a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -170
app.py CHANGED
@@ -4,206 +4,133 @@ import numpy as np
4
  import os
5
  import tempfile
6
  from ultralytics import YOLO
7
- import logging
8
 
9
- # Configure logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
 
 
 
12
 
13
  # Load the Latex2Layout model
14
- model_path = "latex2layout_object_detection_yolov8.pt"
15
  try:
16
- if not os.path.exists(model_path):
17
- raise FileNotFoundError(f"Model file not found: {model_path}")
18
  model = YOLO(model_path)
19
- logger.info("Model loaded successfully")
20
  except Exception as e:
21
- logger.error(f"Error loading model: {str(e)}")
22
- raise
23
 
24
  def detect_and_visualize(image):
25
  """
26
- Perform layout detection on the uploaded image using the Latex2Layout model and visualize the results.
27
-
28
  Args:
29
- image: The uploaded image
30
-
31
  Returns:
32
- annotated_image: Image with detection boxes
33
- layout_annotations: Annotations in YOLO format
34
  """
 
 
 
 
 
35
  try:
36
- if image is None:
37
- return None, "Error: No image uploaded."
38
-
39
- # Validate image format and dimensions
40
- if not isinstance(image, np.ndarray):
41
- return None, "Error: Invalid image format."
42
-
43
- if image.size == 0:
44
- return None, "Error: Empty image."
45
-
46
- # Run detection using the Latex2Layout model
47
  results = model(image)
48
- result = results[0]
49
-
50
- # Create a copy of the image for visualization
51
- annotated_image = image.copy()
52
- layout_annotations = []
53
-
54
- # Get image dimensions
55
- img_height, img_width = image.shape[:2]
56
-
57
- # Draw detection results
58
- for box in result.boxes:
59
- x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
60
- conf = float(box.conf[0])
61
- cls_id = int(box.cls[0])
62
- cls_name = result.names[cls_id]
63
-
64
- # Generate a color for each class
65
- color = tuple(np.random.randint(0, 255, 3).tolist())
66
-
67
- # Draw bounding box and label
68
- cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
69
- label = f'{cls_name} {conf:.2f}'
70
- (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
71
- cv2.rectangle(annotated_image, (x1, y1-label_height-5), (x1+label_width, y1), color, -1)
72
- cv2.putText(annotated_image, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
73
-
74
- # Convert to YOLO format (normalized)
75
- x_center = (x1 + x2) / (2 * img_width)
76
- y_center = (y1 + y2) / (2 * img_height)
77
- width = (x2 - x1) / img_width
78
- height = (y2 - y1) / img_height
79
- layout_annotations.append(f"{cls_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
80
-
81
- return annotated_image, "\n".join(layout_annotations)
82
  except Exception as e:
83
- logger.error(f"Error in detect_and_visualize: {str(e)}")
84
- return None, f"Error during detection: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- def save_layout_annotations(layout_annotations_str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  """
88
- Save layout annotations to a temporary file and return the file path.
89
-
90
  Args:
91
- layout_annotations_str: Annotations string in YOLO format
92
-
93
  Returns:
94
- file_path: Path to the saved annotation file
95
  """
 
 
 
 
 
96
  try:
97
- if not layout_annotations_str:
98
- return None
99
-
100
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt")
101
- with open(temp_file.name, "w") as f:
102
- f.write(layout_annotations_str)
103
- return temp_file.name
 
104
  except Exception as e:
105
- logger.error(f"Error in save_layout_annotations: {str(e)}")
106
- return None
 
 
 
 
107
 
108
- def load_example_image():
109
- """
110
- Load an example image for demonstration.
111
-
112
- Returns:
113
- image: The loaded example image or None if loading fails
114
- """
115
- try:
116
- example_path = "example_image.jpg"
117
- if not os.path.exists(example_path):
118
- logger.error(f"Example image not found: {example_path}")
119
- return None
120
- return cv2.imread(example_path)
121
- except Exception as e:
122
- logger.error(f"Error loading example image: {str(e)}")
123
- return None
124
-
125
- # Custom CSS for styling
126
- custom_css = """
127
- .container { max-width: 1200px; margin: auto; }
128
- .button-primary { background-color: #4CAF50; color: white; }
129
- .button-secondary { background-color: #008CBA; color: white; }
130
- .gr-image { border: 2px solid #ddd; border-radius: 5px; }
131
- .gr-textbox { font-family: monospace; }
132
- """
133
-
134
- # Create Gradio interface with enhanced styling
135
- with gr.Blocks(
136
- title="Latex2Layout Detection",
137
- theme=gr.themes.Default(),
138
- css=custom_css
139
- ) as demo:
140
- # Header with instructions
141
- gr.Markdown(
142
- """
143
- # Latex2Layout Layout Detection
144
- Upload an image to detect layout elements using the **Latex2Layout** model. View the annotated image and download the results in YOLO format.
145
- """
146
- )
147
-
148
- # Main layout with two columns
149
  with gr.Row():
150
- # Input column
151
- with gr.Column(scale=1):
152
- input_image = gr.Image(
153
- label="Upload Image",
154
- type="numpy",
155
- height=400,
156
- elem_classes="gr-image"
157
- )
158
- detect_btn = gr.Button(
159
- "Start Detection",
160
- variant="primary",
161
- elem_classes="button-primary"
162
- )
163
- gr.Markdown("**Tip**: Upload a clear image for optimal detection results.")
164
 
165
- # Output column
166
- with gr.Column(scale=1):
167
- output_image = gr.Image(
168
- label="Detection Results",
169
- height=400,
170
- elem_classes="gr-image"
171
- )
172
- layout_annotations = gr.Textbox(
173
- label="Layout Annotations (YOLO Format)",
174
- lines=10,
175
- max_lines=15,
176
- elem_classes="gr-textbox"
177
- )
178
- download_btn = gr.Button(
179
- "Download Annotations",
180
- variant="secondary",
181
- elem_classes="button-secondary"
182
- )
183
- download_file = gr.File(
184
- label="Download File",
185
- interactive=False
186
- )
187
-
188
- # Example image button (optional)
189
- with gr.Row():
190
- gr.Button("Load Example Image").click(
191
- fn=load_example_image,
192
- outputs=input_image
193
- )
194
-
195
- # Event handlers
196
  detect_btn.click(
197
  fn=detect_and_visualize,
198
- inputs=input_image,
199
- outputs=[output_image, layout_annotations],
200
- show_progress=True
201
  )
202
-
203
  download_btn.click(
204
- fn=save_layout_annotations,
205
- inputs=layout_annotations,
206
- outputs=download_file
207
  )
208
 
209
  # Launch the application
 
4
  import os
5
  import tempfile
6
  from ultralytics import YOLO
 
7
 
8
+ # Define the model path for Latex2Layout
9
+ model_path = "latex2layout_object_detection_yolov8.pt"
10
+
11
+ # Check if the model file exists before loading
12
+ if not os.path.exists(model_path):
13
+ raise FileNotFoundError(f"Model file not found at {model_path}")
14
 
15
  # Load the Latex2Layout model
 
16
  try:
 
 
17
  model = YOLO(model_path)
 
18
  except Exception as e:
19
+ raise RuntimeError(f"Failed to load Latex2Layout model: {e}")
 
20
 
21
  def detect_and_visualize(image):
22
  """
23
+ Perform object detection on the uploaded image and visualize the results.
24
+
25
  Args:
26
+ image: The uploaded image as a numpy array.
27
+
28
  Returns:
29
+ annotated_image: Image with bounding boxes drawn.
30
+ yolo_annotations: Annotations in YOLO format as a string.
31
  """
32
+ # Validate input image
33
+ if image is None or not isinstance(image, np.ndarray):
34
+ raise ValueError("Invalid image input: Please upload a valid image.")
35
+
36
+ # Run object detection with error handling
37
  try:
 
 
 
 
 
 
 
 
 
 
 
38
  results = model(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  except Exception as e:
40
+ raise RuntimeError(f"Error during Latex2Layout detection: {e}")
41
+
42
+ # Extract results from the first frame
43
+ result = results[0]
44
+ annotated_image = image.copy()
45
+ yolo_annotations = []
46
+
47
+ # Get image dimensions
48
+ img_height, img_width = image.shape[:2]
49
+
50
+ # Process each detected object
51
+ for box in result.boxes:
52
+ # Extract bounding box coordinates
53
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
54
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
55
+
56
+ # Get confidence and class details
57
+ conf = float(box.conf[0])
58
+ cls_id = int(box.cls[0])
59
+ cls_name = result.names[cls_id]
60
+
61
+ # Assign a random color to the class
62
+ color = tuple(np.random.randint(0, 255, 3).tolist())
63
 
64
+ # Draw bounding box on the image
65
+ cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
66
+
67
+ # Create and draw label with confidence
68
+ label = f"{cls_name} {conf:.2f}"
69
+ (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
70
+ cv2.rectangle(annotated_image, (x1, y1 - label_height - 5), (x1 + label_width, y1), color, -1)
71
+ cv2.putText(annotated_image, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
72
+
73
+ # Convert bounding box to YOLO format (normalized coordinates)
74
+ x_center = (x1 + x2) / (2 * img_width)
75
+ y_center = (y1 + y2) / (2 * img_height)
76
+ width = (x2 - x1) / img_width
77
+ height = (y2 - y1) / img_height
78
+ yolo_annotations.append(f"{cls_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
79
+
80
+ # Combine annotations into a single string
81
+ yolo_annotations_str = "\n".join(yolo_annotations) if yolo_annotations else "No objects detected."
82
+ return annotated_image, yolo_annotations_str
83
+
84
+ def save_yolo_annotations(yolo_annotations_str):
85
  """
86
+ Save YOLO annotations to a temporary file and return its path.
87
+
88
  Args:
89
+ yolo_annotations_str: Annotations string in YOLO format.
90
+
91
  Returns:
92
+ file_path: Path to the saved annotation file.
93
  """
94
+ # Handle empty annotations
95
+ if not yolo_annotations_str or yolo_annotations_str == "No objects detected.":
96
+ raise ValueError("No annotations available to save.")
97
+
98
+ # Save annotations to a temporary file with error handling
99
  try:
 
 
 
100
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt")
101
+ temp_file_path = temp_file.name
102
+ with open(temp_file_path, "w") as f:
103
+ f.write(yolo_annotations_str)
104
+ return temp_file_path
105
  except Exception as e:
106
+ raise RuntimeError(f"Failed to save annotations: {e}")
107
+
108
+ # Build the Gradio interface
109
+ with gr.Blocks(title="Latex2Layout Object Detection Visualization") as demo:
110
+ gr.Markdown("# Latex2Layout Object Detection Visualization")
111
+ gr.Markdown("Upload an image to detect objects using the Latex2Layout model. View the results with bounding boxes and download annotations in YOLO format.")
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  with gr.Row():
114
+ with gr.Column():
115
+ input_image = gr.Image(label="Upload Image", type="numpy")
116
+ detect_btn = gr.Button("Start Detection")
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ with gr.Column():
119
+ output_image = gr.Image(label="Detection Results")
120
+ yolo_annotations = gr.Textbox(label="YOLO Annotations", lines=10)
121
+ download_btn = gr.Button("Download YOLO Annotations")
122
+ download_file = gr.File(label="Download Annotations")
123
+
124
+ # Define button click events
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  detect_btn.click(
126
  fn=detect_and_visualize,
127
+ inputs=[input_image],
128
+ outputs=[output_image, yolo_annotations]
 
129
  )
 
130
  download_btn.click(
131
+ fn=save_yolo_annotations,
132
+ inputs=[yolo_annotations],
133
+ outputs=[download_file]
134
  )
135
 
136
  # Launch the application