norfair-demo / inference.py
Diego Fernandez
chore: add tmp folder
2f7f294
raw
history blame
2.94 kB
import argparse
import glob
import os
import numpy as np
from norfair import Paths, Tracker, Video
from norfair.camera_motion import HomographyTransformationGetter, MotionEstimator
from inference_utils import (
YOLO,
ModelsPath,
Style,
center,
clean_videos,
draw,
euclidean_distance,
iou,
yolo_detections_to_norfair_detections,
)
DISTANCE_THRESHOLD_BBOX: float = 3.33
DISTANCE_THRESHOLD_CENTROID: int = 30
MAX_DISTANCE: int = 10000
parser = argparse.ArgumentParser(description="Track objects in a video.")
parser.add_argument("--img-size", type=int, default="720", help="YOLOv7 inference size (pixels)")
parser.add_argument(
"--iou-threshold", type=float, default="0.45", help="YOLOv7 IOU threshold for NMS"
)
parser.add_argument(
"--classes", nargs="+", type=int, help="Filter by class: --classes 0, or --classes 0 2 3"
)
args = parser.parse_args()
def inference(
input_video: str,
model: str,
motion_estimation: bool,
drawing_paths: bool,
track_points: str,
model_threshold: str,
):
output_path = "tmp"
clean_videos(output_path)
coord_transformations = None
paths_drawer = None
track_points = Style[track_points].value
model = YOLO(ModelsPath[model].value)
video = Video(input_path=input_video, output_path=output_path)
if motion_estimation:
transformations_getter = HomographyTransformationGetter()
motion_estimator = MotionEstimator(
max_points=500,
min_distance=7,
transformations_getter=transformations_getter,
draw_flow=True,
)
distance_function = iou if track_points == "bbox" else euclidean_distance
distance_threshold = (
DISTANCE_THRESHOLD_BBOX if track_points == "bbox" else DISTANCE_THRESHOLD_CENTROID
)
tracker = Tracker(
distance_function=distance_function,
distance_threshold=distance_threshold,
)
if drawing_paths:
paths_drawer = Paths(center, attenuation=0.01)
for frame in video:
yolo_detections = model(
frame,
conf_threshold=model_threshold,
iou_threshold=args.iou_threshold,
image_size=720,
classes=args.classes,
)
mask = np.ones(frame.shape[:2], frame.dtype)
if motion_estimation:
coord_transformations = motion_estimator.update(frame, mask)
detections = yolo_detections_to_norfair_detections(
yolo_detections, track_points=track_points
)
tracked_objects = tracker.update(
detections=detections, coord_transformations=coord_transformations
)
frame = draw(paths_drawer, track_points, frame, detections, tracked_objects)
video.write(frame)
base_file_name = input_video.split("/")[-1].split(".")[0]
file_name = base_file_name + "_out.mp4"
return os.path.join(output_path, file_name)