Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
import cv2 | |
import sys | |
import numpy as np | |
import os | |
from PIL import Image | |
# from zdete import Predictor as BboxPredictor | |
from transformers import Wav2Vec2Model, Wav2Vec2Processor | |
class MyWav2Vec(): | |
def __init__(self, model_path, device="cuda"): | |
super(MyWav2Vec, self).__init__() | |
self.processor = Wav2Vec2Processor.from_pretrained(model_path) | |
self.wav2Vec = Wav2Vec2Model.from_pretrained(model_path).to(device) | |
self.device = device | |
print("### Wav2Vec model loaded ###") | |
def forward(self, x): | |
return self.wav2Vec(x).last_hidden_state | |
def process(self, x): | |
return self.processor(x, sampling_rate=16000, return_tensors="pt").input_values.to(self.device) | |
class AutoFlow(): | |
def __init__(self, auto_flow_dir, imh=512, imw=512): | |
super(AutoFlow, self).__init__() | |
model_dir = auto_flow_dir+'/third_lib/model_zoo/' | |
cfg_file = model_dir + '/zdete_detector/mobilenet_v1_0.25.yaml' | |
model_file = model_dir + '/zdete_detector/last_39.pt' | |
self.bbox_predictor = BboxPredictor(cfg_file, model_file, imgsz=320, conf_thres=0.6, iou_thres=0.2) | |
self.imh = imh | |
self.imw = imw | |
print("### AutoFlow bbox_predictor loaded ###") | |
def frames_to_face_regions(self, frames, toPIL=True): | |
# 输入是bgr numpy格式 | |
face_region_list = [] | |
for img in frames: | |
# img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
bbox = self.bbox_predictor.predict(img)[0][0] | |
xyxy = bbox[:4] | |
score = bbox[4] | |
xyxy = np.round(xyxy).astype('int') | |
rb, re, cb, ce = xyxy[1], xyxy[3], xyxy[0], xyxy[2] | |
face_mask = np.zeros((img.shape[0], img.shape[1])).astype('uint8') | |
face_mask[rb:re,cb:ce] = 255 | |
face_mask = cv2.resize(face_mask, ((self.imw, self.imh))) | |
if toPIL: | |
face_mask = Image.fromarray(face_mask) | |
face_region_list.append(face_mask) | |
return face_region_list | |
def xyxy2x0y0wh(bbox): | |
x0, y0, x1, y1 = bbox[:4] | |
return [x0, y0, x1-x0, y1-y0] | |
def video_to_frame(video_path: str, interval=1, max_frame=None, imh=None, imw=None, is_return_sum=False, is_rgb=False): | |
vidcap = cv2.VideoCapture(video_path) | |
success = True | |
key_frames = [] | |
sum_frames = None | |
count = 0 | |
while success: | |
success, image = vidcap.read() | |
if image is not None: | |
if is_rgb: | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
if imh is not None and imw is not None: | |
image = img_resize(image, imh=None, imw=None) | |
if count % interval == 0: | |
key_frames.append(image) | |
if is_return_sum: | |
if sum_frames is None: | |
sum_frames = image.copy().astype('float32') | |
else: | |
sum_frames = sum_frames + image | |
count += 1 | |
if max_frame != None: | |
if count >= max_frame: | |
break | |
vidcap.release() | |
if is_return_sum: | |
return key_frames, sum_frames | |
else: | |
return key_frames | |
def img_resize(input_img, imh=None, imw=None, max_val=512): | |
if imh is not None and imw is not None: | |
width, height = imw, imh | |
else: | |
height, width = input_img.shape[0], input_img.shape[1] | |
if height > width: | |
ratio = width/height | |
height = max_val | |
width = ratio * height | |
else: | |
ratio = height/width | |
width = max_val | |
height = ratio * width | |
height = int(round(height/8)*8) | |
width = int(round(width/8)*8) | |
input_img = cv2.resize(input_img, (width, height)) | |
return input_img | |
def assign_audio_to_frame(audio_input, frame_num): | |
audio_len = audio_input.shape[0] | |
audio_per_frame = audio_len / frame_num | |
audio_to_frame_list = [] | |
for f_i in range(frame_num): | |
start_idx = int(round(f_i * audio_per_frame)) | |
end_idx = int(round((f_i + 1) * audio_per_frame)) | |
if start_idx >= audio_len: | |
start_idx = int(round(start_idx - audio_per_frame)) | |
# print(f"frame_i:{f_i}, start_index:{start_idx}, end_index:{end_idx}, audio_length:{audio_input.shape}") | |
seg_audio = audio_input[start_idx:end_idx, :] | |
if type(seg_audio) == np.ndarray: | |
seg_audio = seg_audio.mean(axis=0, keepdims=True) # B * 20 * 768 | |
elif torch.is_tensor(seg_audio): | |
seg_audio = seg_audio.mean(dim=0, keepdim=True) | |
audio_to_frame_list.append(seg_audio) | |
if type(seg_audio) == np.ndarray: | |
audio_to_frames = np.concatenate(audio_to_frame_list, 0) | |
else: | |
audio_to_frames = torch.cat(audio_to_frame_list, 0) | |
return audio_to_frames | |
def assign_audio_to_frame_new(audio_input, frame_num, pad_frame): | |
audio_len = audio_input.shape[0] | |
audio_to_frame_list = [] | |
for f_i in range(frame_num): | |
mid_index = int(f_i / frame_num * audio_len) | |
start_idx = mid_index - pad_frame | |
end_idx = mid_index + pad_frame + 1 | |
if start_idx < 0: | |
start_idx = 0 | |
end_idx = start_idx + pad_frame * 2 + 1 | |
if end_idx >= audio_len: | |
end_idx = audio_len - 1 | |
start_idx = end_idx - (pad_frame * 2 + 1) | |
seg_audio = audio_input[None, start_idx:end_idx, :] | |
audio_to_frame_list.append(seg_audio) | |
if type(seg_audio) == np.ndarray: | |
audio_to_frames = np.concatenate(audio_to_frame_list, 0) | |
else: | |
audio_to_frames = torch.cat(audio_to_frame_list, 0) | |
return audio_to_frames | |
class DotDict(dict): | |
def __init__(self, *args, **kwargs): | |
super(DotDict, self).__init__(*args, **kwargs) | |
def __getattr__(self, key): | |
value = self[key] | |
if isinstance(value, dict): | |
value = DotDict(value) | |
return value |