Kaushik066's picture
Create app.py
d6d3ea7 verified
raw
history blame
9.96 kB
import torch
# For data transformation
from torchvision import transforms
# For ML Model
from transformers import VivitImageProcessor, VivitConfig, VivitModel
# For Data Loaders
from torch.utils.data import Dataset, DataLoader
# For GPU
from accelerate import Accelerator, notebook_launcher
# General Libraries
import os
import PIL
import gc
import pandas as pd
import numpy as np
from torch.nn import Linear, Softmax
import gradio as gr
# Mediapipe Library
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from mediapipe import solutions
from mediapipe.framework.formats import landmark_pb2
# Constants
CLIP_LENGTH = 32
FRAME_STEPS = 4
CLIP_SIZE = 224
BATCH_SIZE = 1
SEED = 42
# Set the device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# pretrained Model
MODEL_TRANSFORMER = 'google/vivit-b-16x2'
# Set Paths
model_path = 'vivit_pytorch_loss051.pt'
# Create Mediapipe Objects
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_hands = mp.solutions.hands
mp_face = mp.solutions.face_mesh
mp_pose = mp.solutions.pose
mp_holistic = mp.solutions.holistic
hand_model_path = 'hand_landmarker.task'
pose_model_path = 'pose_landmarker.task'
BaseOptions = mp.tasks.BaseOptions
HandLandmarker = mp.tasks.vision.HandLandmarker
HandLandmarkerOptions = mp.tasks.vision.HandLandmarkerOptions
PoseLandmarker = mp.tasks.vision.PoseLandmarker
PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions
VisionRunningMode = mp.tasks.vision.RunningMode
# Create a hand landmarker instance with the video mode:
options_hand = HandLandmarkerOptions(
base_options=BaseOptions(model_asset_path = hand_model_path),
running_mode=VisionRunningMode.VIDEO)
# Create a pose landmarker instance with the video mode:
options_pose = PoseLandmarkerOptions(
base_options=BaseOptions(model_asset_path=pose_model_path),
running_mode=VisionRunningMode.VIDEO)
detector_hand = vision.HandLandmarker.create_from_options(options_hand)
detector_pose = vision.PoseLandmarker.create_from_options(options_pose)
holistic = mp_holistic.Holistic(
static_image_mode=False,
model_complexity=1,
smooth_landmarks=True,
enable_segmentation=False,
refine_face_landmarks=True,
min_detection_confidence=0.5,
min_tracking_confidence=0.5
)
# Creating Dataloader
class CustomDatasetProd(Dataset):
def __init__(self, pixel_values):
self.pixel_values = pixel_values.to('cpu')
def __len__(self):
return len(self.pixel_values)
def __getitem__(self, idx):
item = {
'pixel_values': self.pixel_values[idx]
}
return item
class CreateDatasetProd():
def __init__(self
, clip_len
, clip_size
, frame_step
):
super().__init__()
self.clip_len = clip_len
self.clip_size = clip_size
self.frame_step = frame_step
# Define a sample transformation pipeline
self.transform_prod = transforms.v2.Compose([
transforms.v2.ToImage(),
transforms.v2.Resize((self.clip_size, self.clip_size)),
transforms.v2.ToDtype(torch.float32, scale=True)
])
def read_video(self, video_path):
# Read the video and convert to frames
vr = VideoReader(video_path)
total_frames = len(vr)
# Determine frame indices based on total frames
if total_frames < self.clip_len:
key_indices = list(range(total_frames))
for _ in range(self.clip_len - len(key_indices)):
key_indices.append(key_indices[-1])
else:
key_indices = list(range(0, total_frames, max(1, total_frames // self.clip_len)))[:self.clip_len]
#load frames
frames = vr.get_batch(key_indices)
del vr
# Force garbage collection
gc.collect()
return frames
def add_landmarks(self, video):
annotated_image = []
for frame in video:
#Convert pytorch Tensor to CV2 image
image = frame.permute(1, 2, 0).numpy() # Convert to (H, W, C) format for mediapipe to work
results = holistic.process(image)
mp_drawing.draw_landmarks(
image,
results.left_hand_landmarks,
mp_hands.HAND_CONNECTIONS,
landmark_drawing_spec = mp_drawing_styles.get_default_hand_landmarks_style(),
connection_drawing_spec = mp_drawing_styles.get_default_hand_connections_style()
)
mp_drawing.draw_landmarks(
image,
results.right_hand_landmarks,
mp_hands.HAND_CONNECTIONS,
landmark_drawing_spec = mp_drawing_styles.get_default_hand_landmarks_style(),
connection_drawing_spec = mp_drawing_styles.get_default_hand_connections_style()
)
mp_drawing.draw_landmarks(
image,
results.pose_landmarks,
mp_holistic.POSE_CONNECTIONS,
landmark_drawing_spec = mp_drawing_styles.get_default_pose_landmarks_style(),
#connection_drawing_spec = None
)
annotated_image.append(torch.from_numpy(image))
del image, results
# Force garbage collection
gc.collect()
return torch.stack(annotated_image)
def create_dataset(self, video_paths):
pixel_values = []
for path in tqdm(video_paths):
#print('Video', path)
# Read and process Videos
video = self.read_video(path)
video = transforms.v2.functional.resize(video.permute(0, 3, 1, 2), size=(self.clip_size*2, self.clip_size*3)) # Auto converts to (F, C, H, W) format
video = self.add_landmarks(video)
# Data Preperation for ML Model without Augmentation
video = self.transform_prod(video.permute(0, 3, 1, 2))
pixel_values.append(video.to(device))
del video
# Force garbage collection
gc.collect()
pixel_values = torch.stack(pixel_values).to(device)
return CustomDatasetProd(pixel_values=pixel_values)
# Creating Dataloader object
dataset_prod_obj = CreateDatasetProd(CLIP_LENGTH, CLIP_SIZE, FRAME_STEPS)
# Creating ML Model
class SignClassificationModel(torch.nn.Module):
def __init__(self, model_name, idx_to_label, label_to_idx, classes_len):
super(SignClassificationModel, self).__init__()
self.config = VivitConfig.from_pretrained(model_name, id2label=idx_to_label,
label2id=label_to_idx, hidden_dropout_prob=hyperparameters['dropout_rate'],
attention_probs_dropout_prob=hyperparameters['dropout_rate'],
return_dict=True)
self.backbone = VivitModel.from_pretrained(model_name, config=self.config) # Load ViT model
self.ff_head = Linear(self.backbone.config.hidden_size, classes_len)
def forward(self, images):
x = self.backbone(images).last_hidden_state # Extract embeddings
self.backbone.gradient_checkpointing_enable()
# Reduce along emb_dimension1 (axis 1)
reduced_tensor = x.mean(dim=1)
reduced_tensor = self.ff_head(reduced_tensor)
return reduced_tensor
# Load the model
model_pretrained = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
# Evaluation Function
def prod_function(model_pretrained, prod_dl):
# Initialize accelerator
accelerator = Accelerator()
if accelerator.is_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# The seed need to be set before we instantiate the model, as it will determine the random head.
set_seed(SEED)
# There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the prepare method.
accelerated_model, acclerated_prod_dl = accelerator.prepare(model_pretrained, prod_dl)
# Evaluate at the end of the epoch (distributed evaluation as we have 8 TPU cores)
accelerated_model.eval()
prod_preds = []
for batch in acclerated_prod_dl:
videos = batch['pixel_values']
with torch.no_grad():
outputs = accelerated_model(videos)
prod_logits = outputs.squeeze(1)
prod_pred = prod_logits.argmax(-1)
prod_preds.append(prod_pred)
return prod_preds
def translate_sign_language(gesture):
# Create Dataset
prod_ds = dataset_prod_obj.create_dataset(gesture)
prod_dl = DataLoader(prod_ds, batch_size=BATCH_SIZE)
# Run ML Model
predicted_prod_label = prod_function(model_pretrained, prod_dl)
# Identify the hand gesture
predicted_prod_label = torch.stack(predicted_prod_label)
predicted_prod_label = predicted_prod_label.squeeze(1)
idx_to_label = model_pretrained.config.id2label
for val in np.array(predicted_prod_label):
gesture_translation = idx_to_label[val]
return gesture_translation
with gr.Blocks() as demo:
gr.Markdown("# Indian Sign Language Translation App")
# Add webcam input for sign language video capture
video_input = gr.Video(source="webcam")
# Add a button or functionality to process the video
output = gr.Textbox()
# Set up the interface
video_input.change(translate_sign_language, inputs=video_input, outputs=output)
if __gesture__ == "__main__":
demo.launch()