eusholli commited on
Commit
3332470
1 Parent(s): 6234614

updated to follow analyze_frame function

Browse files
Files changed (2) hide show
  1. .gitignore +4 -0
  2. object_detection.py +322 -63
.gitignore CHANGED
@@ -116,3 +116,7 @@ dmypy.json
116
 
117
  # MacOS specific
118
  .DS_Store
 
 
 
 
 
116
 
117
  # MacOS specific
118
  .DS_Store
119
+
120
+ # Keep empty models dir
121
+ models/*
122
+ !models/.gitkeep
object_detection.py CHANGED
@@ -1,8 +1,7 @@
1
- """Object detection demo with MobileNet SSD.
2
- This model and code are based on
3
- https://github.com/robmarkcole/object-detection-app
4
- """
5
-
6
  import logging
7
  import queue
8
  from pathlib import Path
@@ -14,19 +13,36 @@ import numpy as np
14
  import streamlit as st
15
  from streamlit_webrtc import WebRtcMode, webrtc_streamer
16
 
17
- from utils.download import download_file
18
  from utils.turn import get_ice_servers
19
 
20
- HERE = Path(__file__).parent
21
- ROOT = HERE.parent
22
 
23
- logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
25
 
26
- MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" # noqa: E501
27
- MODEL_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.caffemodel"
28
- PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" # noqa: E501
29
- PROTOTXT_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.prototxt.txt"
 
 
30
 
31
  CLASSES = [
32
  "background",
@@ -52,55 +68,63 @@ CLASSES = [
52
  "tvmonitor",
53
  ]
54
 
55
-
56
- class Detection(NamedTuple):
57
- class_id: int
58
- label: str
59
- score: float
60
- box: np.ndarray
61
 
62
 
63
- @st.cache_resource # type: ignore
64
  def generate_label_colors():
65
  return np.random.uniform(0, 255, size=(len(CLASSES), 3))
66
 
67
 
68
  COLORS = generate_label_colors()
69
 
 
 
 
 
 
 
 
 
 
 
 
70
  download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
71
  download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
72
 
 
 
73
 
74
- # Session-specific caching
75
- cache_key = "object_detection_dnn"
76
- if cache_key in st.session_state:
77
- net = st.session_state[cache_key]
78
- else:
79
- net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
80
- st.session_state[cache_key] = net
81
 
82
- score_threshold = st.slider("Score threshold", 0.0, 1.0, 0.5, 0.05)
83
 
84
- # NOTE: The callback will be called in another thread,
85
- # so use a queue here for thread-safety to pass the data
86
- # from inside to outside the callback.
87
- # TODO: A general-purpose shared state object may be more useful.
88
- result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
89
 
 
 
 
 
 
 
 
 
90
 
91
- def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
92
- image = frame.to_ndarray(format="bgr24")
 
 
 
93
 
94
  # Run inference
95
  blob = cv2.dnn.blobFromImage(
96
- cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
97
  )
98
  net.setInput(blob)
99
  output = net.forward()
100
 
101
- h, w = image.shape[:2]
102
 
103
- # Convert the output array into a structured form.
 
104
  output = output.squeeze() # (1, 1, N, 7) -> (N, 7)
105
  output = output[output[:, 2] >= score_threshold]
106
  detections = [
@@ -119,9 +143,9 @@ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
119
  color = COLORS[detection.class_id]
120
  xmin, ymin, xmax, ymax = detection.box.astype("int")
121
 
122
- cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2)
123
  cv2.putText(
124
- image,
125
  caption,
126
  (xmin, ymin - 15 if ymin - 15 > 15 else ymin + 15),
127
  cv2.FONT_HERSHEY_SIMPLEX,
@@ -130,35 +154,270 @@ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
130
  2,
131
  )
132
 
133
- result_queue.put(detections)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- return av.VideoFrame.from_ndarray(image, format="bgr24")
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  ice_servers = get_ice_servers()
138
 
139
- webrtc_ctx = webrtc_streamer(
140
- key="object-detection",
141
- mode=WebRtcMode.SENDRECV,
142
- rtc_configuration=ice_servers,
143
- video_frame_callback=video_frame_callback,
144
- media_stream_constraints={"video": True, "audio": False},
145
- async_processing=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  )
147
 
148
- if st.checkbox("Show the detected labels", value=True):
149
- if webrtc_ctx.state.playing:
150
- labels_placeholder = st.empty()
151
- # NOTE: The video transformation with object detection and
152
- # this loop displaying the result labels are running
153
- # in different threads asynchronously.
154
- # Then the rendered video frames and the labels displayed here
155
- # are not strictly synchronized.
156
- while True:
157
- result = result_queue.get()
158
- labels_placeholder.table(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
 
 
 
 
 
160
  st.markdown(
161
- "This demo uses a model and code from "
162
- "https://github.com/robmarkcole/object-detection-app. "
163
- "Many thanks to the project."
 
 
 
164
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tensorflow as tf
3
+ import time
4
+ import os
 
5
  import logging
6
  import queue
7
  from pathlib import Path
 
13
  import streamlit as st
14
  from streamlit_webrtc import WebRtcMode, webrtc_streamer
15
 
16
+ from utils.download import download_file
17
  from utils.turn import get_ice_servers
18
 
19
+ from PIL import Image, ImageDraw # Import PIL for image processing
20
+ from transformers import pipeline # Import Hugging Face transformers pipeline
21
 
22
+ import requests
23
+ from io import BytesIO # Import for handling byte streams
24
+
25
+
26
+ # Named tuple to store detection results
27
+ class Detection(NamedTuple):
28
+ class_id: int
29
+ label: str
30
+ score: float
31
+ box: np.ndarray
32
+
33
+
34
+ # Queue to store detection results
35
+ result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
36
 
37
+ # CHANGE CODE BELOW HERE, USE TO REPLACE WITH YOUR WANTED ANALYSIS.
38
+ # Update below string to set display title of analysis
39
 
40
+ # Appropriate imports needed for analysis
41
+
42
+ MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel"
43
+ MODEL_LOCAL_PATH = Path("./models/MobileNetSSD_deploy.caffemodel")
44
+ PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt"
45
+ PROTOTXT_LOCAL_PATH = Path("./models/MobileNetSSD_deploy.prototxt.txt")
46
 
47
  CLASSES = [
48
  "background",
 
68
  "tvmonitor",
69
  ]
70
 
71
+ # Generate random colors for each class label
 
 
 
 
 
72
 
73
 
 
74
  def generate_label_colors():
75
  return np.random.uniform(0, 255, size=(len(CLASSES), 3))
76
 
77
 
78
  COLORS = generate_label_colors()
79
 
80
+ # Download model and prototxt files
81
+
82
+
83
+ def download_file(url, local_path, expected_size=None):
84
+ if not local_path.exists() or (expected_size and local_path.stat().st_size != expected_size):
85
+ import requests
86
+ with open(local_path, "wb") as f:
87
+ response = requests.get(url)
88
+ f.write(response.content)
89
+
90
+
91
  download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
92
  download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
93
 
94
+ # Load the model
95
+ net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
96
 
 
 
 
 
 
 
 
97
 
98
+ # Default title - "Facial Sentiment Analysis"
99
 
100
+ ANALYSIS_TITLE = "Object Detection Analysis"
 
 
 
 
101
 
102
+ # CHANGE THE CONTENTS OF THIS FUNCTION, USE TO REPLACE WITH YOUR WANTED ANALYSIS.
103
+ #
104
+
105
+ # Set analysis results in img_container and result queue for display
106
+ # img_container["input"] - holds the input frame contents - of type np.ndarray
107
+ # img_container["analyzed"] - holds the analyzed frame with any added annotations - of type np.ndarray
108
+ # img_container["analysis_time"] - holds how long the analysis has taken in miliseconds
109
+ # result_queue - holds the analysis metadata results - of type queue.Queue[List[Detection]]
110
 
111
+
112
+ def analyze_frame(frame: np.ndarray):
113
+ start_time = time.time() # Start timing the analysis
114
+ img_container["input"] = frame # Store the input frame
115
+ frame = frame.copy() # Create a copy of the frame to modify
116
 
117
  # Run inference
118
  blob = cv2.dnn.blobFromImage(
119
+ cv2.resize(frame, (300, 300)), 0.007843, (300, 300), 127.5
120
  )
121
  net.setInput(blob)
122
  output = net.forward()
123
 
124
+ h, w = frame.shape[:2]
125
 
126
+ # Filter the detections based on the score threshold
127
+ score_threshold = 0.5 # You can adjust the score threshold as needed
128
  output = output.squeeze() # (1, 1, N, 7) -> (N, 7)
129
  output = output[output[:, 2] >= score_threshold]
130
  detections = [
 
143
  color = COLORS[detection.class_id]
144
  xmin, ymin, xmax, ymax = detection.box.astype("int")
145
 
146
+ cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), color, 2)
147
  cv2.putText(
148
+ frame,
149
  caption,
150
  (xmin, ymin - 15 if ymin - 15 > 15 else ymin + 15),
151
  cv2.FONT_HERSHEY_SIMPLEX,
 
154
  2,
155
  )
156
 
157
+ end_time = time.time() # End timing the analysis
158
+ # Calculate execution time in milliseconds
159
+ execution_time_ms = round((end_time - start_time) * 1000, 2)
160
+ # Store the execution time
161
+ img_container["analysis_time"] = execution_time_ms
162
+
163
+ result_queue.put(detections) # Put the results in the result queue
164
+ img_container["analyzed"] = frame # Store the analyzed frame
165
+
166
+ return # End of the function
167
+
168
+ #
169
+ #
170
+ # DO NOT TOUCH THE BELOW CODE (NOT NEEDED)
171
+ #
172
+ #
173
+
174
+
175
+ # Suppress FFmpeg logs
176
+ os.environ["FFMPEG_LOG_LEVEL"] = "quiet"
177
 
178
+ # Suppress TensorFlow or PyTorch progress bars
179
 
180
+ tf.get_logger().setLevel("ERROR")
181
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
182
+
183
+ # Suppress PyTorch logs
184
+
185
+ logging.getLogger().setLevel(logging.WARNING)
186
+ torch.set_num_threads(1)
187
+ logging.getLogger("torch").setLevel(logging.ERROR)
188
+
189
+ # Suppress Streamlit logs using the logging module
190
+ logging.getLogger("streamlit").setLevel(logging.ERROR)
191
+
192
+ # Container to hold image data and analysis results
193
+ img_container = {"input": None, "analyzed": None, "analysis_time": None}
194
+
195
+ # Initialize MTCNN for face detection
196
+ mtcnn = MTCNN()
197
+
198
+ # Logger for debugging and information
199
+ logger = logging.getLogger(__name__)
200
+
201
+
202
+ # Callback function to process video frames
203
+ # This function is called for each video frame in the WebRTC stream.
204
+ # It converts the frame to a numpy array in RGB format, analyzes the frame,
205
+ # and returns the original frame.
206
+ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
207
+ # Convert frame to numpy array in RGB format
208
+ img = frame.to_ndarray(format="rgb24")
209
+ analyze_frame(img) # Analyze the frame
210
+ return frame # Return the original frame
211
+
212
+
213
+ # Get ICE servers for WebRTC
214
  ice_servers = get_ice_servers()
215
 
216
+ # Streamlit UI configuration
217
+ st.set_page_config(layout="wide")
218
+
219
+ # Custom CSS for the Streamlit page
220
+ st.markdown(
221
+ """
222
+ <style>
223
+ .main {
224
+ padding: 2rem;
225
+ }
226
+ h1, h2, h3 {
227
+ font-family: 'Arial', sans-serif;
228
+ }
229
+ h1 {
230
+ font-weight: 700;
231
+ font-size: 2.5rem;
232
+ }
233
+ h2 {
234
+ font-weight: 600;
235
+ font-size: 2rem;
236
+ }
237
+ h3 {
238
+ font-weight: 500;
239
+ font-size: 1.5rem;
240
+ }
241
+ </style>
242
+ """,
243
+ unsafe_allow_html=True,
244
  )
245
 
246
+ # Streamlit page title and subtitle
247
+ st.title("Computer Vision Playground")
248
+
249
+ # Add a link to the README file
250
+ st.markdown(
251
+ """
252
+ <div style="text-align: left;">
253
+ <p>See the <a href="https://huggingface.co/spaces/eusholli/sentiment-analyzer/blob/main/README.md"
254
+ target="_blank">README</a> to learn how to use this code to help you start your computer vision exploration.</p>
255
+ </div>
256
+ """,
257
+ unsafe_allow_html=True,
258
+ )
259
+
260
+ st.subheader(ANALYSIS_TITLE)
261
+
262
+ # Columns for input and output streams
263
+ col1, col2 = st.columns(2)
264
+
265
+ with col1:
266
+ st.header("Input Stream")
267
+ st.subheader("input")
268
+ # WebRTC streamer to get video input from the webcam
269
+ webrtc_ctx = webrtc_streamer(
270
+ key="input-webcam",
271
+ mode=WebRtcMode.SENDRECV,
272
+ rtc_configuration=ice_servers,
273
+ video_frame_callback=video_frame_callback,
274
+ media_stream_constraints={"video": True, "audio": False},
275
+ async_processing=True,
276
+ )
277
+
278
+ # File uploader for images
279
+ st.subheader("Upload an Image")
280
+ uploaded_file = st.file_uploader(
281
+ "Choose an image...", type=["jpg", "jpeg", "png"])
282
+
283
+ # Text input for image URL
284
+ st.subheader("Or Enter Image URL")
285
+ image_url = st.text_input("Image URL")
286
+
287
+ # File uploader for videos
288
+ st.subheader("Upload a Video")
289
+ uploaded_video = st.file_uploader(
290
+ "Choose a video...", type=["mp4", "avi", "mov", "mkv"]
291
+ )
292
 
293
+ # Text input for video URL
294
+ st.subheader("Or Enter Video Download URL")
295
+ video_url = st.text_input("Video URL")
296
+
297
+ # Streamlit footer
298
  st.markdown(
299
+ """
300
+ <div style="text-align: center; margin-top: 2rem;">
301
+ <p>If you want to set up your own computer vision playground see <a href="https://huggingface.co/spaces/eusholli/computer-vision-playground/blob/main/README.md" target="_blank">here</a>.</p>
302
+ </div>
303
+ """,
304
+ unsafe_allow_html=True
305
  )
306
+
307
+ # Function to initialize the analysis UI
308
+ # This function sets up the placeholders and UI elements in the analysis section.
309
+ # It creates placeholders for input and output frames, analysis time, and detected labels.
310
+
311
+
312
+ def analysis_init():
313
+ global analysis_time, show_labels, labels_placeholder, input_placeholder, output_placeholder
314
+
315
+ with col2:
316
+ st.header("Analysis")
317
+ st.subheader("Input Frame")
318
+ input_placeholder = st.empty() # Placeholder for input frame
319
+
320
+ st.subheader("Output Frame")
321
+ output_placeholder = st.empty() # Placeholder for output frame
322
+ analysis_time = st.empty() # Placeholder for analysis time
323
+ show_labels = st.checkbox(
324
+ "Show the detected labels", value=True
325
+ ) # Checkbox to show/hide labels
326
+ labels_placeholder = st.empty() # Placeholder for labels
327
+
328
+
329
+ # Function to publish frames and results to the Streamlit UI
330
+ # This function retrieves the latest frames and results from the global container and result queue,
331
+ # and updates the placeholders in the Streamlit UI with the current input frame, analyzed frame, analysis time, and detected labels.
332
+ def publish_frame():
333
+ if not result_queue.empty():
334
+ result = result_queue.get()
335
+ if show_labels:
336
+ labels_placeholder.table(
337
+ result
338
+ ) # Display labels if the checkbox is checked
339
+
340
+ img = img_container["input"]
341
+ if img is None:
342
+ return
343
+ input_placeholder.image(img, channels="RGB") # Display the input frame
344
+
345
+ analyzed = img_container["analyzed"]
346
+ if analyzed is None:
347
+ return
348
+ # Display the analyzed frame
349
+ output_placeholder.image(analyzed, channels="RGB")
350
+
351
+ time = img_container["analysis_time"]
352
+ if time is None:
353
+ return
354
+ # Display the analysis time
355
+ analysis_time.text(f"Analysis Time: {time} ms")
356
+
357
+
358
+ # If the WebRTC streamer is playing, initialize and publish frames
359
+ if webrtc_ctx.state.playing:
360
+ analysis_init() # Initialize the analysis UI
361
+ while True:
362
+ publish_frame() # Publish the frames and results
363
+ time.sleep(0.1) # Delay to control frame rate
364
+
365
+
366
+ # If an image is uploaded or a URL is provided, process the image
367
+ if uploaded_file is not None or image_url:
368
+ analysis_init() # Initialize the analysis UI
369
+
370
+ if uploaded_file is not None:
371
+ image = Image.open(uploaded_file) # Open the uploaded image
372
+ img = np.array(image.convert("RGB")) # Convert the image to RGB format
373
+ else:
374
+ response = requests.get(image_url) # Download the image from the URL
375
+ # Open the downloaded image
376
+ image = Image.open(BytesIO(response.content))
377
+ img = np.array(image.convert("RGB")) # Convert the image to RGB format
378
+
379
+ analyze_frame(img) # Analyze the image
380
+ publish_frame() # Publish the results
381
+
382
+
383
+ # Function to process video files
384
+ # This function reads frames from a video file, analyzes each frame for face detection and sentiment analysis,
385
+ # and updates the Streamlit UI with the current input frame, analyzed frame, and detected labels.
386
+ def process_video(video_path):
387
+ cap = cv2.VideoCapture(video_path) # Open the video file
388
+ while cap.isOpened():
389
+ ret, frame = cap.read() # Read a frame from the video
390
+ if not ret:
391
+ break # Exit the loop if no more frames are available
392
+
393
+ # Display the current frame as the input frame
394
+ input_placeholder.image(frame)
395
+ analyze_frame(
396
+ frame
397
+ ) # Analyze the frame for face detection and sentiment analysis
398
+ publish_frame() # Publish the results
399
+
400
+ if not result_queue.empty():
401
+ result = result_queue.get()
402
+ if show_labels:
403
+ labels_placeholder.table(
404
+ result
405
+ ) # Display labels if the checkbox is checked
406
+
407
+ cap.release() # Release the video capture object
408
+
409
+
410
+ # If a video is uploaded or a URL is provided, process the video
411
+ if uploaded_video is not None or video_url:
412
+ analysis_init() # Initialize the analysis UI
413
+
414
+ if uploaded_video is not None:
415
+ video_path = uploaded_video.name # Get the name of the uploaded video
416
+ with open(video_path, "wb") as f:
417
+ # Save the uploaded video to a file
418
+ f.write(uploaded_video.getbuffer())
419
+ else:
420
+ # Download the video from the URL
421
+ video_path = download_file(video_url)
422
+
423
+ process_video(video_path) # Process the video