eusholli commited on
Commit
9457a82
·
1 Parent(s): 8dff46f

refactored results_queue for detections

Browse files
Files changed (2) hide show
  1. app.py +14 -27
  2. object_detection.py +0 -420
app.py CHANGED
@@ -3,7 +3,6 @@ import tensorflow as tf
3
  import time
4
  import os
5
  import logging
6
- import queue
7
  from pathlib import Path
8
  from typing import List, NamedTuple
9
 
@@ -24,17 +23,6 @@ import requests
24
  from io import BytesIO # Import for handling byte streams
25
 
26
 
27
- # Named tuple to store detection results
28
- class Detection(NamedTuple):
29
- class_id: int
30
- label: str
31
- score: float
32
- box: np.ndarray
33
-
34
-
35
- # Queue to store detection results
36
- result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
37
-
38
  # CHANGE CODE BELOW HERE, USE TO REPLACE WITH YOUR WANTED ANALYSIS.
39
  # Update below string to set display title of analysis
40
 
@@ -112,7 +100,9 @@ def analyze_frame(frame: np.ndarray):
112
  # Store the execution time
113
  img_container["analysis_time"] = execution_time_ms
114
 
115
- result_queue.put(results) # Put the results in the result queue
 
 
116
  img_container["analyzed"] = frame # Store the analyzed frame
117
 
118
  return # End of the function
@@ -157,7 +147,8 @@ logging.getLogger("torch").setLevel(logging.ERROR)
157
  logging.getLogger("streamlit").setLevel(logging.ERROR)
158
 
159
  # Container to hold image data and analysis results
160
- img_container = {"input": None, "analyzed": None, "analysis_time": None}
 
161
 
162
  # Logger for debugging and information
163
  logger = logging.getLogger(__name__)
@@ -294,12 +285,6 @@ def analysis_init():
294
  # This function retrieves the latest frames and results from the global container and result queue,
295
  # and updates the placeholders in the Streamlit UI with the current input frame, analyzed frame, analysis time, and detected labels.
296
  def publish_frame():
297
- if not result_queue.empty():
298
- result = result_queue.get()
299
- if show_labels:
300
- labels_placeholder.table(
301
- result
302
- ) # Display labels if the checkbox is checked
303
 
304
  img = img_container["input"]
305
  if img is None:
@@ -318,6 +303,15 @@ def publish_frame():
318
  # Display the analysis time
319
  analysis_time.text(f"Analysis Time: {time} ms")
320
 
 
 
 
 
 
 
 
 
 
321
 
322
  # If the WebRTC streamer is playing, initialize and publish frames
323
  if webrtc_ctx.state.playing:
@@ -361,13 +355,6 @@ def process_video(video_path):
361
  ) # Analyze the frame for face detection and sentiment analysis
362
  publish_frame() # Publish the results
363
 
364
- if not result_queue.empty():
365
- result = result_queue.get()
366
- if show_labels:
367
- labels_placeholder.table(
368
- result
369
- ) # Display labels if the checkbox is checked
370
-
371
  cap.release() # Release the video capture object
372
 
373
 
 
3
  import time
4
  import os
5
  import logging
 
6
  from pathlib import Path
7
  from typing import List, NamedTuple
8
 
 
23
  from io import BytesIO # Import for handling byte streams
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
26
  # CHANGE CODE BELOW HERE, USE TO REPLACE WITH YOUR WANTED ANALYSIS.
27
  # Update below string to set display title of analysis
28
 
 
100
  # Store the execution time
101
  img_container["analysis_time"] = execution_time_ms
102
 
103
+ # store the detections
104
+ img_container["detections"] = results
105
+
106
  img_container["analyzed"] = frame # Store the analyzed frame
107
 
108
  return # End of the function
 
147
  logging.getLogger("streamlit").setLevel(logging.ERROR)
148
 
149
  # Container to hold image data and analysis results
150
+ img_container = {"input": None, "analyzed": None,
151
+ "analysis_time": None, "detections": None}
152
 
153
  # Logger for debugging and information
154
  logger = logging.getLogger(__name__)
 
285
  # This function retrieves the latest frames and results from the global container and result queue,
286
  # and updates the placeholders in the Streamlit UI with the current input frame, analyzed frame, analysis time, and detected labels.
287
  def publish_frame():
 
 
 
 
 
 
288
 
289
  img = img_container["input"]
290
  if img is None:
 
303
  # Display the analysis time
304
  analysis_time.text(f"Analysis Time: {time} ms")
305
 
306
+ detections = img_container["detections"]
307
+ if detections is None:
308
+ return
309
+
310
+ if show_labels:
311
+ labels_placeholder.table(
312
+ detections
313
+ ) # Display labels if the checkbox is checked
314
+
315
 
316
  # If the WebRTC streamer is playing, initialize and publish frames
317
  if webrtc_ctx.state.playing:
 
355
  ) # Analyze the frame for face detection and sentiment analysis
356
  publish_frame() # Publish the results
357
 
 
 
 
 
 
 
 
358
  cap.release() # Release the video capture object
359
 
360
 
object_detection.py DELETED
@@ -1,420 +0,0 @@
1
- import torch
2
- import tensorflow as tf
3
- import time
4
- import os
5
- import logging
6
- import queue
7
- from pathlib import Path
8
- from typing import List, NamedTuple
9
-
10
- import av
11
- import cv2
12
- import numpy as np
13
- import streamlit as st
14
- from streamlit_webrtc import WebRtcMode, webrtc_streamer
15
-
16
- from utils.download import download_file
17
- from utils.turn import get_ice_servers
18
-
19
- from PIL import Image, ImageDraw # Import PIL for image processing
20
- from transformers import pipeline # Import Hugging Face transformers pipeline
21
-
22
- import requests
23
- from io import BytesIO # Import for handling byte streams
24
-
25
-
26
- # Named tuple to store detection results
27
- class Detection(NamedTuple):
28
- class_id: int
29
- label: str
30
- score: float
31
- box: np.ndarray
32
-
33
-
34
- # Queue to store detection results
35
- result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
36
-
37
- # CHANGE CODE BELOW HERE, USE TO REPLACE WITH YOUR WANTED ANALYSIS.
38
- # Update below string to set display title of analysis
39
-
40
- # Appropriate imports needed for analysis
41
-
42
- MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel"
43
- MODEL_LOCAL_PATH = Path("./models/MobileNetSSD_deploy.caffemodel")
44
- PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt"
45
- PROTOTXT_LOCAL_PATH = Path("./models/MobileNetSSD_deploy.prototxt.txt")
46
-
47
- CLASSES = [
48
- "background",
49
- "aeroplane",
50
- "bicycle",
51
- "bird",
52
- "boat",
53
- "bottle",
54
- "bus",
55
- "car",
56
- "cat",
57
- "chair",
58
- "cow",
59
- "diningtable",
60
- "dog",
61
- "horse",
62
- "motorbike",
63
- "person",
64
- "pottedplant",
65
- "sheep",
66
- "sofa",
67
- "train",
68
- "tvmonitor",
69
- ]
70
-
71
- # Generate random colors for each class label
72
-
73
-
74
- def generate_label_colors():
75
- return np.random.uniform(0, 255, size=(len(CLASSES), 3))
76
-
77
-
78
- COLORS = generate_label_colors()
79
-
80
- # Download model and prototxt files
81
-
82
-
83
- def download_file(url, local_path, expected_size=None):
84
- if not local_path.exists() or (expected_size and local_path.stat().st_size != expected_size):
85
- import requests
86
- with open(local_path, "wb") as f:
87
- response = requests.get(url)
88
- f.write(response.content)
89
-
90
-
91
- download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
92
- download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
93
-
94
- # Load the model
95
- net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
96
-
97
-
98
- # Default title - "Facial Sentiment Analysis"
99
-
100
- ANALYSIS_TITLE = "Object Detection Analysis"
101
-
102
- # CHANGE THE CONTENTS OF THIS FUNCTION, USE TO REPLACE WITH YOUR WANTED ANALYSIS.
103
- #
104
-
105
- # Set analysis results in img_container and result queue for display
106
- # img_container["input"] - holds the input frame contents - of type np.ndarray
107
- # img_container["analyzed"] - holds the analyzed frame with any added annotations - of type np.ndarray
108
- # img_container["analysis_time"] - holds how long the analysis has taken in miliseconds
109
- # result_queue - holds the analysis metadata results - of type queue.Queue[List[Detection]]
110
-
111
-
112
- def analyze_frame(frame: np.ndarray):
113
- start_time = time.time() # Start timing the analysis
114
- img_container["input"] = frame # Store the input frame
115
- frame = frame.copy() # Create a copy of the frame to modify
116
-
117
- # Run inference
118
- blob = cv2.dnn.blobFromImage(
119
- cv2.resize(frame, (300, 300)), 0.007843, (300, 300), 127.5
120
- )
121
- net.setInput(blob)
122
- output = net.forward()
123
-
124
- h, w = frame.shape[:2]
125
-
126
- # Filter the detections based on the score threshold
127
- score_threshold = 0.5 # You can adjust the score threshold as needed
128
- output = output.squeeze() # (1, 1, N, 7) -> (N, 7)
129
- output = output[output[:, 2] >= score_threshold]
130
- detections = [
131
- Detection(
132
- class_id=int(detection[1]),
133
- label=CLASSES[int(detection[1])],
134
- score=float(detection[2]),
135
- box=(detection[3:7] * np.array([w, h, w, h])),
136
- )
137
- for detection in output
138
- ]
139
-
140
- # Render bounding boxes and captions
141
- for detection in detections:
142
- caption = f"{detection.label}: {round(detection.score * 100, 2)}%"
143
- color = COLORS[detection.class_id]
144
- xmin, ymin, xmax, ymax = detection.box.astype("int")
145
-
146
- cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), color, 2)
147
- cv2.putText(
148
- frame,
149
- caption,
150
- (xmin, ymin - 15 if ymin - 15 > 15 else ymin + 15),
151
- cv2.FONT_HERSHEY_SIMPLEX,
152
- 0.5,
153
- color,
154
- 2,
155
- )
156
-
157
- end_time = time.time() # End timing the analysis
158
- # Calculate execution time in milliseconds
159
- execution_time_ms = round((end_time - start_time) * 1000, 2)
160
- # Store the execution time
161
- img_container["analysis_time"] = execution_time_ms
162
-
163
- result_queue.put(detections) # Put the results in the result queue
164
- img_container["analyzed"] = frame # Store the analyzed frame
165
-
166
- return # End of the function
167
-
168
- #
169
- #
170
- # DO NOT TOUCH THE BELOW CODE (NOT NEEDED)
171
- #
172
- #
173
-
174
-
175
- # Suppress FFmpeg logs
176
- os.environ["FFMPEG_LOG_LEVEL"] = "quiet"
177
-
178
- # Suppress TensorFlow or PyTorch progress bars
179
-
180
- tf.get_logger().setLevel("ERROR")
181
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
182
-
183
- # Suppress PyTorch logs
184
-
185
- logging.getLogger().setLevel(logging.WARNING)
186
- torch.set_num_threads(1)
187
- logging.getLogger("torch").setLevel(logging.ERROR)
188
-
189
- # Suppress Streamlit logs using the logging module
190
- logging.getLogger("streamlit").setLevel(logging.ERROR)
191
-
192
- # Container to hold image data and analysis results
193
- img_container = {"input": None, "analyzed": None, "analysis_time": None}
194
-
195
- # Logger for debugging and information
196
- logger = logging.getLogger(__name__)
197
-
198
-
199
- # Callback function to process video frames
200
- # This function is called for each video frame in the WebRTC stream.
201
- # It converts the frame to a numpy array in RGB format, analyzes the frame,
202
- # and returns the original frame.
203
- def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
204
- # Convert frame to numpy array in RGB format
205
- img = frame.to_ndarray(format="rgb24")
206
- analyze_frame(img) # Analyze the frame
207
- return frame # Return the original frame
208
-
209
-
210
- # Get ICE servers for WebRTC
211
- ice_servers = get_ice_servers()
212
-
213
- # Streamlit UI configuration
214
- st.set_page_config(layout="wide")
215
-
216
- # Custom CSS for the Streamlit page
217
- st.markdown(
218
- """
219
- <style>
220
- .main {
221
- padding: 2rem;
222
- }
223
- h1, h2, h3 {
224
- font-family: 'Arial', sans-serif;
225
- }
226
- h1 {
227
- font-weight: 700;
228
- font-size: 2.5rem;
229
- }
230
- h2 {
231
- font-weight: 600;
232
- font-size: 2rem;
233
- }
234
- h3 {
235
- font-weight: 500;
236
- font-size: 1.5rem;
237
- }
238
- </style>
239
- """,
240
- unsafe_allow_html=True,
241
- )
242
-
243
- # Streamlit page title and subtitle
244
- st.title("Computer Vision Playground")
245
-
246
- # Add a link to the README file
247
- st.markdown(
248
- """
249
- <div style="text-align: left;">
250
- <p>See the <a href="https://huggingface.co/spaces/eusholli/sentiment-analyzer/blob/main/README.md"
251
- target="_blank">README</a> to learn how to use this code to help you start your computer vision exploration.</p>
252
- </div>
253
- """,
254
- unsafe_allow_html=True,
255
- )
256
-
257
- st.subheader(ANALYSIS_TITLE)
258
-
259
- # Columns for input and output streams
260
- col1, col2 = st.columns(2)
261
-
262
- with col1:
263
- st.header("Input Stream")
264
- st.subheader("input")
265
- # WebRTC streamer to get video input from the webcam
266
- webrtc_ctx = webrtc_streamer(
267
- key="input-webcam",
268
- mode=WebRtcMode.SENDRECV,
269
- rtc_configuration=ice_servers,
270
- video_frame_callback=video_frame_callback,
271
- media_stream_constraints={"video": True, "audio": False},
272
- async_processing=True,
273
- )
274
-
275
- # File uploader for images
276
- st.subheader("Upload an Image")
277
- uploaded_file = st.file_uploader(
278
- "Choose an image...", type=["jpg", "jpeg", "png"])
279
-
280
- # Text input for image URL
281
- st.subheader("Or Enter Image URL")
282
- image_url = st.text_input("Image URL")
283
-
284
- # File uploader for videos
285
- st.subheader("Upload a Video")
286
- uploaded_video = st.file_uploader(
287
- "Choose a video...", type=["mp4", "avi", "mov", "mkv"]
288
- )
289
-
290
- # Text input for video URL
291
- st.subheader("Or Enter Video Download URL")
292
- video_url = st.text_input("Video URL")
293
-
294
- # Streamlit footer
295
- st.markdown(
296
- """
297
- <div style="text-align: center; margin-top: 2rem;">
298
- <p>If you want to set up your own computer vision playground see <a href="https://huggingface.co/spaces/eusholli/computer-vision-playground/blob/main/README.md" target="_blank">here</a>.</p>
299
- </div>
300
- """,
301
- unsafe_allow_html=True
302
- )
303
-
304
- # Function to initialize the analysis UI
305
- # This function sets up the placeholders and UI elements in the analysis section.
306
- # It creates placeholders for input and output frames, analysis time, and detected labels.
307
-
308
-
309
- def analysis_init():
310
- global analysis_time, show_labels, labels_placeholder, input_placeholder, output_placeholder
311
-
312
- with col2:
313
- st.header("Analysis")
314
- st.subheader("Input Frame")
315
- input_placeholder = st.empty() # Placeholder for input frame
316
-
317
- st.subheader("Output Frame")
318
- output_placeholder = st.empty() # Placeholder for output frame
319
- analysis_time = st.empty() # Placeholder for analysis time
320
- show_labels = st.checkbox(
321
- "Show the detected labels", value=True
322
- ) # Checkbox to show/hide labels
323
- labels_placeholder = st.empty() # Placeholder for labels
324
-
325
-
326
- # Function to publish frames and results to the Streamlit UI
327
- # This function retrieves the latest frames and results from the global container and result queue,
328
- # and updates the placeholders in the Streamlit UI with the current input frame, analyzed frame, analysis time, and detected labels.
329
- def publish_frame():
330
- if not result_queue.empty():
331
- result = result_queue.get()
332
- if show_labels:
333
- labels_placeholder.table(
334
- result
335
- ) # Display labels if the checkbox is checked
336
-
337
- img = img_container["input"]
338
- if img is None:
339
- return
340
- input_placeholder.image(img, channels="RGB") # Display the input frame
341
-
342
- analyzed = img_container["analyzed"]
343
- if analyzed is None:
344
- return
345
- # Display the analyzed frame
346
- output_placeholder.image(analyzed, channels="RGB")
347
-
348
- time = img_container["analysis_time"]
349
- if time is None:
350
- return
351
- # Display the analysis time
352
- analysis_time.text(f"Analysis Time: {time} ms")
353
-
354
-
355
- # If the WebRTC streamer is playing, initialize and publish frames
356
- if webrtc_ctx.state.playing:
357
- analysis_init() # Initialize the analysis UI
358
- while True:
359
- publish_frame() # Publish the frames and results
360
- time.sleep(0.1) # Delay to control frame rate
361
-
362
-
363
- # If an image is uploaded or a URL is provided, process the image
364
- if uploaded_file is not None or image_url:
365
- analysis_init() # Initialize the analysis UI
366
-
367
- if uploaded_file is not None:
368
- image = Image.open(uploaded_file) # Open the uploaded image
369
- img = np.array(image.convert("RGB")) # Convert the image to RGB format
370
- else:
371
- response = requests.get(image_url) # Download the image from the URL
372
- # Open the downloaded image
373
- image = Image.open(BytesIO(response.content))
374
- img = np.array(image.convert("RGB")) # Convert the image to RGB format
375
-
376
- analyze_frame(img) # Analyze the image
377
- publish_frame() # Publish the results
378
-
379
-
380
- # Function to process video files
381
- # This function reads frames from a video file, analyzes each frame for face detection and sentiment analysis,
382
- # and updates the Streamlit UI with the current input frame, analyzed frame, and detected labels.
383
- def process_video(video_path):
384
- cap = cv2.VideoCapture(video_path) # Open the video file
385
- while cap.isOpened():
386
- ret, frame = cap.read() # Read a frame from the video
387
- if not ret:
388
- break # Exit the loop if no more frames are available
389
-
390
- # Display the current frame as the input frame
391
- input_placeholder.image(frame)
392
- analyze_frame(
393
- frame
394
- ) # Analyze the frame for face detection and sentiment analysis
395
- publish_frame() # Publish the results
396
-
397
- if not result_queue.empty():
398
- result = result_queue.get()
399
- if show_labels:
400
- labels_placeholder.table(
401
- result
402
- ) # Display labels if the checkbox is checked
403
-
404
- cap.release() # Release the video capture object
405
-
406
-
407
- # If a video is uploaded or a URL is provided, process the video
408
- if uploaded_video is not None or video_url:
409
- analysis_init() # Initialize the analysis UI
410
-
411
- if uploaded_video is not None:
412
- video_path = uploaded_video.name # Get the name of the uploaded video
413
- with open(video_path, "wb") as f:
414
- # Save the uploaded video to a file
415
- f.write(uploaded_video.getbuffer())
416
- else:
417
- # Download the video from the URL
418
- video_path = download_file(video_url)
419
-
420
- process_video(video_path) # Process the video