Spaces:
Running
Running
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license | |
from collections import defaultdict | |
import cv2 | |
from ultralytics import YOLO | |
from ultralytics.utils import ASSETS_URL, DEFAULT_CFG_DICT, DEFAULT_SOL_DICT, LOGGER | |
from ultralytics.utils.checks import check_imshow, check_requirements | |
class BaseSolution: | |
""" | |
A base class for managing Ultralytics Solutions. | |
This class provides core functionality for various Ultralytics Solutions, including model loading, object tracking, | |
and region initialization. | |
Attributes: | |
LineString (shapely.geometry.LineString): Class for creating line string geometries. | |
Polygon (shapely.geometry.Polygon): Class for creating polygon geometries. | |
Point (shapely.geometry.Point): Class for creating point geometries. | |
CFG (Dict): Configuration dictionary loaded from a YAML file and updated with kwargs. | |
region (List[Tuple[int, int]]): List of coordinate tuples defining a region of interest. | |
line_width (int): Width of lines used in visualizations. | |
model (ultralytics.YOLO): Loaded YOLO model instance. | |
names (Dict[int, str]): Dictionary mapping class indices to class names. | |
env_check (bool): Flag indicating whether the environment supports image display. | |
track_history (collections.defaultdict): Dictionary to store tracking history for each object. | |
Methods: | |
extract_tracks: Apply object tracking and extract tracks from an input image. | |
store_tracking_history: Store object tracking history for a given track ID and bounding box. | |
initialize_region: Initialize the counting region and line segment based on configuration. | |
display_output: Display the results of processing, including showing frames or saving results. | |
Examples: | |
>>> solution = BaseSolution(model="yolov8n.pt", region=[(0, 0), (100, 0), (100, 100), (0, 100)]) | |
>>> solution.initialize_region() | |
>>> image = cv2.imread("image.jpg") | |
>>> solution.extract_tracks(image) | |
>>> solution.display_output(image) | |
""" | |
def __init__(self, IS_CLI=False, **kwargs): | |
""" | |
Initializes the `BaseSolution` class with configuration settings and the YOLO model for Ultralytics solutions. | |
IS_CLI (optional): Enables CLI mode if set. | |
""" | |
check_requirements("shapely>=2.0.0") | |
from shapely.geometry import LineString, Point, Polygon | |
from shapely.prepared import prep | |
self.LineString = LineString | |
self.Polygon = Polygon | |
self.Point = Point | |
self.prep = prep | |
self.annotator = None # Initialize annotator | |
self.tracks = None | |
self.track_data = None | |
self.boxes = [] | |
self.clss = [] | |
self.track_ids = [] | |
self.track_line = None | |
self.r_s = None | |
# Load config and update with args | |
DEFAULT_SOL_DICT.update(kwargs) | |
DEFAULT_CFG_DICT.update(kwargs) | |
self.CFG = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT} | |
LOGGER.info(f"Ultralytics Solutions: ✅ {DEFAULT_SOL_DICT}") | |
self.region = self.CFG["region"] # Store region data for other classes usage | |
self.line_width = ( | |
self.CFG["line_width"] if self.CFG["line_width"] is not None else 2 | |
) # Store line_width for usage | |
# Load Model and store classes names | |
if self.CFG["model"] is None: | |
self.CFG["model"] = "yolo11n.pt" | |
self.model = YOLO(self.CFG["model"]) | |
self.names = self.model.names | |
self.track_add_args = { # Tracker additional arguments for advance configuration | |
k: self.CFG[k] for k in ["verbose", "iou", "conf", "device", "max_det", "half", "tracker"] | |
} | |
if IS_CLI and self.CFG["source"] is None: | |
d_s = "solutions_ci_demo.mp4" if "-pose" not in self.CFG["model"] else "solution_ci_pose_demo.mp4" | |
LOGGER.warning(f"⚠️ WARNING: source not provided. using default source {ASSETS_URL}/{d_s}") | |
from ultralytics.utils.downloads import safe_download | |
safe_download(f"{ASSETS_URL}/{d_s}") # download source from ultralytics assets | |
self.CFG["source"] = d_s # set default source | |
# Initialize environment and region setup | |
self.env_check = check_imshow(warn=True) | |
self.track_history = defaultdict(list) | |
def extract_tracks(self, im0): | |
""" | |
Applies object tracking and extracts tracks from an input image or frame. | |
Args: | |
im0 (ndarray): The input image or frame. | |
Examples: | |
>>> solution = BaseSolution() | |
>>> frame = cv2.imread("path/to/image.jpg") | |
>>> solution.extract_tracks(frame) | |
""" | |
self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"], **self.track_add_args) | |
# Extract tracks for OBB or object detection | |
self.track_data = self.tracks[0].obb or self.tracks[0].boxes | |
if self.track_data and self.track_data.id is not None: | |
self.boxes = self.track_data.xyxy.cpu() | |
self.clss = self.track_data.cls.cpu().tolist() | |
self.track_ids = self.track_data.id.int().cpu().tolist() | |
else: | |
LOGGER.warning("WARNING ⚠️ no tracks found!") | |
self.boxes, self.clss, self.track_ids = [], [], [] | |
def store_tracking_history(self, track_id, box): | |
""" | |
Stores the tracking history of an object. | |
This method updates the tracking history for a given object by appending the center point of its | |
bounding box to the track line. It maintains a maximum of 30 points in the tracking history. | |
Args: | |
track_id (int): The unique identifier for the tracked object. | |
box (List[float]): The bounding box coordinates of the object in the format [x1, y1, x2, y2]. | |
Examples: | |
>>> solution = BaseSolution() | |
>>> solution.store_tracking_history(1, [100, 200, 300, 400]) | |
""" | |
# Store tracking history | |
self.track_line = self.track_history[track_id] | |
self.track_line.append(((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)) | |
if len(self.track_line) > 30: | |
self.track_line.pop(0) | |
def initialize_region(self): | |
"""Initialize the counting region and line segment based on configuration settings.""" | |
if self.region is None: | |
self.region = [(20, 400), (1080, 400), (1080, 360), (20, 360)] | |
self.r_s = ( | |
self.Polygon(self.region) if len(self.region) >= 3 else self.LineString(self.region) | |
) # region or line | |
def display_output(self, im0): | |
""" | |
Display the results of the processing, which could involve showing frames, printing counts, or saving results. | |
This method is responsible for visualizing the output of the object detection and tracking process. It displays | |
the processed frame with annotations, and allows for user interaction to close the display. | |
Args: | |
im0 (numpy.ndarray): The input image or frame that has been processed and annotated. | |
Examples: | |
>>> solution = BaseSolution() | |
>>> frame = cv2.imread("path/to/image.jpg") | |
>>> solution.display_output(frame) | |
Notes: | |
- This method will only display output if the 'show' configuration is set to True and the environment | |
supports image display. | |
- The display can be closed by pressing the 'q' key. | |
""" | |
if self.CFG.get("show") and self.env_check: | |
cv2.imshow("Ultralytics Solutions", im0) | |
if cv2.waitKey(1) & 0xFF == ord("q"): | |
return | |