tanthinhdt commited on
Commit
3b64f55
·
verified ·
1 Parent(s): b808ec1

feat: add application file

Browse files
Files changed (1) hide show
  1. app.py +164 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import gradio as gr
3
+ from mediapipe.python.solutions import holistic
4
+ from torchvision.transforms.v2 import Compose, Lambda, Normalize
5
+ from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
6
+ from utils import get_predictions, preprocess
7
+
8
+
9
+ title = '''
10
+
11
+ '''
12
+
13
+ cite_markdown = '''
14
+
15
+ '''
16
+
17
+ description = '''
18
+
19
+ '''
20
+
21
+ examples = [
22
+ ['samples/000_con_cho.mp4', 'Con chó'],
23
+ ['samples/001_con_meo.mp4', 'Con mèo'],
24
+ ['samples/005_con_tho.mp4', 'Con thỏ'],
25
+ ['samples/006_con_trau.mp4', 'Con trâu'],
26
+ ['samples/007_con_bo.mp4', 'Con bò'],
27
+ ['samples/008_con_de.mp4', 'Con dê'],
28
+ ['samples/009_con_heo.mp4', 'Con heo'],
29
+ ['samples/010_mau_den.mp4', 'Màu đen'],
30
+ ['samples/021_qua_man.mp4', 'Quả mận'],
31
+ ['samples/022_qua_dua.mp4', 'Quả dứa'],
32
+ ['samples/023_qua_dao.mp4', 'Quả đào'],
33
+ ['samples/029_qua_dua.mp4', 'Quả dưa'],
34
+ ['samples/031_me.mp4', 'Mẹ'],
35
+ ['samples/032_con_trai.mp4', 'Con trai'],
36
+ ['samples/033_con_gai.mp4', 'Con gái'],
37
+ ['samples/035_chong.mp4', 'Chồng'],
38
+ ['samples/044_mach.mp4', 'Mách'],
39
+ ['samples/051_chay.mp4', 'Chạy'],
40
+ ['samples/054_mua.mp4', 'Múa'],
41
+ ['samples/055_nau.mp4', 'Nấu'],
42
+ ['samples/057_nham_lan.mp4', 'Nhầm lẫn'],
43
+ ['samples/059_cam_trai.mp4', 'Cắm trại'],
44
+ ['samples/060_cung_cap.mp4', 'Cung cấp'],
45
+ ['samples/062_bat_buoc.mp4', 'Bắt buộc'],
46
+ ['samples/064_mua_ban.mp4', 'Mua bán'],
47
+ ['samples/066_khong_nen.mp4', 'Không nên'],
48
+ ['samples/067_khong_can.mp4', 'Không cần'],
49
+ ['samples/069_khong_nghe_loi.mp4', 'Không nghe lời'],
50
+ ['samples/073_ngot.mp4', 'Ngọt'],
51
+ ['samples/079_chat.mp4', 'Chật'],
52
+ ['samples/080_hep.mp4', 'Hẹp'],
53
+ ['samples/081_rong.mp4', 'Rộng'],
54
+ ['samples/082_dai.mp4', 'Dài'],
55
+ ['samples/085_om.mp4', 'Ốm'],
56
+ ['samples/086_map.mp4', 'Mập'],
57
+ ['samples/087_ngoan.mp4', 'Ngoan'],
58
+ ['samples/089_khoe.mp4', 'Khoẻ'],
59
+ ['samples/091_dau.mp4', 'Đau'],
60
+ ['samples/095_tot_bung.mp4', 'Tốt bụng'],
61
+ ['samples/097_thu_vi.mp4', 'Thú vị'],
62
+ ]
63
+
64
+
65
+ def inference(
66
+ video: str,
67
+ k: int,
68
+ model,
69
+ keypoints_detector,
70
+ data_height: int,
71
+ data_width: int,
72
+ model_input_height: int,
73
+ model_input_width: int,
74
+ device: str,
75
+ transform: Compose,
76
+ progress: gr.Progress,
77
+ ) -> tuple:
78
+ progress(0, desc='Preprocessing video')
79
+ inputs = preprocess(
80
+ model_num_frames=model.config.num_frames,
81
+ keypoints_detector=keypoints_detector,
82
+ source=video,
83
+ data_height=data_height,
84
+ data_width=data_width,
85
+ model_input_height=model_input_height,
86
+ model_input_width=model_input_width,
87
+ device=device,
88
+ transform=transform,
89
+ )
90
+
91
+ progress(1/2, desc='Getting predictions')
92
+ predictions = get_predictions(inputs=inputs, model=model, k=k)
93
+ output_message = ''
94
+ for i, prediction in enumerate(predictions):
95
+ output_message += f'{i}. {prediction["label"]} ({prediction["score"]})\n'
96
+ output_message = output_message.strip()
97
+
98
+ progress(1/2, desc='Completed')
99
+
100
+ return output_message
101
+
102
+
103
+ if __name__ == '__main__':
104
+ with open('config.yaml', 'r') as file:
105
+ config = yaml.safe_load(file)
106
+
107
+ device = 'cpu'
108
+ image_processor = VideoMAEImageProcessor.from_pretrained(config['model']['name'])
109
+ model = VideoMAEForVideoClassification.from_pretrained(config['model']['name'])
110
+ model = model.eval().to(device)
111
+
112
+ mean = image_processor.image_mean
113
+ std = image_processor.image_std
114
+ if 'shortest_edge' in image_processor.size:
115
+ height = width = image_processor.size['shortest_edge']
116
+ else:
117
+ height = image_processor.size['height']
118
+ width = image_processor.size['width']
119
+
120
+ keypoints_detector = holistic.Holistic(
121
+ static_image_mode=False,
122
+ model_complexity=2,
123
+ enable_segmentation=True,
124
+ refine_face_landmarks=True,
125
+ )
126
+
127
+ transform = Compose(
128
+ [
129
+ Lambda(lambda x: x / 255.0),
130
+ Normalize(mean=mean, std=std),
131
+ ]
132
+ )
133
+
134
+ inference(
135
+ model=model,
136
+ keypoints_detector=keypoints_detector,
137
+ source=config['inference']['source'],
138
+ data_height=config['data']['height'],
139
+ data_width=config['data']['width'],
140
+ model_input_height=height,
141
+ model_input_width=width,
142
+ device=device,
143
+ transform=transform,
144
+ )
145
+
146
+ iface = gr.Interface(
147
+ fn=inference,
148
+ inputs=[
149
+ 'video',
150
+ gr.components.Slider(
151
+ minimum=1,
152
+ maximum=5,
153
+ value=3,
154
+ step=1,
155
+ label='k',
156
+ info='Return top-k results',
157
+ ),
158
+ ],
159
+ outputs='text',
160
+ examples=examples,
161
+ title=title,
162
+ description=description,
163
+ )
164
+ iface.launch()