|
import torch |
|
import cv2 |
|
import sys |
|
import numpy as np |
|
import os |
|
from PIL import Image |
|
|
|
from transformers import Wav2Vec2Model, Wav2Vec2Processor |
|
|
|
class MyWav2Vec(): |
|
def __init__(self, model_path, device="cuda"): |
|
super(MyWav2Vec, self).__init__() |
|
self.processor = Wav2Vec2Processor.from_pretrained(model_path) |
|
self.wav2Vec = Wav2Vec2Model.from_pretrained(model_path).to(device) |
|
self.device = device |
|
print("### Wav2Vec model loaded ###") |
|
|
|
def forward(self, x): |
|
return self.wav2Vec(x).last_hidden_state |
|
|
|
def process(self, x): |
|
return self.processor(x, sampling_rate=16000, return_tensors="pt").input_values.to(self.device) |
|
|
|
class AutoFlow(): |
|
def __init__(self, auto_flow_dir, imh=512, imw=512): |
|
super(AutoFlow, self).__init__() |
|
model_dir = auto_flow_dir+'/third_lib/model_zoo/' |
|
cfg_file = model_dir + '/zdete_detector/mobilenet_v1_0.25.yaml' |
|
model_file = model_dir + '/zdete_detector/last_39.pt' |
|
self.bbox_predictor = BboxPredictor(cfg_file, model_file, imgsz=320, conf_thres=0.6, iou_thres=0.2) |
|
self.imh = imh |
|
self.imw = imw |
|
print("### AutoFlow bbox_predictor loaded ###") |
|
|
|
def frames_to_face_regions(self, frames, toPIL=True): |
|
|
|
face_region_list = [] |
|
for img in frames: |
|
|
|
bbox = self.bbox_predictor.predict(img)[0][0] |
|
xyxy = bbox[:4] |
|
score = bbox[4] |
|
xyxy = np.round(xyxy).astype('int') |
|
rb, re, cb, ce = xyxy[1], xyxy[3], xyxy[0], xyxy[2] |
|
face_mask = np.zeros((img.shape[0], img.shape[1])).astype('uint8') |
|
face_mask[rb:re,cb:ce] = 255 |
|
face_mask = cv2.resize(face_mask, ((self.imw, self.imh))) |
|
if toPIL: |
|
face_mask = Image.fromarray(face_mask) |
|
face_region_list.append(face_mask) |
|
return face_region_list |
|
|
|
def xyxy2x0y0wh(bbox): |
|
x0, y0, x1, y1 = bbox[:4] |
|
return [x0, y0, x1-x0, y1-y0] |
|
|
|
def video_to_frame(video_path: str, interval=1, max_frame=None, imh=None, imw=None, is_return_sum=False, is_rgb=False): |
|
vidcap = cv2.VideoCapture(video_path) |
|
success = True |
|
|
|
key_frames = [] |
|
sum_frames = None |
|
count = 0 |
|
while success: |
|
success, image = vidcap.read() |
|
if image is not None: |
|
if is_rgb: |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
if imh is not None and imw is not None: |
|
image = img_resize(image, imh=None, imw=None) |
|
|
|
if count % interval == 0: |
|
key_frames.append(image) |
|
if is_return_sum: |
|
if sum_frames is None: |
|
sum_frames = image.copy().astype('float32') |
|
else: |
|
sum_frames = sum_frames + image |
|
|
|
count += 1 |
|
|
|
if max_frame != None: |
|
if count >= max_frame: |
|
break |
|
|
|
vidcap.release() |
|
if is_return_sum: |
|
return key_frames, sum_frames |
|
else: |
|
return key_frames |
|
|
|
def img_resize(input_img, imh=None, imw=None, max_val=512): |
|
if imh is not None and imw is not None: |
|
width, height = imw, imh |
|
else: |
|
height, width = input_img.shape[0], input_img.shape[1] |
|
if height > width: |
|
ratio = width/height |
|
height = max_val |
|
width = ratio * height |
|
else: |
|
ratio = height/width |
|
width = max_val |
|
height = ratio * width |
|
|
|
height = int(round(height/8)*8) |
|
width = int(round(width/8)*8) |
|
|
|
input_img = cv2.resize(input_img, (width, height)) |
|
return input_img |
|
|
|
def assign_audio_to_frame(audio_input, frame_num): |
|
audio_len = audio_input.shape[0] |
|
audio_per_frame = audio_len / frame_num |
|
audio_to_frame_list = [] |
|
for f_i in range(frame_num): |
|
start_idx = int(round(f_i * audio_per_frame)) |
|
end_idx = int(round((f_i + 1) * audio_per_frame)) |
|
if start_idx >= audio_len: |
|
start_idx = int(round(start_idx - audio_per_frame)) |
|
|
|
seg_audio = audio_input[start_idx:end_idx, :] |
|
if type(seg_audio) == np.ndarray: |
|
seg_audio = seg_audio.mean(axis=0, keepdims=True) |
|
elif torch.is_tensor(seg_audio): |
|
seg_audio = seg_audio.mean(dim=0, keepdim=True) |
|
audio_to_frame_list.append(seg_audio) |
|
|
|
if type(seg_audio) == np.ndarray: |
|
audio_to_frames = np.concatenate(audio_to_frame_list, 0) |
|
else: |
|
audio_to_frames = torch.cat(audio_to_frame_list, 0) |
|
return audio_to_frames |
|
|
|
def assign_audio_to_frame_new(audio_input, frame_num, pad_frame): |
|
audio_len = audio_input.shape[0] |
|
audio_to_frame_list = [] |
|
for f_i in range(frame_num): |
|
mid_index = int(f_i / frame_num * audio_len) |
|
start_idx = mid_index - pad_frame |
|
end_idx = mid_index + pad_frame + 1 |
|
|
|
if start_idx < 0: |
|
start_idx = 0 |
|
end_idx = start_idx + pad_frame * 2 + 1 |
|
if end_idx >= audio_len: |
|
end_idx = audio_len - 1 |
|
start_idx = end_idx - (pad_frame * 2 + 1) |
|
|
|
seg_audio = audio_input[None, start_idx:end_idx, :] |
|
audio_to_frame_list.append(seg_audio) |
|
|
|
if type(seg_audio) == np.ndarray: |
|
audio_to_frames = np.concatenate(audio_to_frame_list, 0) |
|
else: |
|
audio_to_frames = torch.cat(audio_to_frame_list, 0) |
|
return audio_to_frames |
|
|
|
class DotDict(dict): |
|
def __init__(self, *args, **kwargs): |
|
super(DotDict, self).__init__(*args, **kwargs) |
|
|
|
def __getattr__(self, key): |
|
value = self[key] |
|
if isinstance(value, dict): |
|
value = DotDict(value) |
|
return value |