Spaces:
Runtime error
Runtime error
File size: 5,658 Bytes
3db63ac 4b27b73 3db63ac 060e0a7 b29e506 060e0a7 b29e506 08834d7 b29e506 48f025f 060e0a7 b29e506 d56b4d4 a69ad45 b29e506 d56b4d4 f915926 904c6e8 23a56cb a69ad45 23a56cb d56b4d4 904c6e8 22d154c 904c6e8 23a56cb 904c6e8 68f9039 904c6e8 a69ad45 904c6e8 68f9039 a69ad45 68f9039 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import sys
os.system('git clone https://github.com/facebookresearch/av_hubert.git')
os.chdir('/home/user/app/av_hubert')
os.system('git submodule init')
os.system('git submodule update')
os.chdir('/home/user/app/av_hubert/fairseq')
os.system('pip install ./')
os.system('pip install scipy')
os.system('pip install sentencepiece')
os.system('pip install python_speech_features')
os.system('pip install scikit-video')
os.system('pip install transformers')
os.system('pip install gradio==3.12')
os.system('pip install numpy==1.23.3')
os.chdir('/home/user/app/av_hubert/avhubert')
sys.path.append('/home/user/app/av_hubert')
sys.path.append('/home/user/app/av_hubert/avhubert')
print(sys.path)
print(os.listdir())
from fairseq import checkpoint_utils, options, tasks, utils
from argparse import Namespace
import dlib, cv2, os
import numpy as np
import skvideo
import skvideo.io
from tqdm import tqdm
from preparation.align_mouth import landmarks_interpolate, crop_patch, write_video_ffmpeg
from base64 import b64encode
import torch
import cv2
import tempfile
import fairseq
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.dataclass.configs import GenerationConfig
from huggingface_hub import hf_hub_download
import gradio as gr
user_dir = "/home/user/app/av_hubert/avhubert"
ckpt_path = hf_hub_download('vumichien/AV-HuBERT', 'model.pt')
face_detector_path = "/home/user/app/mmod_human_face_detector.dat"
face_predictor_path = "/home/user/app/shape_predictor_68_face_landmarks.dat"
mean_face_path = "/home/user/app/20words_mean_face.npy"
mouth_roi_path = "/home/user/app/roi.mp4"
utils.import_user_module(Namespace(user_dir=user_dir))
def detect_landmark(image, detector, predictor):
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
face_locations = detector(gray, 1)
coords = None
for (_, face_location) in enumerate(face_locations):
if torch.cuda.is_available():
rect = face_location.rect
else:
rect = face_location
shape = predictor(gray, rect)
coords = np.zeros((68, 2), dtype=np.int32)
for i in range(0, 68):
coords[i] = (shape.part(i).x, shape.part(i).y)
return coords
def preprocess_video(input_video_path):
if torch.cuda.is_available():
detector = dlib.cnn_face_detection_model_v1(face_detector_path)
else:
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor(face_predictor_path)
STD_SIZE = (256, 256)
mean_face_landmarks = np.load(mean_face_path)
stablePntsIDs = [33, 36, 39, 42, 45]
videogen = skvideo.io.vread(input_video_path)
frames = np.array([frame for frame in videogen])
landmarks = []
for frame in tqdm(frames):
landmark = detect_landmark(frame, detector, predictor)
landmarks.append(landmark)
preprocessed_landmarks = landmarks_interpolate(landmarks)
rois = crop_patch(input_video_path, preprocessed_landmarks, mean_face_landmarks, stablePntsIDs, STD_SIZE,
window_margin=12, start_idx=48, stop_idx=68, crop_height=96, crop_width=96)
write_video_ffmpeg(rois, mouth_roi_path, "/usr/bin/ffmpeg")
return mouth_roi_path
def predict(process_video):
num_frames = int(cv2.VideoCapture(process_video).get(cv2.CAP_PROP_FRAME_COUNT))
data_dir = tempfile.mkdtemp()
tsv_cont = ["/\n", f"test-0\t{process_video}\t{None}\t{num_frames}\t{int(16_000*num_frames/25)}\n"]
label_cont = ["DUMMY\n"]
with open(f"{data_dir}/test.tsv", "w") as fo:
fo.write("".join(tsv_cont))
with open(f"{data_dir}/test.wrd", "w") as fo:
fo.write("".join(label_cont))
modalities = ["video"]
gen_subset = "test"
gen_cfg = GenerationConfig(beam=20)
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
models = [model.eval().cuda() if torch.cuda.is_available() else model.eval() for model in models]
saved_cfg.task.modalities = modalities
saved_cfg.task.data = data_dir
saved_cfg.task.label_dir = data_dir
task = tasks.setup_task(saved_cfg.task)
task.load_dataset(gen_subset, task_cfg=saved_cfg.task)
generator = task.build_generator(models, gen_cfg)
def decode_fn(x):
dictionary = task.target_dictionary
symbols_ignore = generator.symbols_to_strip_from_output
symbols_ignore.add(dictionary.pad())
return task.datasets[gen_subset].label_processors[0].decode(x, symbols_ignore)
itr = task.get_batch_iterator(dataset=task.dataset(gen_subset)).next_epoch_itr(shuffle=False)
sample = next(itr)
if torch.cuda.is_available():
sample = utils.move_to_cuda(sample)
hypos = task.inference_step(generator, models, sample)
ref = decode_fn(sample['target'][0].int().cpu())
hypo = hypos[0][0]['tokens'].int().cpu()
hypo = decode_fn(hypo)
return hypo
# ---- Gradio Layout -----
demo = gr.Blocks()
demo.encrypt = False
text_output = gr.Textbox()
with demo:
with gr.Row():
video_in = gr.Video(label="Input Video", mirror_webcam=False, interactive=True)
video_out = gr.Video(label="Audio Visual Video", mirror_webcam=False, interactive=True)
with gr.Row():
detect_landmark_btn = gr.Button("Detect landmark")
detect_landmark_btn.click(preprocess_video, [video_in], [
video_out])
predict_btn = gr.Button("Predict")
predict_btn.click(predict, [video_out], [
text_output])
with gr.Row():
# video_lip = gr.Video(label="Audio Visual Video", mirror_webcam=False)
text_output.render()
demo.launch(debug=True) |