StephanST commited on
Commit
b84126d
·
verified ·
1 Parent(s): 2c3bb74

Upload 5 files

Browse files

FOSS models release 1.0

run_sliced_inference.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import sys
3
+ import os
4
+ import numpy as np
5
+ from sahi import AutoDetectionModel
6
+ from sahi.predict import get_sliced_prediction, get_prediction
7
+ import supervision as sv
8
+
9
+ # Check the number of command-line arguments
10
+ if len(sys.argv) != 8:
11
+ print("Usage: python yolov8_video_inference.py <model_path> <input_path> <output_path> <slice_height> <slice_width> <overlap_height_ratio> <overlap_width_ratio>")
12
+ sys.exit(1)
13
+
14
+ # Get command-line arguments
15
+ model_path = sys.argv[1]
16
+ input_path = sys.argv[2]
17
+ output_path = sys.argv[3]
18
+ slice_height = int(sys.argv[4])
19
+ slice_width = int(sys.argv[5])
20
+ overlap_height_ratio = float(sys.argv[6])
21
+ overlap_width_ratio = float(sys.argv[7])
22
+
23
+ # Load YOLOv8 model with SAHI
24
+ detection_model = AutoDetectionModel.from_pretrained(
25
+ model_type='yolov8', # or 'yolov8'
26
+ model_path=model_path,
27
+ confidence_threshold=0.1,
28
+ device="cpu" # or "cuda"
29
+ )
30
+
31
+ # Annotators
32
+ box_annotator = sv.BoxCornerAnnotator(thickness=2)
33
+ label_annotator = sv.LabelAnnotator(text_scale=0.5, text_thickness=2)
34
+
35
+ def annotate_image(image, object_predictions):
36
+ """
37
+ Given an OpenCV image and a list of object predictions from SAHI,
38
+ returns an annotated copy of that image.
39
+ """
40
+ if not object_predictions:
41
+ return image.copy()
42
+
43
+ xyxy, confidences, class_ids, class_names = [], [], [], []
44
+ for pred in object_predictions:
45
+ bbox = pred.bbox.to_xyxy() # [x1, y1, x2, y2]
46
+ xyxy.append(bbox)
47
+ confidences.append(pred.score.value)
48
+ class_ids.append(pred.category.id)
49
+ class_names.append(pred.category.name)
50
+
51
+ xyxy = np.array(xyxy, dtype=np.float32)
52
+ confidences = np.array(confidences, dtype=np.float32)
53
+ class_ids = np.array(class_ids, dtype=int)
54
+
55
+ detections = sv.Detections(
56
+ xyxy=xyxy,
57
+ confidence=confidences,
58
+ class_id=class_ids
59
+ )
60
+
61
+ labels = [f"{cn} {conf:.2f}" for cn, conf in zip(class_names, confidences)]
62
+
63
+ annotated = image.copy()
64
+ annotated = box_annotator.annotate(scene=annotated, detections=detections)
65
+ annotated = label_annotator.annotate(scene=annotated, detections=detections, labels=labels)
66
+ return annotated
67
+
68
+ def run_sliced_inference(image):
69
+ result = get_sliced_prediction(
70
+ image=image,
71
+ detection_model=detection_model,
72
+ slice_height=slice_height,
73
+ slice_width=slice_width,
74
+ overlap_height_ratio=overlap_height_ratio,
75
+ overlap_width_ratio=overlap_width_ratio
76
+ )
77
+ return annotate_image(image, result.object_prediction_list)
78
+
79
+ def run_full_inference(image):
80
+ # Normal inference without slicing
81
+ result = get_prediction(
82
+ image=image,
83
+ detection_model=detection_model
84
+ # postprocess_match_threshold=0.5, # If you want to adjust the post-processing threshold
85
+ )
86
+ return annotate_image(image, result.object_prediction_list)
87
+
88
+ # Determine if the input is an image or video based on file extension
89
+ _, ext = os.path.splitext(input_path.lower())
90
+ image_extensions = [".png", ".jpg", ".jpeg", ".bmp"]
91
+
92
+ if ext in image_extensions:
93
+ # ----- IMAGE PROCESSING -----
94
+ image = cv2.imread(input_path)
95
+ if image is None:
96
+ print(f"Error loading image: {input_path}")
97
+ sys.exit(1)
98
+
99
+ h, w = image.shape[:2]
100
+
101
+ # Decide whether or not to slice
102
+ if False: #(h > slice_height) or (w > slice_width):
103
+ # If the image is bigger than slice dims, do sliced inference
104
+ annotated_image = run_sliced_inference(image)
105
+ else:
106
+ # Otherwise do normal inference
107
+ annotated_image = run_full_inference(image)
108
+
109
+ cv2.imwrite(output_path, annotated_image)
110
+ print(f"Inference complete. Annotated image saved at '{output_path}'")
111
+
112
+ else:
113
+ # ----- VIDEO PROCESSING -----
114
+ cap = cv2.VideoCapture(input_path)
115
+ if not cap.isOpened():
116
+ print(f"Error opening video: {input_path}")
117
+ sys.exit(1)
118
+
119
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
120
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
121
+ fps = cap.get(cv2.CAP_PROP_FPS)
122
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
123
+
124
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
125
+ frame_count = 0
126
+
127
+ while cap.isOpened():
128
+ ret, frame = cap.read()
129
+ if not ret:
130
+ break
131
+
132
+ # For each frame, you may or may not want slicing. Usually, you can do normal slicing if needed.
133
+ annotated_frame = run_sliced_inference(frame)
134
+ out.write(annotated_frame)
135
+
136
+ frame_count += 1
137
+ print(f"Processed frame {frame_count}", end='\r')
138
+
139
+ cap.release()
140
+ out.release()
141
+ print(f"\nInference complete. Video saved at '{output_path}'")
run_sliced_inference_with_tracker.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import sys
3
+ from sahi.models.yolov8 import Yolov8DetectionModel
4
+ from sahi.predict import get_sliced_prediction
5
+ import supervision as sv
6
+ import numpy as np
7
+
8
+ # Check the number of command-line arguments
9
+ if len(sys.argv) != 8:
10
+ print("Usage: python yolov8_video_inference.py <model_path> <input_video_path> <output_video_path> <slice_height> <slice_width> <overlap_height_ratio> <overlap_width_ratio>")
11
+ sys.exit(1)
12
+
13
+ # Get command-line arguments
14
+ model_path = sys.argv[1]
15
+ input_video_path = sys.argv[2]
16
+ output_video_path = sys.argv[3]
17
+ slice_height = int(sys.argv[4])
18
+ slice_width = int(sys.argv[5])
19
+ overlap_height_ratio = float(sys.argv[6])
20
+ overlap_width_ratio = float(sys.argv[7])
21
+
22
+ # Load YOLOv8 model with SAHI
23
+ detection_model = Yolov8DetectionModel(
24
+ model_path=model_path,
25
+ confidence_threshold=0.25,
26
+ device="cuda" # or "cpu"
27
+ )
28
+
29
+ # Get video info
30
+ video_info = sv.VideoInfo.from_video_path(video_path=input_video_path)
31
+
32
+ # Open input video
33
+ cap = cv2.VideoCapture(input_video_path)
34
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
36
+ fps = cap.get(cv2.CAP_PROP_FPS)
37
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
38
+
39
+ # Set up output video writer
40
+ out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
41
+
42
+ # Initialize tracker and smoother
43
+ tracker = sv.ByteTrack(frame_rate=video_info.fps)
44
+ smoother = sv.DetectionsSmoother()
45
+
46
+ # Create bounding box and label annotators
47
+ box_annotator = sv.BoxCornerAnnotator(thickness=2)
48
+ label_annotator = sv.LabelAnnotator(
49
+ text_scale=0.5,
50
+ text_thickness=1,
51
+ text_padding=1
52
+ )
53
+
54
+ # Process each frame
55
+ frame_count = 0
56
+ class_id_to_name = {} # Initialize once to store class_id to name mapping
57
+
58
+ while cap.isOpened():
59
+ ret, frame = cap.read()
60
+ if not ret:
61
+ break
62
+
63
+ # Perform sliced inference on the current frame using SAHI
64
+ result = get_sliced_prediction(
65
+ image=frame,
66
+ detection_model=detection_model,
67
+ slice_height=slice_height,
68
+ slice_width=slice_width,
69
+ overlap_height_ratio=overlap_height_ratio,
70
+ overlap_width_ratio=overlap_width_ratio
71
+ )
72
+
73
+ # Extract data from SAHI result
74
+ object_predictions = result.object_prediction_list
75
+
76
+ # Initialize lists to hold the data
77
+ xyxy = []
78
+ confidences = []
79
+ class_ids = []
80
+ # Build or update class_id to name mapping
81
+ for pred in object_predictions:
82
+ if pred.category.id not in class_id_to_name:
83
+ class_id_to_name[pred.category.id] = pred.category.name
84
+
85
+ # Loop over the object predictions and extract data
86
+ for pred in object_predictions:
87
+ bbox = pred.bbox.to_xyxy() # Convert bbox to [x1, y1, x2, y2]
88
+ xyxy.append(bbox)
89
+ confidences.append(pred.score.value)
90
+ class_ids.append(pred.category.id)
91
+
92
+ # Check if there are any detections
93
+ if xyxy:
94
+ # Convert lists to numpy arrays
95
+ xyxy = np.array(xyxy, dtype=np.float32)
96
+ confidences = np.array(confidences, dtype=np.float32)
97
+ class_ids = np.array(class_ids, dtype=int)
98
+
99
+ # Create sv.Detections object
100
+ detections = sv.Detections(
101
+ xyxy=xyxy,
102
+ confidence=confidences,
103
+ class_id=class_ids
104
+ )
105
+
106
+ # Update tracker with detections
107
+ detections = tracker.update_with_detections(detections)
108
+
109
+ # Update smoother with detections
110
+ detections = smoother.update_with_detections(detections)
111
+
112
+ # Prepare labels for label annotator
113
+ # Include tracker ID in labels if available
114
+ labels = []
115
+ for i in range(len(detections.xyxy)):
116
+ class_id = detections.class_id[i]
117
+ confidence = detections.confidence[i]
118
+ class_name = class_id_to_name.get(class_id, 'Unknown')
119
+ label = f"{class_name} {confidence:.2f}"
120
+
121
+ # Add tracker ID if available
122
+ if hasattr(detections, 'tracker_id') and detections.tracker_id is not None:
123
+ tracker_id = detections.tracker_id[i]
124
+ label = f"ID {tracker_id} {label}"
125
+
126
+ labels.append(label)
127
+
128
+ # Annotate frame with detection results
129
+ annotated_frame = frame.copy()
130
+ annotated_frame = box_annotator.annotate(
131
+ scene=annotated_frame,
132
+ detections=detections
133
+ )
134
+ annotated_frame = label_annotator.annotate(
135
+ scene=annotated_frame,
136
+ detections=detections,
137
+ labels=labels
138
+ )
139
+ else:
140
+ # If no detections, use the original frame
141
+ annotated_frame = frame.copy()
142
+
143
+ # Write the annotated frame to the output video
144
+ out.write(annotated_frame)
145
+
146
+ frame_count += 1
147
+ print(f"Processed frame {frame_count}", end='\r')
148
+
149
+ # Release resources
150
+ cap.release()
151
+ out.release()
152
+ print("\nInference complete. Video saved at", output_video_path)
unidrone_yolov8m_640px.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab44f269ec1c57087b30e54aa17aeff6470fa6a6d17d6329b507f396b109d1d2
3
+ size 97694621
unidrone_yolov8n_448px.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1ec33a56bfd4ff1fb4f4f0f90a9852746428b33b4681041c1b1096a13581c54
3
+ size 13055080
unidrone_yolov8n_640px.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b60edd7ede370fdbb89b91aec2ed2f0f0650a86f6cfa8fabb9e32f62a0a7de78
3
+ size 13111623