Upload 10 files
Browse files- onnx/friendfoe.onnx +3 -0
- onnx/hero.onnx +3 -0
- onnx/hp.onnx +3 -0
- onnx/ui.onnx +3 -0
- pt/friendfoe.pt +3 -0
- pt/hero.pt +3 -0
- pt/hp.pt +3 -0
- pt/ui.pt +3 -0
- quad_detect.py +258 -0
- training.ipynb +0 -0
onnx/friendfoe.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:771ce0dc372ae817b02cbbd71775ffd31474c6b4b63a4ce7d35bb0ae58d72d28
|
3 |
+
size 12569522
|
onnx/hero.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a5275890ecc28e74c7c559e502e03166035f6f012659ba7a3d8abbe344cc7ea4
|
3 |
+
size 175131042
|
onnx/hp.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f22e83aec5f2f22f4e2c361fb371eb92780dbcab5484d07254d22d99f516e2af
|
3 |
+
size 45063519
|
onnx/ui.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f6fd5b1978ca6083f5b92dcbcc62103a9df5ff8fcf996b3e030a30ed21fffbce
|
3 |
+
size 12580671
|
pt/friendfoe.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7d8b9d7961d659c6e9b5181d23f26e565c7088e322718920b5b322d56762341d
|
3 |
+
size 6281379
|
pt/hero.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:31c73e4e30c60afc9cc8e2889df0bc7dccc46c09eda6905d2ff6e3bed4d6d22a
|
3 |
+
size 87683187
|
pt/hp.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3313cbb42da36bf3fc17e4e6387d95f12dbc5470e34281627415b41752d76853
|
3 |
+
size 22623203
|
pt/ui.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ad39e98df822beab1de29dae729bc971c0002b1e626dcd1dcd2facd6981c0907
|
3 |
+
size 6254435
|
quad_detect.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import argparse
|
4 |
+
import glob
|
5 |
+
import time
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from ultralytics import YOLO
|
9 |
+
import threading
|
10 |
+
|
11 |
+
# -----------------------------------
|
12 |
+
# Parse command-line arguments
|
13 |
+
# -----------------------------------
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('--model1', help='Path to YOLO model #1', required=True)
|
16 |
+
parser.add_argument('--model2', help='Path to YOLO model #2', required=True)
|
17 |
+
parser.add_argument('--model3', help='Path to YOLO model #3', required=True)
|
18 |
+
parser.add_argument('--model4', help='Path to YOLO model #4', required=True)
|
19 |
+
parser.add_argument('--source', help='Image or video source (file, folder, or camera)', required=True)
|
20 |
+
parser.add_argument('--thresh', help='Minimum confidence threshold for displaying detected objects',
|
21 |
+
default=0.5, type=float)
|
22 |
+
parser.add_argument('--resolution', help='Resolution in WxH (e.g., "640x480"), otherwise matches source',
|
23 |
+
default=None)
|
24 |
+
parser.add_argument('--record', help='Record results and save as "demo1.avi". Requires --resolution.',
|
25 |
+
action='store_true')
|
26 |
+
args = parser.parse_args()
|
27 |
+
|
28 |
+
# -----------------------------------
|
29 |
+
# Load YOLO models
|
30 |
+
# -----------------------------------
|
31 |
+
model_paths = [args.model1, args.model2, args.model3, args.model4]
|
32 |
+
min_thresh = args.thresh
|
33 |
+
models = []
|
34 |
+
for path in model_paths:
|
35 |
+
if not os.path.exists(path):
|
36 |
+
print(f'ERROR: Model path {path} is invalid. Check the file path.')
|
37 |
+
sys.exit(0)
|
38 |
+
models.append(YOLO(path))
|
39 |
+
|
40 |
+
# Class names for each model
|
41 |
+
model_labels = [m.names for m in models]
|
42 |
+
|
43 |
+
# -----------------------------------
|
44 |
+
# Determine source type
|
45 |
+
# -----------------------------------
|
46 |
+
img_source = args.source
|
47 |
+
img_ext_list = ['.jpg', '.jpeg', '.png', '.bmp']
|
48 |
+
vid_ext_list = ['.avi', '.mov', '.mp4', '.mkv', '.wmv']
|
49 |
+
|
50 |
+
if os.path.isdir(img_source):
|
51 |
+
source_type = 'folder'
|
52 |
+
elif os.path.isfile(img_source):
|
53 |
+
_, ext = os.path.splitext(img_source)
|
54 |
+
if ext in img_ext_list:
|
55 |
+
source_type = 'image'
|
56 |
+
elif ext in vid_ext_list:
|
57 |
+
source_type = 'video'
|
58 |
+
else:
|
59 |
+
print(f'Unsupported file extension: {ext}')
|
60 |
+
sys.exit(0)
|
61 |
+
elif 'usb' in img_source:
|
62 |
+
source_type = 'usb'
|
63 |
+
usb_idx = int(img_source[3:])
|
64 |
+
else:
|
65 |
+
print(f'Invalid source: {img_source}')
|
66 |
+
sys.exit(0)
|
67 |
+
|
68 |
+
# -----------------------------------
|
69 |
+
# Set up resolution and recorder
|
70 |
+
# -----------------------------------
|
71 |
+
resize = False
|
72 |
+
user_res = args.resolution
|
73 |
+
record = args.record
|
74 |
+
if user_res:
|
75 |
+
resize = True
|
76 |
+
resW, resH = map(int, user_res.split('x'))
|
77 |
+
|
78 |
+
if record:
|
79 |
+
if source_type not in ['video', 'usb']:
|
80 |
+
print('Recording only works for video and camera sources.')
|
81 |
+
sys.exit(0)
|
82 |
+
if not user_res:
|
83 |
+
print('Please specify resolution to record.')
|
84 |
+
sys.exit(0)
|
85 |
+
record_name = 'demo1.avi'
|
86 |
+
record_fps = 30
|
87 |
+
recorder = cv2.VideoWriter(record_name, cv2.VideoWriter_fourcc(*'MJPG'),
|
88 |
+
record_fps, (resW, resH))
|
89 |
+
|
90 |
+
# -----------------------------------
|
91 |
+
# Load image/video source
|
92 |
+
# -----------------------------------
|
93 |
+
if source_type == 'image':
|
94 |
+
imgs_list = [img_source]
|
95 |
+
elif source_type == 'folder':
|
96 |
+
imgs_list = [file for file in glob.glob(img_source + '/*')
|
97 |
+
if os.path.splitext(file)[1] in img_ext_list]
|
98 |
+
elif source_type in ['video', 'usb']:
|
99 |
+
cap = cv2.VideoCapture(img_source if source_type == 'video' else usb_idx)
|
100 |
+
if resize:
|
101 |
+
cap.set(3, resW)
|
102 |
+
cap.set(4, resH)
|
103 |
+
|
104 |
+
# -----------------------------------
|
105 |
+
# Colors for each model's bounding boxes
|
106 |
+
# -----------------------------------
|
107 |
+
model_colors = [
|
108 |
+
(0, 128, 0), # Model 1: dark green
|
109 |
+
(128, 0, 0), # Model 2: dark red
|
110 |
+
(0, 0, 128), # Model 3: dark blue
|
111 |
+
(128, 128, 0) # Model 4: dark olive (muted cyan alternative)
|
112 |
+
]
|
113 |
+
# -----------------------------------
|
114 |
+
# FPS tracking
|
115 |
+
# -----------------------------------
|
116 |
+
avg_frame_rate = 0
|
117 |
+
frame_rate_buffer = []
|
118 |
+
fps_avg_len = 200
|
119 |
+
img_count = 0
|
120 |
+
|
121 |
+
# -----------------------------------
|
122 |
+
# Thread function for model inference
|
123 |
+
# -----------------------------------
|
124 |
+
def run_detection(index, frame, results_dict):
|
125 |
+
results_dict[index] = models[index](frame, verbose=False)[0]
|
126 |
+
|
127 |
+
# -----------------------------------
|
128 |
+
# draw_detections with left/right anchor
|
129 |
+
# -----------------------------------
|
130 |
+
def draw_detections(results, labels, frame, color, min_thresh, anchor='left'):
|
131 |
+
for det in results.boxes:
|
132 |
+
xyxy = det.xyxy.cpu().numpy().squeeze().astype(int)
|
133 |
+
class_idx = int(det.cls.item())
|
134 |
+
conf = det.conf.item()
|
135 |
+
|
136 |
+
if conf > min_thresh:
|
137 |
+
x1, y1, x2, y2 = xyxy
|
138 |
+
|
139 |
+
# Draw bounding box
|
140 |
+
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
141 |
+
|
142 |
+
label = f'{labels[class_idx]}: {int(conf * 100)}%'
|
143 |
+
label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
144 |
+
|
145 |
+
if anchor == 'left':
|
146 |
+
# Label at top-left corner
|
147 |
+
text_x = x1
|
148 |
+
text_y = y1 - 5
|
149 |
+
box_x1 = x1
|
150 |
+
box_x2 = x1 + label_size[0]
|
151 |
+
else:
|
152 |
+
# Label at top-right corner
|
153 |
+
text_x = x2 - label_size[0]
|
154 |
+
text_y = y1 - 5
|
155 |
+
box_x1 = x2 - label_size[0]
|
156 |
+
box_x2 = x2
|
157 |
+
|
158 |
+
# Draw filled rectangle for label background
|
159 |
+
cv2.rectangle(
|
160 |
+
frame,
|
161 |
+
(box_x1, y1 - label_size[1] - 10),
|
162 |
+
(box_x2, y1),
|
163 |
+
color,
|
164 |
+
-1
|
165 |
+
)
|
166 |
+
|
167 |
+
# Draw label text in white
|
168 |
+
cv2.putText(
|
169 |
+
frame,
|
170 |
+
label,
|
171 |
+
(text_x, text_y),
|
172 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
173 |
+
0.5,
|
174 |
+
(255, 255, 255),
|
175 |
+
1
|
176 |
+
)
|
177 |
+
|
178 |
+
# -----------------------------------
|
179 |
+
# Main loop
|
180 |
+
# -----------------------------------
|
181 |
+
while True:
|
182 |
+
t_start = time.perf_counter()
|
183 |
+
|
184 |
+
# Grab frame
|
185 |
+
if source_type in ['image', 'folder']:
|
186 |
+
if img_count >= len(imgs_list):
|
187 |
+
print('All images processed. Exiting.')
|
188 |
+
sys.exit(0)
|
189 |
+
frame = cv2.imread(imgs_list[img_count])
|
190 |
+
img_count += 1
|
191 |
+
elif source_type in ['video', 'usb']:
|
192 |
+
ret, frame = cap.read()
|
193 |
+
if not ret:
|
194 |
+
print('No more frames or camera disconnected. Exiting.')
|
195 |
+
break
|
196 |
+
|
197 |
+
# Resize if needed
|
198 |
+
if resize:
|
199 |
+
frame = cv2.resize(frame, (resW, resH))
|
200 |
+
|
201 |
+
# Dictionary to store results from each model
|
202 |
+
results_dict = {}
|
203 |
+
|
204 |
+
# Run all models in separate threads
|
205 |
+
threads = []
|
206 |
+
for i in range(4):
|
207 |
+
t = threading.Thread(target=run_detection, args=(i, frame, results_dict))
|
208 |
+
t.start()
|
209 |
+
threads.append(t)
|
210 |
+
|
211 |
+
# Wait for all threads to finish
|
212 |
+
for t in threads:
|
213 |
+
t.join()
|
214 |
+
|
215 |
+
# Draw each model’s detections
|
216 |
+
for i in range(4):
|
217 |
+
# Move the 3rd model's label to the right side (i == 2 in 0-based indexing)
|
218 |
+
if i == 2:
|
219 |
+
draw_detections(results_dict[i], model_labels[i], frame, model_colors[i],
|
220 |
+
min_thresh, anchor='right')
|
221 |
+
else:
|
222 |
+
draw_detections(results_dict[i], model_labels[i], frame, model_colors[i],
|
223 |
+
min_thresh, anchor='left')
|
224 |
+
|
225 |
+
# Show FPS
|
226 |
+
cv2.putText(frame, f'FPS: {avg_frame_rate:.2f}',
|
227 |
+
(10, 20),
|
228 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
229 |
+
0.7,
|
230 |
+
(0, 255, 255),
|
231 |
+
2)
|
232 |
+
|
233 |
+
# Display frame
|
234 |
+
cv2.imshow('YOLO Detection Results', frame)
|
235 |
+
if record:
|
236 |
+
recorder.write(frame)
|
237 |
+
|
238 |
+
key = cv2.waitKey(5)
|
239 |
+
if key == ord('q'):
|
240 |
+
break
|
241 |
+
|
242 |
+
# Compute FPS
|
243 |
+
t_stop = time.perf_counter()
|
244 |
+
frame_rate_calc = 1 / (t_stop - t_start)
|
245 |
+
frame_rate_buffer.append(frame_rate_calc)
|
246 |
+
if len(frame_rate_buffer) > fps_avg_len:
|
247 |
+
frame_rate_buffer.pop(0)
|
248 |
+
avg_frame_rate = np.mean(frame_rate_buffer)
|
249 |
+
|
250 |
+
# -----------------------------------
|
251 |
+
# Cleanup
|
252 |
+
# -----------------------------------
|
253 |
+
print(f'Average FPS: {avg_frame_rate:.2f}')
|
254 |
+
if source_type in ['video', 'usb']:
|
255 |
+
cap.release()
|
256 |
+
if record:
|
257 |
+
recorder.release()
|
258 |
+
cv2.destroyAllWindows()
|
training.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|