|
|
|
import torch |
|
|
|
from torchvision import transforms |
|
|
|
import transformers |
|
from transformers import VivitImageProcessor, VivitConfig, VivitModel |
|
from transformers import set_seed |
|
|
|
import datasets |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
from accelerate import Accelerator, notebook_launcher |
|
|
|
import decord |
|
from decord.bridge import set_bridge |
|
decord.bridge.set_bridge("torch") |
|
from decord import VideoReader |
|
|
|
import os |
|
import PIL |
|
import gc |
|
import pandas as pd |
|
import numpy as np |
|
from torch.nn import Linear, Softmax |
|
import gradio as gr |
|
import cv2 |
|
import io |
|
import tempfile |
|
|
|
import mediapipe as mp |
|
from mediapipe.tasks import python |
|
from mediapipe.tasks.python import vision |
|
from mediapipe import solutions |
|
from mediapipe.framework.formats import landmark_pb2 |
|
|
|
CLIP_LENGTH = 32 |
|
FRAME_STEPS = 4 |
|
CLIP_SIZE = 224 |
|
BATCH_SIZE = 1 |
|
SEED = 42 |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
MODEL_TRANSFORMER = 'google/vivit-b-16x2' |
|
|
|
model_path = 'vivit_pytorch_loss051.pt' |
|
data_path = 'signs' |
|
|
|
|
|
custom_css = """ |
|
#landmarked_video { |
|
max-height: 300px; |
|
max-width: 600px; |
|
object-fit: fill; |
|
width: 100%; |
|
height: 100%; |
|
} |
|
""" |
|
|
|
|
|
mp_drawing = mp.solutions.drawing_utils |
|
mp_drawing_styles = mp.solutions.drawing_styles |
|
mp_hands = mp.solutions.hands |
|
mp_face = mp.solutions.face_mesh |
|
mp_pose = mp.solutions.pose |
|
mp_holistic = mp.solutions.holistic |
|
hand_model_path = 'hand_landmarker.task' |
|
pose_model_path = 'pose_landmarker.task' |
|
|
|
BaseOptions = mp.tasks.BaseOptions |
|
HandLandmarker = mp.tasks.vision.HandLandmarker |
|
HandLandmarkerOptions = mp.tasks.vision.HandLandmarkerOptions |
|
PoseLandmarker = mp.tasks.vision.PoseLandmarker |
|
PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions |
|
VisionRunningMode = mp.tasks.vision.RunningMode |
|
|
|
|
|
options_hand = HandLandmarkerOptions( |
|
base_options=BaseOptions(model_asset_path = hand_model_path), |
|
running_mode=VisionRunningMode.VIDEO) |
|
|
|
|
|
options_pose = PoseLandmarkerOptions( |
|
base_options=BaseOptions(model_asset_path=pose_model_path), |
|
running_mode=VisionRunningMode.VIDEO) |
|
|
|
detector_hand = vision.HandLandmarker.create_from_options(options_hand) |
|
detector_pose = vision.PoseLandmarker.create_from_options(options_pose) |
|
|
|
holistic = mp_holistic.Holistic( |
|
static_image_mode=False, |
|
model_complexity=1, |
|
smooth_landmarks=True, |
|
enable_segmentation=False, |
|
refine_face_landmarks=True, |
|
min_detection_confidence=0.5, |
|
min_tracking_confidence=0.5 |
|
) |
|
|
|
|
|
class CreateDatasetProd(): |
|
def __init__(self |
|
, clip_len |
|
, clip_size |
|
, frame_step |
|
): |
|
super().__init__() |
|
self.clip_len = clip_len |
|
self.clip_size = clip_size |
|
self.frame_step = frame_step |
|
|
|
|
|
self.transform_prod = transforms.v2.Compose([ |
|
transforms.v2.ToImage(), |
|
transforms.v2.Resize((self.clip_size, self.clip_size)), |
|
transforms.v2.ToDtype(torch.float32, scale=True) |
|
]) |
|
|
|
def read_video(self, video_path): |
|
|
|
vr = VideoReader(video_path) |
|
total_frames = len(vr) |
|
|
|
|
|
if total_frames < self.clip_len: |
|
key_indices = list(range(total_frames)) |
|
for _ in range(self.clip_len - len(key_indices)): |
|
key_indices.append(key_indices[-1]) |
|
else: |
|
key_indices = list(range(0, total_frames, max(1, total_frames // self.clip_len)))[:self.clip_len] |
|
|
|
|
|
frames = vr.get_batch(key_indices) |
|
del vr |
|
|
|
gc.collect() |
|
|
|
return frames |
|
|
|
def add_landmarks(self, video): |
|
annotated_image = [] |
|
for frame in video: |
|
|
|
image = frame.permute(1, 2, 0).numpy() |
|
|
|
results = holistic.process(image) |
|
|
|
mp_drawing.draw_landmarks( |
|
image, |
|
results.left_hand_landmarks, |
|
mp_hands.HAND_CONNECTIONS, |
|
landmark_drawing_spec = mp_drawing_styles.get_default_hand_landmarks_style(), |
|
connection_drawing_spec = mp_drawing_styles.get_default_hand_connections_style() |
|
) |
|
mp_drawing.draw_landmarks( |
|
image, |
|
results.right_hand_landmarks, |
|
mp_hands.HAND_CONNECTIONS, |
|
landmark_drawing_spec = mp_drawing_styles.get_default_hand_landmarks_style(), |
|
connection_drawing_spec = mp_drawing_styles.get_default_hand_connections_style() |
|
) |
|
mp_drawing.draw_landmarks( |
|
image, |
|
results.pose_landmarks, |
|
mp_holistic.POSE_CONNECTIONS, |
|
landmark_drawing_spec = mp_drawing_styles.get_default_pose_landmarks_style(), |
|
|
|
) |
|
|
|
annotated_image.append(torch.from_numpy(image)) |
|
|
|
del image, results |
|
|
|
gc.collect() |
|
|
|
return torch.stack(annotated_image) |
|
|
|
def create_dataset(self, video_paths): |
|
|
|
video = self.read_video(video_paths) |
|
video = torch.from_numpy(video.asnumpy()) |
|
video = transforms.v2.functional.resize(video.permute(0, 3, 1, 2), size=(self.clip_size*2, self.clip_size*3)) |
|
video = self.add_landmarks(video) |
|
|
|
video = self.transform_prod(video.permute(0, 3, 1, 2)) |
|
pixel_values = video.to(device) |
|
|
|
|
|
del video |
|
gc.collect() |
|
|
|
return pixel_values |
|
|
|
|
|
dataset_prod_obj = CreateDatasetProd(CLIP_LENGTH, CLIP_SIZE, FRAME_STEPS) |
|
|
|
|
|
class SignClassificationModel(torch.nn.Module): |
|
def __init__(self, model_name, idx_to_label, label_to_idx, classes_len): |
|
super(SignClassificationModel, self).__init__() |
|
self.config = VivitConfig.from_pretrained(model_name, id2label=idx_to_label, |
|
label2id=label_to_idx, hidden_dropout_prob=hyperparameters['dropout_rate'], |
|
attention_probs_dropout_prob=hyperparameters['dropout_rate'], |
|
return_dict=True) |
|
self.backbone = VivitModel.from_pretrained(model_name, config=self.config) |
|
self.ff_head = Linear(self.backbone.config.hidden_size, classes_len) |
|
|
|
def forward(self, images): |
|
x = self.backbone(images).last_hidden_state |
|
self.backbone.gradient_checkpointing_enable() |
|
|
|
|
|
reduced_tensor = x.mean(dim=1) |
|
reduced_tensor = self.ff_head(reduced_tensor) |
|
return reduced_tensor |
|
|
|
|
|
model_pretrained = torch.load(model_path, map_location=device, weights_only=False) |
|
|
|
|
|
def prod_function(model_pretrained, prod_ds): |
|
|
|
accelerator = Accelerator() |
|
|
|
if accelerator.is_main_process: |
|
datasets.utils.logging.set_verbosity_warning() |
|
transformers.utils.logging.set_verbosity_info() |
|
else: |
|
datasets.utils.logging.set_verbosity_error() |
|
transformers.utils.logging.set_verbosity_error() |
|
|
|
|
|
set_seed(SEED) |
|
|
|
|
|
accelerated_model, acclerated_prod_ds = accelerator.prepare(model_pretrained, prod_ds) |
|
|
|
|
|
accelerated_model.eval() |
|
|
|
with torch.no_grad(): |
|
outputs = accelerated_model(acclerated_prod_ds.unsqueeze(0)) |
|
|
|
prod_logits = outputs.squeeze(1) |
|
prod_pred = prod_logits.argmax(-1) |
|
return prod_pred |
|
|
|
|
|
def save_video_to_mp4(video_tensor, fps=10): |
|
|
|
video_numpy = video_tensor.permute(0, 2, 3, 1).cpu().numpy() |
|
|
|
if video_numpy.max() <= 1.0: |
|
video_numpy = (video_numpy * 255).astype(np.uint8) |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") |
|
output_path = temp_file.name |
|
|
|
|
|
|
|
|
|
height, width, channels = video_numpy[0].shape |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
|
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
|
|
|
|
for frame in video_numpy: |
|
|
|
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
out.write(frame_bgr) |
|
|
|
out.release() |
|
|
|
|
|
return output_path |
|
|
|
|
|
def list_videos(): |
|
if os.path.exists(data_path): |
|
video_lst = [f for f in os.listdir(data_path) if f.endswith((".mp4", ".mov", ".MOV", ".webm", ".avi"))] |
|
return video_lst |
|
|
|
|
|
def play_video(selected_video): |
|
return os.path.join(data_path, selected_video) if selected_video else None |
|
|
|
|
|
|
|
def translate_sign_language(gesture): |
|
|
|
prod_ds = dataset_prod_obj.create_dataset(gesture) |
|
prod_video_path = save_video_to_mp4(prod_ds) |
|
|
|
|
|
|
|
predicted_prod_label = prod_function(model_pretrained, prod_ds) |
|
|
|
|
|
predicted_prod_label = predicted_prod_label.squeeze(0) |
|
|
|
idx_to_label = model_pretrained.config.id2label |
|
gesture_translation = idx_to_label[predicted_prod_label.cpu().numpy().item()] |
|
|
|
|
|
|
|
|
|
|
|
|
|
return gesture_translation , prod_video_path |
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
gr.Markdown("# Indian Sign Language Translation App") |
|
|
|
|
|
with gr.Tab("Gesture recognition"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1, variant="panel"): |
|
with gr.Row(height=350, variant="panel"): |
|
|
|
video_input = gr.Video(sources=["webcam"], format="mp4", label="Gesture") |
|
with gr.Row(variant="panel"): |
|
|
|
video_button = gr.Button("Submit") |
|
|
|
text_output = gr.Textbox(label="Translation in English") |
|
with gr.Column(scale=1, variant="panel"): |
|
with gr.Row(): |
|
|
|
video_output = gr.Video(interactive=False, autoplay=True, |
|
streaming=False, label="Landmarked Gesture" |
|
|
|
) |
|
|
|
video_button.click(translate_sign_language, inputs=video_input, outputs=[text_output, video_output]) |
|
|
|
|
|
|
|
with gr.Tab("Indian Sign Language gesture reference"): |
|
with gr.Row(height=500, variant="panel", equal_height=False, show_progress=True): |
|
with gr.Column(scale=1, variant="panel"): |
|
video_dropdown = gr.Dropdown(choices=list_videos(), label="ISL gestures", info="More gestures comming soon!") |
|
search_button = gr.Button("Search Gesture") |
|
with gr.Column(scale=1, variant="panel"): |
|
search_output = gr.Video(streaming=False, label="ISL gestures Video") |
|
|
|
search_button.click(play_video, inputs=video_dropdown, outputs=search_output) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|