MiVOLO / mivolo /predictor.py
admin
sync
4c4ff57
raw
history blame
2.65 kB
from collections import defaultdict
from typing import Dict, Generator, List, Optional, Tuple
import cv2
import numpy as np
import tqdm
from mivolo.model.mi_volo import MiVOLO
from mivolo.model.yolo_detector import Detector
from mivolo.structures import AGE_GENDER_TYPE, PersonAndFaceResult
class Predictor:
def __init__(self, config, verbose: bool = False):
self.detector = Detector(config.detector_weights, config.device, verbose=verbose)
self.age_gender_model = MiVOLO(
config.checkpoint,
config.device,
half=True,
use_persons=config.with_persons,
disable_faces=config.disable_faces,
verbose=verbose,
)
self.draw = config.draw
def recognize(self, image: np.ndarray) -> Tuple[PersonAndFaceResult, Optional[np.ndarray]]:
detected_objects: PersonAndFaceResult = self.detector.predict(image)
self.age_gender_model.predict(image, detected_objects)
out_im = None
if self.draw:
# plot results on image
out_im = detected_objects.plot()
return detected_objects, out_im
def recognize_video(self, source: str) -> Generator:
video_capture = cv2.VideoCapture(source)
if not video_capture.isOpened():
raise ValueError(f"Failed to open video source {source}")
detected_objects_history: Dict[int, List[AGE_GENDER_TYPE]] = defaultdict(list)
total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
for _ in tqdm.tqdm(range(total_frames)):
ret, frame = video_capture.read()
if not ret:
break
detected_objects: PersonAndFaceResult = self.detector.track(frame)
self.age_gender_model.predict(frame, detected_objects)
current_frame_objs = detected_objects.get_results_for_tracking()
cur_persons: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[0]
cur_faces: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[1]
# add tr_persons and tr_faces to history
for guid, data in cur_persons.items():
# not useful for tracking :)
if None not in data:
detected_objects_history[guid].append(data)
for guid, data in cur_faces.items():
if None not in data:
detected_objects_history[guid].append(data)
detected_objects.set_tracked_age_gender(detected_objects_history)
if self.draw:
frame = detected_objects.plot()
yield detected_objects_history, frame