import json from typing import List, Optional, Tuple import cv2 import numpy as np import pandas as pd import torch from tap import Tap from torch import Tensor from transformers import ( AutoFeatureExtractor, TimesformerForVideoClassification, VideoMAEFeatureExtractor, ) from utils.img_container import ImgContainer class ArgParser(Tap): is_recording: Optional[bool] = False # "facebook/timesformer-base-finetuned-k400" # "facebook/timesformer-base-finetuned-k600", # "facebook/timesformer-base-finetuned-ssv2", # "facebook/timesformer-hr-finetuned-k600", # "facebook/timesformer-hr-finetuned-k400", # "facebook/timesformer-hr-finetuned-ssv2", # "fcakyon/timesformer-large-finetuned-k400", # "fcakyon/timesformer-large-finetuned-k600", model_name: Optional[str] = "facebook/timesformer-base-finetuned-k400" num_skip_frames: Optional[int] = 4 top_k: Optional[int] = 5 id2label: Optional[str] = "labels/kinetics_400.json" threshold: Optional[float] = 10.0 max_confidence: Optional[float] = 20.0 # Set None if not scale class ActivityModel: def __init__(self, args: ArgParser): self.feature_extractor, self.model = self.load_model(args.model_name) self.args = args self.frames_per_video = self.get_frames_per_video(args.model_name) print(f"Frames per video: {self.frames_per_video}") self.load_json() def load_json(self): if args.id2label is not None: with open(args.id2label, encoding="utf-8") as f: tmp = json.load(f) d = dict() for key, item in tmp.items(): d[int(key)] = item self.model.config.id2label = d def load_model( self, model_name: str ) -> Tuple[VideoMAEFeatureExtractor, TimesformerForVideoClassification]: if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name: feature_extractor = AutoFeatureExtractor.from_pretrained( "MCG-NJU/videomae-base-finetuned-kinetics" ) else: feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) model = TimesformerForVideoClassification.from_pretrained(model_name) return feature_extractor, model def inference(self, img_container: ImgContainer): if not img_container.ready: return inputs = self.feature_extractor(list(img_container.imgs), return_tensors="pt") with torch.no_grad(): outputs = self.model(**inputs) logits: Tensor = outputs.logits # model predicts one of the 400 Kinetics-400 classes max_index = logits.argmax(-1).item() predicted_label = self.model.config.id2label[max_index] confidence = logits[0][max_index] if (self.args.threshold is None) or ( self.args.threshold is not None and confidence >= self.args.threshold ): img_container.frame_rate.label = f"{predicted_label}_{confidence:.2f}%" # logits = np.squeeze(logits) logits = logits.squeeze().numpy() indices = np.argsort(logits)[::-1][: self.args.top_k] values = logits[indices] results: List[Tuple[str, float]] = [] for index, value in zip(indices, values): predicted_label = self.model.config.id2label[index] # print(f"Label: {predicted_label} - {value:.2f}%") results.append((predicted_label, value)) img_container.rs = pd.DataFrame(results, columns=("Label", "Confidence")) def get_frames_per_video(self, model_name: str) -> int: if "base-finetuned" in model_name: return 8 elif "hr-finetuned" in model_name: return 16 else: return 96 def main(args: ArgParser): activity_model = ActivityModel(args) img_container = ImgContainer(activity_model.frames_per_video, args.is_recording) num_skips = 0 # define a video capture object camera = cv2.VideoCapture(0) frame_width = int(camera.get(3)) frame_height = int(camera.get(4)) size = (frame_width, frame_height) video_output = cv2.VideoWriter( "activities.mp4", cv2.VideoWriter_fourcc(*"MP4V"), 10, size ) if camera.isOpened() == False: print("Error reading video file") while camera.isOpened(): # Capture the video frame # by frame ret, frame = camera.read() num_skips = (num_skips + 1) % args.num_skip_frames img_container.img = frame img_container.frame_rate.count() if num_skips == 0: img_container.add_frame(frame) activity_model.inference(img_container) rs = img_container.frame_rate.show_fps(frame, img_container.is_recording) # Display the resulting frame cv2.imshow("ActivityTracking", rs) if img_container.is_recording: video_output.write(rs) # the 'q' button is set as the # quitting button you may use any # desired button of your choice k = cv2.waitKey(1) if k == ord("q"): break elif k == ord("r"): img_container.toggle_recording() # After the loop release the cap object camera.release() video_output.release() # Destroy all the windows cv2.destroyAllWindows() if __name__ == "__main__": args = ArgParser().parse_args() main(args)