Chappieut commited on
Commit
1f5d7bc
·
verified ·
1 Parent(s): e90de48

Upload 10 files

Browse files
Files changed (10) hide show
  1. onnx/friendfoe.onnx +3 -0
  2. onnx/hero.onnx +3 -0
  3. onnx/hp.onnx +3 -0
  4. onnx/ui.onnx +3 -0
  5. pt/friendfoe.pt +3 -0
  6. pt/hero.pt +3 -0
  7. pt/hp.pt +3 -0
  8. pt/ui.pt +3 -0
  9. quad_detect.py +258 -0
  10. training.ipynb +0 -0
onnx/friendfoe.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:771ce0dc372ae817b02cbbd71775ffd31474c6b4b63a4ce7d35bb0ae58d72d28
3
+ size 12569522
onnx/hero.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5275890ecc28e74c7c559e502e03166035f6f012659ba7a3d8abbe344cc7ea4
3
+ size 175131042
onnx/hp.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f22e83aec5f2f22f4e2c361fb371eb92780dbcab5484d07254d22d99f516e2af
3
+ size 45063519
onnx/ui.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6fd5b1978ca6083f5b92dcbcc62103a9df5ff8fcf996b3e030a30ed21fffbce
3
+ size 12580671
pt/friendfoe.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d8b9d7961d659c6e9b5181d23f26e565c7088e322718920b5b322d56762341d
3
+ size 6281379
pt/hero.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31c73e4e30c60afc9cc8e2889df0bc7dccc46c09eda6905d2ff6e3bed4d6d22a
3
+ size 87683187
pt/hp.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3313cbb42da36bf3fc17e4e6387d95f12dbc5470e34281627415b41752d76853
3
+ size 22623203
pt/ui.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad39e98df822beab1de29dae729bc971c0002b1e626dcd1dcd2facd6981c0907
3
+ size 6254435
quad_detect.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import glob
5
+ import time
6
+ import cv2
7
+ import numpy as np
8
+ from ultralytics import YOLO
9
+ import threading
10
+
11
+ # -----------------------------------
12
+ # Parse command-line arguments
13
+ # -----------------------------------
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--model1', help='Path to YOLO model #1', required=True)
16
+ parser.add_argument('--model2', help='Path to YOLO model #2', required=True)
17
+ parser.add_argument('--model3', help='Path to YOLO model #3', required=True)
18
+ parser.add_argument('--model4', help='Path to YOLO model #4', required=True)
19
+ parser.add_argument('--source', help='Image or video source (file, folder, or camera)', required=True)
20
+ parser.add_argument('--thresh', help='Minimum confidence threshold for displaying detected objects',
21
+ default=0.5, type=float)
22
+ parser.add_argument('--resolution', help='Resolution in WxH (e.g., "640x480"), otherwise matches source',
23
+ default=None)
24
+ parser.add_argument('--record', help='Record results and save as "demo1.avi". Requires --resolution.',
25
+ action='store_true')
26
+ args = parser.parse_args()
27
+
28
+ # -----------------------------------
29
+ # Load YOLO models
30
+ # -----------------------------------
31
+ model_paths = [args.model1, args.model2, args.model3, args.model4]
32
+ min_thresh = args.thresh
33
+ models = []
34
+ for path in model_paths:
35
+ if not os.path.exists(path):
36
+ print(f'ERROR: Model path {path} is invalid. Check the file path.')
37
+ sys.exit(0)
38
+ models.append(YOLO(path))
39
+
40
+ # Class names for each model
41
+ model_labels = [m.names for m in models]
42
+
43
+ # -----------------------------------
44
+ # Determine source type
45
+ # -----------------------------------
46
+ img_source = args.source
47
+ img_ext_list = ['.jpg', '.jpeg', '.png', '.bmp']
48
+ vid_ext_list = ['.avi', '.mov', '.mp4', '.mkv', '.wmv']
49
+
50
+ if os.path.isdir(img_source):
51
+ source_type = 'folder'
52
+ elif os.path.isfile(img_source):
53
+ _, ext = os.path.splitext(img_source)
54
+ if ext in img_ext_list:
55
+ source_type = 'image'
56
+ elif ext in vid_ext_list:
57
+ source_type = 'video'
58
+ else:
59
+ print(f'Unsupported file extension: {ext}')
60
+ sys.exit(0)
61
+ elif 'usb' in img_source:
62
+ source_type = 'usb'
63
+ usb_idx = int(img_source[3:])
64
+ else:
65
+ print(f'Invalid source: {img_source}')
66
+ sys.exit(0)
67
+
68
+ # -----------------------------------
69
+ # Set up resolution and recorder
70
+ # -----------------------------------
71
+ resize = False
72
+ user_res = args.resolution
73
+ record = args.record
74
+ if user_res:
75
+ resize = True
76
+ resW, resH = map(int, user_res.split('x'))
77
+
78
+ if record:
79
+ if source_type not in ['video', 'usb']:
80
+ print('Recording only works for video and camera sources.')
81
+ sys.exit(0)
82
+ if not user_res:
83
+ print('Please specify resolution to record.')
84
+ sys.exit(0)
85
+ record_name = 'demo1.avi'
86
+ record_fps = 30
87
+ recorder = cv2.VideoWriter(record_name, cv2.VideoWriter_fourcc(*'MJPG'),
88
+ record_fps, (resW, resH))
89
+
90
+ # -----------------------------------
91
+ # Load image/video source
92
+ # -----------------------------------
93
+ if source_type == 'image':
94
+ imgs_list = [img_source]
95
+ elif source_type == 'folder':
96
+ imgs_list = [file for file in glob.glob(img_source + '/*')
97
+ if os.path.splitext(file)[1] in img_ext_list]
98
+ elif source_type in ['video', 'usb']:
99
+ cap = cv2.VideoCapture(img_source if source_type == 'video' else usb_idx)
100
+ if resize:
101
+ cap.set(3, resW)
102
+ cap.set(4, resH)
103
+
104
+ # -----------------------------------
105
+ # Colors for each model's bounding boxes
106
+ # -----------------------------------
107
+ model_colors = [
108
+ (0, 128, 0), # Model 1: dark green
109
+ (128, 0, 0), # Model 2: dark red
110
+ (0, 0, 128), # Model 3: dark blue
111
+ (128, 128, 0) # Model 4: dark olive (muted cyan alternative)
112
+ ]
113
+ # -----------------------------------
114
+ # FPS tracking
115
+ # -----------------------------------
116
+ avg_frame_rate = 0
117
+ frame_rate_buffer = []
118
+ fps_avg_len = 200
119
+ img_count = 0
120
+
121
+ # -----------------------------------
122
+ # Thread function for model inference
123
+ # -----------------------------------
124
+ def run_detection(index, frame, results_dict):
125
+ results_dict[index] = models[index](frame, verbose=False)[0]
126
+
127
+ # -----------------------------------
128
+ # draw_detections with left/right anchor
129
+ # -----------------------------------
130
+ def draw_detections(results, labels, frame, color, min_thresh, anchor='left'):
131
+ for det in results.boxes:
132
+ xyxy = det.xyxy.cpu().numpy().squeeze().astype(int)
133
+ class_idx = int(det.cls.item())
134
+ conf = det.conf.item()
135
+
136
+ if conf > min_thresh:
137
+ x1, y1, x2, y2 = xyxy
138
+
139
+ # Draw bounding box
140
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
141
+
142
+ label = f'{labels[class_idx]}: {int(conf * 100)}%'
143
+ label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
144
+
145
+ if anchor == 'left':
146
+ # Label at top-left corner
147
+ text_x = x1
148
+ text_y = y1 - 5
149
+ box_x1 = x1
150
+ box_x2 = x1 + label_size[0]
151
+ else:
152
+ # Label at top-right corner
153
+ text_x = x2 - label_size[0]
154
+ text_y = y1 - 5
155
+ box_x1 = x2 - label_size[0]
156
+ box_x2 = x2
157
+
158
+ # Draw filled rectangle for label background
159
+ cv2.rectangle(
160
+ frame,
161
+ (box_x1, y1 - label_size[1] - 10),
162
+ (box_x2, y1),
163
+ color,
164
+ -1
165
+ )
166
+
167
+ # Draw label text in white
168
+ cv2.putText(
169
+ frame,
170
+ label,
171
+ (text_x, text_y),
172
+ cv2.FONT_HERSHEY_SIMPLEX,
173
+ 0.5,
174
+ (255, 255, 255),
175
+ 1
176
+ )
177
+
178
+ # -----------------------------------
179
+ # Main loop
180
+ # -----------------------------------
181
+ while True:
182
+ t_start = time.perf_counter()
183
+
184
+ # Grab frame
185
+ if source_type in ['image', 'folder']:
186
+ if img_count >= len(imgs_list):
187
+ print('All images processed. Exiting.')
188
+ sys.exit(0)
189
+ frame = cv2.imread(imgs_list[img_count])
190
+ img_count += 1
191
+ elif source_type in ['video', 'usb']:
192
+ ret, frame = cap.read()
193
+ if not ret:
194
+ print('No more frames or camera disconnected. Exiting.')
195
+ break
196
+
197
+ # Resize if needed
198
+ if resize:
199
+ frame = cv2.resize(frame, (resW, resH))
200
+
201
+ # Dictionary to store results from each model
202
+ results_dict = {}
203
+
204
+ # Run all models in separate threads
205
+ threads = []
206
+ for i in range(4):
207
+ t = threading.Thread(target=run_detection, args=(i, frame, results_dict))
208
+ t.start()
209
+ threads.append(t)
210
+
211
+ # Wait for all threads to finish
212
+ for t in threads:
213
+ t.join()
214
+
215
+ # Draw each model’s detections
216
+ for i in range(4):
217
+ # Move the 3rd model's label to the right side (i == 2 in 0-based indexing)
218
+ if i == 2:
219
+ draw_detections(results_dict[i], model_labels[i], frame, model_colors[i],
220
+ min_thresh, anchor='right')
221
+ else:
222
+ draw_detections(results_dict[i], model_labels[i], frame, model_colors[i],
223
+ min_thresh, anchor='left')
224
+
225
+ # Show FPS
226
+ cv2.putText(frame, f'FPS: {avg_frame_rate:.2f}',
227
+ (10, 20),
228
+ cv2.FONT_HERSHEY_SIMPLEX,
229
+ 0.7,
230
+ (0, 255, 255),
231
+ 2)
232
+
233
+ # Display frame
234
+ cv2.imshow('YOLO Detection Results', frame)
235
+ if record:
236
+ recorder.write(frame)
237
+
238
+ key = cv2.waitKey(5)
239
+ if key == ord('q'):
240
+ break
241
+
242
+ # Compute FPS
243
+ t_stop = time.perf_counter()
244
+ frame_rate_calc = 1 / (t_stop - t_start)
245
+ frame_rate_buffer.append(frame_rate_calc)
246
+ if len(frame_rate_buffer) > fps_avg_len:
247
+ frame_rate_buffer.pop(0)
248
+ avg_frame_rate = np.mean(frame_rate_buffer)
249
+
250
+ # -----------------------------------
251
+ # Cleanup
252
+ # -----------------------------------
253
+ print(f'Average FPS: {avg_frame_rate:.2f}')
254
+ if source_type in ['video', 'usb']:
255
+ cap.release()
256
+ if record:
257
+ recorder.release()
258
+ cv2.destroyAllWindows()
training.ipynb ADDED
The diff for this file is too large to render. See raw diff