ChaseHan commited on
Commit
1ff383d
·
verified ·
1 Parent(s): 3ab79b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -105
app.py CHANGED
@@ -2,135 +2,206 @@ import gradio as gr
2
  import cv2
3
  import numpy as np
4
  import os
5
- import tempfile
6
- from ultralytics import YOLO
 
 
 
 
7
 
8
- # Define the model path for Latex2Layout
9
- model_path = "latex2layout_object_detection_yolov8.pt"
10
 
11
- # Check if the model file exists before loading
12
- if not os.path.exists(model_path):
13
- raise FileNotFoundError(f"Model file not found at {model_path}")
14
 
15
- # Load the Latex2Layout model
16
- try:
17
- model = YOLO(model_path)
18
- except Exception as e:
19
- raise RuntimeError(f"Failed to load Latex2Layout model: {e}")
20
-
21
- def detect_and_visualize(image):
22
  """
23
- Perform object detection on the uploaded image and visualize the results.
24
-
25
  Args:
26
- image: The uploaded image as a numpy array.
27
-
28
  Returns:
29
- annotated_image: Image with bounding boxes drawn.
30
- yolo_annotations: Annotations in YOLO format as a string.
31
  """
32
- # Validate input image
33
- if image is None or not isinstance(image, np.ndarray):
34
- raise ValueError("Invalid image input: Please upload a valid image.")
35
-
36
- # Run object detection with error handling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  try:
38
- results = model(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  except Exception as e:
40
- raise RuntimeError(f"Error during Latex2Layout detection: {e}")
41
-
42
- # Extract results from the first frame
43
- result = results[0]
44
- annotated_image = image.copy()
45
- yolo_annotations = []
46
-
47
- # Get image dimensions
48
- img_height, img_width = image.shape[:2]
49
-
50
- # Process each detected object
51
- for box in result.boxes:
52
- # Extract bounding box coordinates
53
- x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
54
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
55
-
56
- # Get confidence and class details
57
- conf = float(box.conf[0])
58
- cls_id = int(box.cls[0])
59
- cls_name = result.names[cls_id]
60
-
61
- # Assign a random color to the class
62
- color = tuple(np.random.randint(0, 255, 3).tolist())
63
 
64
- # Draw bounding box on the image
65
- cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
66
-
67
- # Create and draw label with confidence
68
- label = f"{cls_name} {conf:.2f}"
69
- (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
70
- cv2.rectangle(annotated_image, (x1, y1 - label_height - 5), (x1 + label_width, y1), color, -1)
71
- cv2.putText(annotated_image, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
72
-
73
- # Convert bounding box to YOLO format (normalized coordinates)
74
- x_center = (x1 + x2) / (2 * img_width)
75
- y_center = (y1 + y2) / (2 * img_height)
76
- width = (x2 - x1) / img_width
77
- height = (y2 - y1) / img_height
78
- yolo_annotations.append(f"{cls_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
79
-
80
- # Combine annotations into a single string
81
- yolo_annotations_str = "\n".join(yolo_annotations) if yolo_annotations else "No objects detected."
82
- return annotated_image, yolo_annotations_str
83
-
84
- def save_yolo_annotations(yolo_annotations_str):
85
  """
86
- Save YOLO annotations to a temporary file and return its path.
87
-
88
  Args:
89
- yolo_annotations_str: Annotations string in YOLO format.
90
-
 
 
 
91
  Returns:
92
- file_path: Path to the saved annotation file.
93
  """
94
- # Handle empty annotations
95
- if not yolo_annotations_str or yolo_annotations_str == "No objects detected.":
96
- raise ValueError("No annotations available to save.")
97
-
98
- # Save annotations to a temporary file with error handling
 
 
 
 
99
  try:
100
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt")
101
- temp_file_path = temp_file.name
102
- with open(temp_file_path, "w") as f:
103
- f.write(yolo_annotations_str)
104
- return temp_file_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  except Exception as e:
106
- raise RuntimeError(f"Failed to save annotations: {e}")
107
-
108
- # Build the Gradio interface
109
- with gr.Blocks(title="Latex2Layout Object Detection Visualization") as demo:
110
- gr.Markdown("# Latex2Layout Object Detection Visualization")
111
- gr.Markdown("Upload an image to detect objects using the Latex2Layout model. View the results with bounding boxes and download annotations in YOLO format.")
112
 
 
 
 
 
 
113
  with gr.Row():
114
- with gr.Column():
115
  input_image = gr.Image(label="Upload Image", type="numpy")
116
- detect_btn = gr.Button("Start Detection")
 
117
 
118
- with gr.Column():
119
  output_image = gr.Image(label="Detection Results")
120
- yolo_annotations = gr.Textbox(label="YOLO Annotations", lines=10)
121
- download_btn = gr.Button("Download YOLO Annotations")
122
- download_file = gr.File(label="Download Annotations")
123
-
124
- # Define button click events
 
 
 
 
 
 
 
 
 
 
 
125
  detect_btn.click(
126
- fn=detect_and_visualize,
127
  inputs=[input_image],
128
- outputs=[output_image, yolo_annotations]
129
  )
130
- download_btn.click(
131
- fn=save_yolo_annotations,
132
- inputs=[yolo_annotations],
133
- outputs=[download_file]
 
134
  )
135
 
136
  # Launch the application
 
2
  import cv2
3
  import numpy as np
4
  import os
5
+ import requests
6
+ import json
7
+ from PIL import Image
8
+ import io
9
+ import base64
10
+ from openai import OpenAI
11
 
12
+ # API endpoints
13
+ YOLO_API_ENDPOINT = "https://api.example.com/yolo" # Replace with actual YOLO API endpoint
14
 
15
+ # Qwen API configuration
16
+ QWEN_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
17
+ QWEN_MODEL_ID = "qwen2.5-vl-3b-instruct"
18
 
19
+ def encode_image(image_array):
 
 
 
 
 
 
20
  """
21
+ Encode numpy array image to base64 string.
22
+
23
  Args:
24
+ image_array: numpy array of the image
25
+
26
  Returns:
27
+ base64 encoded string of the image
 
28
  """
29
+ # Convert numpy array to PIL Image
30
+ pil_image = Image.fromarray(image_array)
31
+
32
+ # Convert PIL Image to bytes
33
+ img_byte_arr = io.BytesIO()
34
+ pil_image.save(img_byte_arr, format='PNG')
35
+ img_byte_arr = img_byte_arr.getvalue()
36
+
37
+ # Encode to base64
38
+ return base64.b64encode(img_byte_arr).decode("utf-8")
39
+
40
+ def detect_layout(image):
41
+ """
42
+ Perform layout detection on the uploaded image using YOLO API.
43
+
44
+ Args:
45
+ image: The uploaded image as a numpy array
46
+
47
+ Returns:
48
+ annotated_image: Image with detection boxes
49
+ layout_info: Layout detection results
50
+ """
51
+ if image is None:
52
+ return None, "Error: No image uploaded."
53
+
54
+ # Convert numpy array to PIL Image
55
+ pil_image = Image.fromarray(image)
56
+
57
+ # Convert PIL Image to bytes for API request
58
+ img_byte_arr = io.BytesIO()
59
+ pil_image.save(img_byte_arr, format='PNG')
60
+ img_byte_arr = img_byte_arr.getvalue()
61
+
62
+ # Prepare API request
63
+ files = {'image': ('image.png', img_byte_arr, 'image/png')}
64
+
65
  try:
66
+ # Call YOLO API
67
+ response = requests.post(YOLO_API_ENDPOINT, files=files)
68
+ response.raise_for_status()
69
+ detection_results = response.json()
70
+
71
+ # Create a copy of the image for visualization
72
+ annotated_image = image.copy()
73
+
74
+ # Draw detection results
75
+ for detection in detection_results:
76
+ x1, y1, x2, y2 = detection['bbox']
77
+ cls_name = detection['class']
78
+ conf = detection['confidence']
79
+
80
+ # Generate a color for each class
81
+ color = tuple(np.random.randint(0, 255, 3).tolist())
82
+
83
+ # Draw bounding box and label
84
+ cv2.rectangle(annotated_image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
85
+ label = f'{cls_name} {conf:.2f}'
86
+ (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
87
+ cv2.rectangle(annotated_image, (int(x1), int(y1)-label_height-5), (int(x1)+label_width, int(y1)), color, -1)
88
+ cv2.putText(annotated_image, label, (int(x1), int(y1)-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
89
+
90
+ # Format layout information for Qwen
91
+ layout_info = json.dumps(detection_results, indent=2)
92
+
93
+ return annotated_image, layout_info
94
+
95
  except Exception as e:
96
+ return None, f"Error during layout detection: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ def qa_about_layout(image, question, layout_info, api_key):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  """
100
+ Answer questions about the layout using Qwen2.5-VL API.
101
+
102
  Args:
103
+ image: The uploaded image
104
+ question: User's question about the layout
105
+ layout_info: Layout detection results from YOLO
106
+ api_key: User's Qwen API key
107
+
108
  Returns:
109
+ answer: Qwen's answer to the question
110
  """
111
+ if image is None or not question:
112
+ return "Please upload an image and ask a question."
113
+
114
+ if not layout_info:
115
+ return "No layout information available. Please detect layout first."
116
+
117
+ if not api_key:
118
+ return "Please enter your Qwen API key."
119
+
120
  try:
121
+ # Encode image to base64
122
+ base64_image = encode_image(image)
123
+
124
+ # Initialize OpenAI client for Qwen API
125
+ client = OpenAI(
126
+ api_key=api_key,
127
+ base_url=QWEN_BASE_URL,
128
+ )
129
+
130
+ # Prepare system prompt with layout information
131
+ system_prompt = f"""You are a helpful assistant specialized in analyzing document layouts.
132
+ The following layout information has been detected in the image:
133
+ {layout_info}
134
+
135
+ Please answer questions about the layout based on this information and the image."""
136
+
137
+ # Prepare messages for API call
138
+ messages = [
139
+ {
140
+ "role": "system",
141
+ "content": [{"type": "text", "text": system_prompt}]
142
+ },
143
+ {
144
+ "role": "user",
145
+ "content": [
146
+ {
147
+ "type": "image_url",
148
+ "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
149
+ },
150
+ {"type": "text", "text": question},
151
+ ],
152
+ }
153
+ ]
154
+
155
+ # Call Qwen API
156
+ completion = client.chat.completions.create(
157
+ model=QWEN_MODEL_ID,
158
+ messages=messages,
159
+ )
160
+
161
+ return completion.choices[0].message.content
162
+
163
  except Exception as e:
164
+ return f"Error during QA: {str(e)}"
 
 
 
 
 
165
 
166
+ # Create Gradio interface
167
+ with gr.Blocks(title="Latex2Layout QA System") as demo:
168
+ gr.Markdown("# Latex2Layout QA System")
169
+ gr.Markdown("Upload an image, detect layout elements, and ask questions about the layout.")
170
+
171
  with gr.Row():
172
+ with gr.Column(scale=1):
173
  input_image = gr.Image(label="Upload Image", type="numpy")
174
+ detect_btn = gr.Button("Detect Layout")
175
+ gr.Markdown("**Tip**: Upload a clear image for optimal detection results.")
176
 
177
+ with gr.Column(scale=1):
178
  output_image = gr.Image(label="Detection Results")
179
+ layout_info = gr.Textbox(label="Layout Information", lines=10)
180
+
181
+ with gr.Row():
182
+ with gr.Column(scale=1):
183
+ api_key_input = gr.Textbox(
184
+ label="Qwen API Key",
185
+ placeholder="Enter your Qwen API key here",
186
+ type="password"
187
+ )
188
+ question_input = gr.Textbox(label="Ask a question about the layout")
189
+ qa_btn = gr.Button("Ask Question")
190
+
191
+ with gr.Column(scale=1):
192
+ answer_output = gr.Textbox(label="Answer", lines=5)
193
+
194
+ # Event handlers
195
  detect_btn.click(
196
+ fn=detect_layout,
197
  inputs=[input_image],
198
+ outputs=[output_image, layout_info]
199
  )
200
+
201
+ qa_btn.click(
202
+ fn=qa_about_layout,
203
+ inputs=[input_image, question_input, layout_info, api_key_input],
204
+ outputs=[answer_output]
205
  )
206
 
207
  # Launch the application