Kaushik066 commited on
Commit
d6d3ea7
·
verified ·
1 Parent(s): 98362e5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +276 -0
app.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # For data transformation
3
+ from torchvision import transforms
4
+ # For ML Model
5
+ from transformers import VivitImageProcessor, VivitConfig, VivitModel
6
+ # For Data Loaders
7
+ from torch.utils.data import Dataset, DataLoader
8
+ # For GPU
9
+ from accelerate import Accelerator, notebook_launcher
10
+ # General Libraries
11
+ import os
12
+ import PIL
13
+ import gc
14
+ import pandas as pd
15
+ import numpy as np
16
+ from torch.nn import Linear, Softmax
17
+ import gradio as gr
18
+ # Mediapipe Library
19
+ import mediapipe as mp
20
+ from mediapipe.tasks import python
21
+ from mediapipe.tasks.python import vision
22
+ from mediapipe import solutions
23
+ from mediapipe.framework.formats import landmark_pb2
24
+ # Constants
25
+ CLIP_LENGTH = 32
26
+ FRAME_STEPS = 4
27
+ CLIP_SIZE = 224
28
+ BATCH_SIZE = 1
29
+ SEED = 42
30
+
31
+
32
+ # Set the device (GPU or CPU)
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ # pretrained Model
35
+ MODEL_TRANSFORMER = 'google/vivit-b-16x2'
36
+ # Set Paths
37
+ model_path = 'vivit_pytorch_loss051.pt'
38
+
39
+ # Create Mediapipe Objects
40
+ mp_drawing = mp.solutions.drawing_utils
41
+ mp_drawing_styles = mp.solutions.drawing_styles
42
+ mp_hands = mp.solutions.hands
43
+ mp_face = mp.solutions.face_mesh
44
+ mp_pose = mp.solutions.pose
45
+ mp_holistic = mp.solutions.holistic
46
+ hand_model_path = 'hand_landmarker.task'
47
+ pose_model_path = 'pose_landmarker.task'
48
+
49
+ BaseOptions = mp.tasks.BaseOptions
50
+ HandLandmarker = mp.tasks.vision.HandLandmarker
51
+ HandLandmarkerOptions = mp.tasks.vision.HandLandmarkerOptions
52
+ PoseLandmarker = mp.tasks.vision.PoseLandmarker
53
+ PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions
54
+ VisionRunningMode = mp.tasks.vision.RunningMode
55
+
56
+ # Create a hand landmarker instance with the video mode:
57
+ options_hand = HandLandmarkerOptions(
58
+ base_options=BaseOptions(model_asset_path = hand_model_path),
59
+ running_mode=VisionRunningMode.VIDEO)
60
+
61
+ # Create a pose landmarker instance with the video mode:
62
+ options_pose = PoseLandmarkerOptions(
63
+ base_options=BaseOptions(model_asset_path=pose_model_path),
64
+ running_mode=VisionRunningMode.VIDEO)
65
+
66
+ detector_hand = vision.HandLandmarker.create_from_options(options_hand)
67
+ detector_pose = vision.PoseLandmarker.create_from_options(options_pose)
68
+
69
+ holistic = mp_holistic.Holistic(
70
+ static_image_mode=False,
71
+ model_complexity=1,
72
+ smooth_landmarks=True,
73
+ enable_segmentation=False,
74
+ refine_face_landmarks=True,
75
+ min_detection_confidence=0.5,
76
+ min_tracking_confidence=0.5
77
+ )
78
+
79
+ # Creating Dataloader
80
+ class CustomDatasetProd(Dataset):
81
+ def __init__(self, pixel_values):
82
+ self.pixel_values = pixel_values.to('cpu')
83
+
84
+ def __len__(self):
85
+ return len(self.pixel_values)
86
+
87
+ def __getitem__(self, idx):
88
+ item = {
89
+ 'pixel_values': self.pixel_values[idx]
90
+ }
91
+ return item
92
+
93
+ class CreateDatasetProd():
94
+ def __init__(self
95
+ , clip_len
96
+ , clip_size
97
+ , frame_step
98
+ ):
99
+ super().__init__()
100
+ self.clip_len = clip_len
101
+ self.clip_size = clip_size
102
+ self.frame_step = frame_step
103
+
104
+ # Define a sample transformation pipeline
105
+ self.transform_prod = transforms.v2.Compose([
106
+ transforms.v2.ToImage(),
107
+ transforms.v2.Resize((self.clip_size, self.clip_size)),
108
+ transforms.v2.ToDtype(torch.float32, scale=True)
109
+ ])
110
+
111
+ def read_video(self, video_path):
112
+ # Read the video and convert to frames
113
+ vr = VideoReader(video_path)
114
+ total_frames = len(vr)
115
+
116
+ # Determine frame indices based on total frames
117
+ if total_frames < self.clip_len:
118
+ key_indices = list(range(total_frames))
119
+ for _ in range(self.clip_len - len(key_indices)):
120
+ key_indices.append(key_indices[-1])
121
+ else:
122
+ key_indices = list(range(0, total_frames, max(1, total_frames // self.clip_len)))[:self.clip_len]
123
+
124
+ #load frames
125
+ frames = vr.get_batch(key_indices)
126
+ del vr
127
+ # Force garbage collection
128
+ gc.collect()
129
+
130
+ return frames
131
+
132
+ def add_landmarks(self, video):
133
+ annotated_image = []
134
+ for frame in video:
135
+
136
+ #Convert pytorch Tensor to CV2 image
137
+ image = frame.permute(1, 2, 0).numpy() # Convert to (H, W, C) format for mediapipe to work
138
+
139
+ results = holistic.process(image)
140
+
141
+ mp_drawing.draw_landmarks(
142
+ image,
143
+ results.left_hand_landmarks,
144
+ mp_hands.HAND_CONNECTIONS,
145
+ landmark_drawing_spec = mp_drawing_styles.get_default_hand_landmarks_style(),
146
+ connection_drawing_spec = mp_drawing_styles.get_default_hand_connections_style()
147
+ )
148
+ mp_drawing.draw_landmarks(
149
+ image,
150
+ results.right_hand_landmarks,
151
+ mp_hands.HAND_CONNECTIONS,
152
+ landmark_drawing_spec = mp_drawing_styles.get_default_hand_landmarks_style(),
153
+ connection_drawing_spec = mp_drawing_styles.get_default_hand_connections_style()
154
+ )
155
+ mp_drawing.draw_landmarks(
156
+ image,
157
+ results.pose_landmarks,
158
+ mp_holistic.POSE_CONNECTIONS,
159
+ landmark_drawing_spec = mp_drawing_styles.get_default_pose_landmarks_style(),
160
+ #connection_drawing_spec = None
161
+ )
162
+
163
+ annotated_image.append(torch.from_numpy(image))
164
+
165
+ del image, results
166
+ # Force garbage collection
167
+ gc.collect()
168
+
169
+ return torch.stack(annotated_image)
170
+
171
+ def create_dataset(self, video_paths):
172
+ pixel_values = []
173
+ for path in tqdm(video_paths):
174
+ #print('Video', path)
175
+ # Read and process Videos
176
+ video = self.read_video(path)
177
+ video = transforms.v2.functional.resize(video.permute(0, 3, 1, 2), size=(self.clip_size*2, self.clip_size*3)) # Auto converts to (F, C, H, W) format
178
+ video = self.add_landmarks(video)
179
+ # Data Preperation for ML Model without Augmentation
180
+ video = self.transform_prod(video.permute(0, 3, 1, 2))
181
+ pixel_values.append(video.to(device))
182
+ del video
183
+ # Force garbage collection
184
+ gc.collect()
185
+
186
+ pixel_values = torch.stack(pixel_values).to(device)
187
+ return CustomDatasetProd(pixel_values=pixel_values)
188
+
189
+ # Creating Dataloader object
190
+ dataset_prod_obj = CreateDatasetProd(CLIP_LENGTH, CLIP_SIZE, FRAME_STEPS)
191
+
192
+ # Creating ML Model
193
+ class SignClassificationModel(torch.nn.Module):
194
+ def __init__(self, model_name, idx_to_label, label_to_idx, classes_len):
195
+ super(SignClassificationModel, self).__init__()
196
+ self.config = VivitConfig.from_pretrained(model_name, id2label=idx_to_label,
197
+ label2id=label_to_idx, hidden_dropout_prob=hyperparameters['dropout_rate'],
198
+ attention_probs_dropout_prob=hyperparameters['dropout_rate'],
199
+ return_dict=True)
200
+ self.backbone = VivitModel.from_pretrained(model_name, config=self.config) # Load ViT model
201
+ self.ff_head = Linear(self.backbone.config.hidden_size, classes_len)
202
+
203
+ def forward(self, images):
204
+ x = self.backbone(images).last_hidden_state # Extract embeddings
205
+ self.backbone.gradient_checkpointing_enable()
206
+
207
+ # Reduce along emb_dimension1 (axis 1)
208
+ reduced_tensor = x.mean(dim=1)
209
+ reduced_tensor = self.ff_head(reduced_tensor)
210
+ return reduced_tensor
211
+
212
+ # Load the model
213
+ model_pretrained = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
214
+
215
+ # Evaluation Function
216
+ def prod_function(model_pretrained, prod_dl):
217
+ # Initialize accelerator
218
+ accelerator = Accelerator()
219
+
220
+ if accelerator.is_main_process:
221
+ datasets.utils.logging.set_verbosity_warning()
222
+ transformers.utils.logging.set_verbosity_info()
223
+ else:
224
+ datasets.utils.logging.set_verbosity_error()
225
+ transformers.utils.logging.set_verbosity_error()
226
+
227
+ # The seed need to be set before we instantiate the model, as it will determine the random head.
228
+ set_seed(SEED)
229
+
230
+ # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the prepare method.
231
+ accelerated_model, acclerated_prod_dl = accelerator.prepare(model_pretrained, prod_dl)
232
+
233
+ # Evaluate at the end of the epoch (distributed evaluation as we have 8 TPU cores)
234
+ accelerated_model.eval()
235
+
236
+ prod_preds = []
237
+
238
+ for batch in acclerated_prod_dl:
239
+ videos = batch['pixel_values']
240
+ with torch.no_grad():
241
+ outputs = accelerated_model(videos)
242
+
243
+ prod_logits = outputs.squeeze(1)
244
+ prod_pred = prod_logits.argmax(-1)
245
+ prod_preds.append(prod_pred)
246
+ return prod_preds
247
+
248
+ def translate_sign_language(gesture):
249
+ # Create Dataset
250
+ prod_ds = dataset_prod_obj.create_dataset(gesture)
251
+ prod_dl = DataLoader(prod_ds, batch_size=BATCH_SIZE)
252
+
253
+ # Run ML Model
254
+ predicted_prod_label = prod_function(model_pretrained, prod_dl)
255
+
256
+ # Identify the hand gesture
257
+ predicted_prod_label = torch.stack(predicted_prod_label)
258
+ predicted_prod_label = predicted_prod_label.squeeze(1)
259
+
260
+ idx_to_label = model_pretrained.config.id2label
261
+ for val in np.array(predicted_prod_label):
262
+ gesture_translation = idx_to_label[val]
263
+
264
+ return gesture_translation
265
+
266
+ with gr.Blocks() as demo:
267
+ gr.Markdown("# Indian Sign Language Translation App")
268
+ # Add webcam input for sign language video capture
269
+ video_input = gr.Video(source="webcam")
270
+ # Add a button or functionality to process the video
271
+ output = gr.Textbox()
272
+ # Set up the interface
273
+ video_input.change(translate_sign_language, inputs=video_input, outputs=output)
274
+
275
+ if __gesture__ == "__main__":
276
+ demo.launch()