Spaces:
Runtime error
Runtime error
thinh-huynh-re
commited on
Commit
·
1e87f84
1
Parent(s):
d8653f1
Refactor
Browse files- run_opencv.py +12 -27
- utils/frame_rate.py +3 -1
- utils/img_container.py +24 -0
run_opencv.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
-
from typing import List,
|
|
|
2 |
import cv2
|
3 |
-
from pandas import DataFrame
|
4 |
-
from transformers import AutoFeatureExtractor, TimesformerForVideoClassification
|
5 |
import numpy as np
|
6 |
-
import torch
|
7 |
import pandas as pd
|
|
|
8 |
from torch import Tensor
|
|
|
|
|
|
|
9 |
|
10 |
-
from utils.frame_rate import FrameRate
|
11 |
|
12 |
def load_model(model_name: str):
|
13 |
if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name:
|
@@ -19,23 +20,6 @@ def load_model(model_name: str):
|
|
19 |
model = TimesformerForVideoClassification.from_pretrained(model_name)
|
20 |
return feature_extractor, model
|
21 |
|
22 |
-
class ImgContainer:
|
23 |
-
def __init__(self, frames_per_video: int = 8) -> None:
|
24 |
-
self.img: Optional[np.ndarray] = None # raw image
|
25 |
-
self.frame_rate: FrameRate = FrameRate()
|
26 |
-
self.imgs: List[np.ndarray] = []
|
27 |
-
self.frame_rate.reset()
|
28 |
-
self.frames_per_video = frames_per_video
|
29 |
-
self.rs: Optional[DataFrame] = None
|
30 |
-
|
31 |
-
def add_frame(self, frame: np.ndarray):
|
32 |
-
if len(img_container.imgs) >= frames_per_video:
|
33 |
-
self.imgs.pop(0)
|
34 |
-
self.imgs.append(frame)
|
35 |
-
|
36 |
-
@property
|
37 |
-
def ready(self):
|
38 |
-
return len(img_container.imgs) == self.frames_per_video
|
39 |
|
40 |
def inference():
|
41 |
if not img_container.ready:
|
@@ -50,7 +34,7 @@ def inference():
|
|
50 |
# model predicts one of the 400 Kinetics-400 classes
|
51 |
max_index = logits.argmax(-1).item()
|
52 |
predicted_label = model.config.id2label[max_index]
|
53 |
-
|
54 |
img_container.frame_rate.label = f"{predicted_label}_{logits[0][max_index]:.2f}%"
|
55 |
|
56 |
TOP_K = 12
|
@@ -67,6 +51,7 @@ def inference():
|
|
67 |
|
68 |
img_container.rs = pd.DataFrame(results, columns=("Label", "Confidence"))
|
69 |
|
|
|
70 |
def get_frames_per_video(model_name: str) -> int:
|
71 |
if "base-finetuned" in model_name:
|
72 |
return 8
|
@@ -100,7 +85,7 @@ num_skips = 0
|
|
100 |
# define a video capture object
|
101 |
vid = cv2.VideoCapture(0)
|
102 |
|
103 |
-
while
|
104 |
# Capture the video frame
|
105 |
# by frame
|
106 |
ret, frame = vid.read()
|
@@ -109,19 +94,19 @@ while(True):
|
|
109 |
|
110 |
img_container.img = frame
|
111 |
img_container.frame_rate.count()
|
112 |
-
|
113 |
if num_skips == 0:
|
114 |
img_container.add_frame(frame)
|
115 |
inference()
|
116 |
rs = img_container.frame_rate.show_fps(frame)
|
117 |
|
118 |
# Display the resulting frame
|
119 |
-
cv2.imshow(
|
120 |
|
121 |
# the 'q' button is set as the
|
122 |
# quitting button you may use any
|
123 |
# desired button of your choice
|
124 |
-
if cv2.waitKey(1) & 0xFF == ord(
|
125 |
break
|
126 |
|
127 |
# After the loop release the cap object
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
import cv2
|
|
|
|
|
4 |
import numpy as np
|
|
|
5 |
import pandas as pd
|
6 |
+
import torch
|
7 |
from torch import Tensor
|
8 |
+
from transformers import AutoFeatureExtractor, TimesformerForVideoClassification
|
9 |
+
|
10 |
+
from utils.img_container import ImgContainer
|
11 |
|
|
|
12 |
|
13 |
def load_model(model_name: str):
|
14 |
if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name:
|
|
|
20 |
model = TimesformerForVideoClassification.from_pretrained(model_name)
|
21 |
return feature_extractor, model
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
def inference():
|
25 |
if not img_container.ready:
|
|
|
34 |
# model predicts one of the 400 Kinetics-400 classes
|
35 |
max_index = logits.argmax(-1).item()
|
36 |
predicted_label = model.config.id2label[max_index]
|
37 |
+
|
38 |
img_container.frame_rate.label = f"{predicted_label}_{logits[0][max_index]:.2f}%"
|
39 |
|
40 |
TOP_K = 12
|
|
|
51 |
|
52 |
img_container.rs = pd.DataFrame(results, columns=("Label", "Confidence"))
|
53 |
|
54 |
+
|
55 |
def get_frames_per_video(model_name: str) -> int:
|
56 |
if "base-finetuned" in model_name:
|
57 |
return 8
|
|
|
85 |
# define a video capture object
|
86 |
vid = cv2.VideoCapture(0)
|
87 |
|
88 |
+
while True:
|
89 |
# Capture the video frame
|
90 |
# by frame
|
91 |
ret, frame = vid.read()
|
|
|
94 |
|
95 |
img_container.img = frame
|
96 |
img_container.frame_rate.count()
|
97 |
+
|
98 |
if num_skips == 0:
|
99 |
img_container.add_frame(frame)
|
100 |
inference()
|
101 |
rs = img_container.frame_rate.show_fps(frame)
|
102 |
|
103 |
# Display the resulting frame
|
104 |
+
cv2.imshow("TimeSFormer", rs)
|
105 |
|
106 |
# the 'q' button is set as the
|
107 |
# quitting button you may use any
|
108 |
# desired button of your choice
|
109 |
+
if cv2.waitKey(1) & 0xFF == ord("q"):
|
110 |
break
|
111 |
|
112 |
# After the loop release the cap object
|
utils/frame_rate.py
CHANGED
@@ -1,6 +1,8 @@
|
|
|
|
1 |
from typing import Optional
|
|
|
|
|
2 |
import numpy as np
|
3 |
-
import time, cv2
|
4 |
|
5 |
|
6 |
class FrameRate:
|
|
|
1 |
+
import time
|
2 |
from typing import Optional
|
3 |
+
|
4 |
+
import cv2
|
5 |
import numpy as np
|
|
|
6 |
|
7 |
|
8 |
class FrameRate:
|
utils/img_container.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from pandas import DataFrame
|
5 |
+
|
6 |
+
from .frame_rate import FrameRate
|
7 |
+
|
8 |
+
|
9 |
+
class ImgContainer:
|
10 |
+
def __init__(self, frames_per_video: int = 8) -> None:
|
11 |
+
self.img: Optional[np.ndarray] = None # raw image
|
12 |
+
self.frame_rate: FrameRate = FrameRate()
|
13 |
+
self.imgs: List[np.ndarray] = []
|
14 |
+
self.frames_per_video = frames_per_video
|
15 |
+
self.rs: Optional[DataFrame] = None
|
16 |
+
|
17 |
+
def add_frame(self, frame: np.ndarray) -> None:
|
18 |
+
if len(self.imgs) >= self.frames_per_video:
|
19 |
+
self.imgs.pop(0)
|
20 |
+
self.imgs.append(frame)
|
21 |
+
|
22 |
+
@property
|
23 |
+
def ready(self):
|
24 |
+
return len(self.imgs) == self.frames_per_video
|