import os
import signal
import time
import csv
import sys
import warnings
import random
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import numpy as np
import time
import pprint
from loguru import logger
import smplx
from torch.utils.tensorboard import SummaryWriter
import wandb
import matplotlib.pyplot as plt
from utils import config, logger_tools, other_tools_hf, metric, data_transfer, other_tools
from dataloaders import data_tools
from dataloaders.build_vocab import Vocab
from optimizers.optim_factory import create_optimizer
from optimizers.scheduler_factory import create_scheduler
from optimizers.loss_factory import get_loss_func
from dataloaders.data_tools import joints_list
from utils import rotation_conversions as rc
import soundfile as sf
import librosa
import subprocess
from transformers import pipeline
from diffusion.model_util import create_gaussian_diffusion
from diffusion.resample import create_named_schedule_sampler
from models.vq.model import RVQVAE
import train
import spaces
command = ["bash","./demo/install_mfa.sh"]
result = subprocess.run(command, capture_output=True, text=True)
print("debug1: ", result)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-tiny.en",
chunk_length_s=30,
device=device,
)
debug = False
class BaseTrainer(object):
def __init__(self, args,ap):
args.use_ddim=True
hf_dir = "hf"
time_local = time.localtime()
time_name_expend = "%02d%02d_%02d%02d%02d_"%(time_local[1], time_local[2],time_local[3], time_local[4], time_local[5])
self.time_name_expend = time_name_expend
tmp_dir = args.out_path + "custom/"+ time_name_expend + hf_dir
if not os.path.exists(tmp_dir + "/"):
os.makedirs(tmp_dir + "/")
self.audio_path = tmp_dir + "/tmp.wav"
sf.write(self.audio_path, ap[1], ap[0])
audio, ssr = librosa.load(self.audio_path,sr=args.audio_sr)
# use asr model to get corresponding text transcripts
file_path = tmp_dir+"/tmp.lab"
self.textgrid_path = tmp_dir + "/tmp.TextGrid"
if not debug:
text = pipe(audio, batch_size=8)["text"]
with open(file_path, "w", encoding="utf-8") as file:
file.write(text)
# use montreal forced aligner to get textgrid
command = ["mfa", "align", tmp_dir, "english_us_arpa", "english_us_arpa", tmp_dir]
result = subprocess.run(command, capture_output=True, text=True)
print("debug2: ", result)
ap = (ssr, audio)
self.args = args
self.rank = 0 # dist.get_rank()
args.textgrid_file_path = self.textgrid_path
args.audio_file_path = self.audio_path
self.rank = 0 # dist.get_rank()
self.checkpoint_path = tmp_dir
args.tmp_dir = tmp_dir
if self.rank == 0:
self.test_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "test")
self.test_loader = torch.utils.data.DataLoader(
self.test_data,
batch_size=1,
shuffle=False,
num_workers=args.loader_workers,
drop_last=False,
)
logger.info(f"Init test dataloader success")
model_module = __import__(f"models.{args.model}", fromlist=["something"])
self.model = torch.nn.DataParallel(getattr(model_module, args.g_name)(args), args.gpus).cuda()
if self.rank == 0:
logger.info(self.model)
logger.info(f"init {args.g_name} success")
self.smplx = smplx.create(
self.args.data_path_1+"smplx_models/",
model_type='smplx',
gender='NEUTRAL_2020',
use_face_contour=False,
num_betas=300,
num_expression_coeffs=100,
ext='npz',
use_pca=False,
).to(self.rank).eval()
self.args = args
self.ori_joint_list = joints_list[self.args.ori_joints]
self.tar_joint_list_face = joints_list["beat_smplx_face"]
self.tar_joint_list_upper = joints_list["beat_smplx_upper"]
self.tar_joint_list_hands = joints_list["beat_smplx_hands"]
self.tar_joint_list_lower = joints_list["beat_smplx_lower"]
self.joint_mask_face = np.zeros(len(list(self.ori_joint_list.keys()))*3)
self.joints = 55
for joint_name in self.tar_joint_list_face:
self.joint_mask_face[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
self.joint_mask_upper = np.zeros(len(list(self.ori_joint_list.keys()))*3)
for joint_name in self.tar_joint_list_upper:
self.joint_mask_upper[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
self.joint_mask_hands = np.zeros(len(list(self.ori_joint_list.keys()))*3)
for joint_name in self.tar_joint_list_hands:
self.joint_mask_hands[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
self.joint_mask_lower = np.zeros(len(list(self.ori_joint_list.keys()))*3)
for joint_name in self.tar_joint_list_lower:
self.joint_mask_lower[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
self.tracker = other_tools.EpochTracker(["fid", "l1div", "bc", "rec", "trans", "vel", "transv", 'dis', 'gen', 'acc', 'transa', 'exp', 'lvd', 'mse', "cls", "rec_face", "latent", "cls_full", "cls_self", "cls_word", "latent_word","latent_self","predict_x0_loss"], [False,True,True, False, False, False, False, False, False, False, False, False, False, False, False, False, False,False, False, False,False,False,False])
vq_model_module = __import__(f"models.motion_representation", fromlist=["something"])
self.args.vae_layer = 2
self.args.vae_length = 256
self.args.vae_test_dim = 106
self.vq_model_face = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank)
other_tools.load_checkpoints(self.vq_model_face, "./datasets/hub/pretrained_vq/face_vertex_1layer_790.bin", args.e_name)
vq_type = self.args.vqvae_type
if vq_type=="vqvae":
self.args.vae_layer = 4
self.args.vae_test_dim = 78
self.vq_model_upper = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank)
other_tools.load_checkpoints(self.vq_model_upper, args.vqvae_upper_path, args.e_name)
self.args.vae_test_dim = 180
self.vq_model_hands = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank)
other_tools.load_checkpoints(self.vq_model_hands, args.vqvae_hands_path, args.e_name)
self.args.vae_test_dim = 54
self.args.vae_layer = 4
self.vq_model_lower = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank)
other_tools.load_checkpoints(self.vq_model_lower, args.vqvae_lower_path, args.e_name)
elif vq_type=="rvqvae":
args.num_quantizers = 6
args.shared_codebook = False
args.quantize_dropout_prob = 0.2
args.mu = 0.99
args.nb_code = 512
args.code_dim = 512
args.code_dim = 512
args.down_t = 2
args.stride_t = 2
args.width = 512
args.depth = 3
args.dilation_growth_rate = 3
args.vq_act = "relu"
args.vq_norm = None
dim_pose = 78
args.body_part = "upper"
self.vq_model_upper = RVQVAE(args,
dim_pose,
args.nb_code,
args.code_dim,
args.code_dim,
args.down_t,
args.stride_t,
args.width,
args.depth,
args.dilation_growth_rate,
args.vq_act,
args.vq_norm)
dim_pose = 180
args.body_part = "hands"
self.vq_model_hands = RVQVAE(args,
dim_pose,
args.nb_code,
args.code_dim,
args.code_dim,
args.down_t,
args.stride_t,
args.width,
args.depth,
args.dilation_growth_rate,
args.vq_act,
args.vq_norm)
dim_pose = 54
if args.use_trans:
dim_pose = 57
self.args.vqvae_lower_path = self.args.vqvae_lower_trans_path
args.body_part = "lower"
self.vq_model_lower = RVQVAE(args,
dim_pose,
args.nb_code,
args.code_dim,
args.code_dim,
args.down_t,
args.stride_t,
args.width,
args.depth,
args.dilation_growth_rate,
args.vq_act,
args.vq_norm)
self.vq_model_upper.load_state_dict(torch.load(self.args.vqvae_upper_path)['net'])
self.vq_model_hands.load_state_dict(torch.load(self.args.vqvae_hands_path)['net'])
self.vq_model_lower.load_state_dict(torch.load(self.args.vqvae_lower_path)['net'])
self.vqvae_latent_scale = self.args.vqvae_latent_scale
self.vq_model_upper.eval().to(self.rank)
self.vq_model_hands.eval().to(self.rank)
self.vq_model_lower.eval().to(self.rank)
self.args.vae_test_dim = 61
self.args.vae_layer = 4
self.args.vae_test_dim = 330
self.args.vae_layer = 4
self.args.vae_length = 240
self.vq_model_face.eval()
self.vq_model_upper.eval()
self.vq_model_hands.eval()
self.vq_model_lower.eval()
self.cls_loss = nn.NLLLoss().to(self.rank)
self.reclatent_loss = nn.MSELoss().to(self.rank)
self.vel_loss = torch.nn.L1Loss(reduction='mean').to(self.rank)
self.rec_loss = get_loss_func("GeodesicLoss").to(self.rank)
self.log_softmax = nn.LogSoftmax(dim=2).to(self.rank)
self.diffusion = create_gaussian_diffusion(use_ddim=args.use_ddim)
self.schedule_sampler_type = 'uniform'
self.schedule_sampler = create_named_schedule_sampler(self.schedule_sampler_type, self.diffusion)
self.mean = np.load(args.mean_pose_path)
self.std = np.load(args.std_pose_path)
self.use_trans = args.use_trans
if self.use_trans:
self.trans_mean = np.load(args.mean_trans_path)
self.trans_std = np.load(args.std_trans_path)
self.trans_mean = torch.from_numpy(self.trans_mean).cuda()
self.trans_std = torch.from_numpy(self.trans_std).cuda()
joints = [3,6,9,12,13,14,15,16,17,18,19,20,21]
upper_body_mask = []
for i in joints:
upper_body_mask.extend([i*6, i*6+1, i*6+2, i*6+3, i*6+4, i*6+5])
joints = list(range(25,55))
hands_body_mask = []
for i in joints:
hands_body_mask.extend([i*6, i*6+1, i*6+2, i*6+3, i*6+4, i*6+5])
joints = [0,1,2,4,5,7,8,10,11]
lower_body_mask = []
for i in joints:
lower_body_mask.extend([i*6, i*6+1, i*6+2, i*6+3, i*6+4, i*6+5])
self.mean_upper = self.mean[upper_body_mask]
self.mean_hands = self.mean[hands_body_mask]
self.mean_lower = self.mean[lower_body_mask]
self.std_upper = self.std[upper_body_mask]
self.std_hands = self.std[hands_body_mask]
self.std_lower = self.std[lower_body_mask]
self.mean_upper = torch.from_numpy(self.mean_upper).cuda()
self.mean_hands = torch.from_numpy(self.mean_hands).cuda()
self.mean_lower = torch.from_numpy(self.mean_lower).cuda()
self.std_upper = torch.from_numpy(self.std_upper).cuda()
self.std_hands = torch.from_numpy(self.std_hands).cuda()
self.std_lower = torch.from_numpy(self.std_lower).cuda()
def inverse_selection(self, filtered_t, selection_array, n):
original_shape_t = np.zeros((n, selection_array.size))
selected_indices = np.where(selection_array == 1)[0]
for i in range(n):
original_shape_t[i, selected_indices] = filtered_t[i]
return original_shape_t
def inverse_selection_tensor(self, filtered_t, selection_array, n):
selection_array = torch.from_numpy(selection_array).cuda()
original_shape_t = torch.zeros((n, 165)).cuda()
selected_indices = torch.where(selection_array == 1)[0]
for i in range(n):
original_shape_t[i, selected_indices] = filtered_t[i]
return original_shape_t
def _load_data(self, dict_data):
tar_pose_raw = dict_data["pose"]
tar_pose = tar_pose_raw[:, :, :165].to(self.rank)
tar_contact = tar_pose_raw[:, :, 165:169].to(self.rank)
tar_trans = dict_data["trans"].to(self.rank)
tar_trans_v = dict_data["trans_v"].to(self.rank)
tar_exps = dict_data["facial"].to(self.rank)
in_audio = dict_data["audio"].to(self.rank)
in_word = dict_data["word"].to(self.rank)
tar_beta = dict_data["beta"].to(self.rank)
tar_id = dict_data["id"].to(self.rank).long()
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
tar_pose_jaw = tar_pose[:, :, 66:69]
tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
tar_pose_hands = tar_pose[:, :, 25*3:55*3]
tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)]
tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)]
tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
tar_pose_lower = tar_pose_leg
tar4dis = torch.cat([tar_pose_jaw, tar_pose_upper, tar_pose_hands, tar_pose_leg], dim=2)
if self.args.pose_norm:
tar_pose_upper = (tar_pose_upper - self.mean_upper) / self.std_upper
tar_pose_hands = (tar_pose_hands - self.mean_hands) / self.std_hands
tar_pose_lower = (tar_pose_lower - self.mean_lower) / self.std_lower
if self.use_trans:
tar_trans_v = (tar_trans_v - self.trans_mean)/self.trans_std
tar_pose_lower = torch.cat([tar_pose_lower,tar_trans_v], dim=-1)
latent_face_top = self.vq_model_face.map2latent(tar_pose_face) # bs*n/4
latent_upper_top = self.vq_model_upper.map2latent(tar_pose_upper)
latent_hands_top = self.vq_model_hands.map2latent(tar_pose_hands)
latent_lower_top = self.vq_model_lower.map2latent(tar_pose_lower)
latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2)/self.args.vqvae_latent_scale
tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
style_feature = None
if self.args.use_motionclip:
motionclip_feat = tar_pose_6d[...,:22*6]
batch = {}
bs,seq,feat = motionclip_feat.shape
batch['x']=motionclip_feat.permute(0,2,1).contiguous()
batch['y']=torch.zeros(bs).int().cuda()
batch['mask']=torch.ones([bs,seq]).bool().cuda()
style_feature = self.motionclip.encoder(batch)['mu'].detach().float()
# print(tar_index_value_upper_top.shape, index_in.shape)
return {
"tar_pose_jaw": tar_pose_jaw,
"tar_pose_face": tar_pose_face,
"tar_pose_upper": tar_pose_upper,
"tar_pose_lower": tar_pose_lower,
"tar_pose_hands": tar_pose_hands,
'tar_pose_leg': tar_pose_leg,
"in_audio": in_audio,
"in_word": in_word,
"tar_trans": tar_trans,
"tar_exps": tar_exps,
"tar_beta": tar_beta,
"tar_pose": tar_pose,
"tar4dis": tar4dis,
"latent_face_top": latent_face_top,
"latent_upper_top": latent_upper_top,
"latent_hands_top": latent_hands_top,
"latent_lower_top": latent_lower_top,
"latent_in": latent_in,
"tar_id": tar_id,
"latent_all": latent_all,
"tar_pose_6d": tar_pose_6d,
"tar_contact": tar_contact,
"style_feature":style_feature,
}
def _g_test(self, loaded_data):
sample_fn = self.diffusion.p_sample_loop
if self.args.use_ddim:
sample_fn = self.diffusion.ddim_sample_loop
mode = 'test'
bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints
tar_pose = loaded_data["tar_pose"]
tar_beta = loaded_data["tar_beta"]
tar_exps = loaded_data["tar_exps"]
tar_contact = loaded_data["tar_contact"]
tar_trans = loaded_data["tar_trans"]
in_word = loaded_data["in_word"]
in_audio = loaded_data["in_audio"]
in_x0 = loaded_data['latent_in']
in_seed = loaded_data['latent_in']
remain = n%8
if remain != 0:
tar_pose = tar_pose[:, :-remain, :]
tar_beta = tar_beta[:, :-remain, :]
tar_trans = tar_trans[:, :-remain, :]
in_word = in_word[:, :-remain]
tar_exps = tar_exps[:, :-remain, :]
tar_contact = tar_contact[:, :-remain, :]
in_x0 = in_x0[:, :in_x0.shape[1]-(remain//self.args.vqvae_squeeze_scale), :]
in_seed = in_seed[:, :in_x0.shape[1]-(remain//self.args.vqvae_squeeze_scale), :]
n = n - remain
tar_pose_jaw = tar_pose[:, :, 66:69]
tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
tar_pose_hands = tar_pose[:, :, 25*3:55*3]
tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)]
tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)]
tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2)
tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
rec_all_face = []
rec_all_upper = []
rec_all_lower = []
rec_all_hands = []
vqvae_squeeze_scale = self.args.vqvae_squeeze_scale
roundt = (n - self.args.pre_frames * vqvae_squeeze_scale) // (self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale)
remain = (n - self.args.pre_frames * vqvae_squeeze_scale) % (self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale)
round_l = self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale
for i in range(0, roundt):
in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames * vqvae_squeeze_scale]
in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*self.args.pre_frames * vqvae_squeeze_scale]
in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames]
in_seed_tmp = in_seed[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+self.args.pre_frames]
in_x0_tmp = in_x0[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+self.args.pre_frames]
mask_val = torch.ones(bs, self.args.pose_length, self.args.pose_dims+3+4).float().cuda()
mask_val[:, :self.args.pre_frames, :] = 0.0
if i == 0:
in_seed_tmp = in_seed_tmp[:, :self.args.pre_frames, :]
else:
in_seed_tmp = last_sample[:, -self.args.pre_frames:, :]
cond_ = {'y':{}}
cond_['y']['audio'] = in_audio_tmp
cond_['y']['word'] = in_word_tmp
cond_['y']['id'] = in_id_tmp
cond_['y']['seed'] =in_seed_tmp
cond_['y']['mask'] = (torch.zeros([self.args.batch_size, 1, 1, self.args.pose_length]) < 1).cuda()
cond_['y']['style_feature'] = torch.zeros([bs, 512]).cuda()
shape_ = (bs, 1536, 1, 32)
sample = sample_fn(
self.model,
shape_,
clip_denoised=False,
model_kwargs=cond_,
skip_timesteps=0,
init_image=None,
progress=True,
dump_steps=None,
noise=None,
const_noise=False,
)
sample = sample.squeeze().permute(1,0).unsqueeze(0)
last_sample = sample.clone()
rec_latent_upper = sample[...,:512]
rec_latent_hands = sample[...,512:1024]
rec_latent_lower = sample[...,1024:1536]
if i == 0:
rec_all_upper.append(rec_latent_upper)
rec_all_hands.append(rec_latent_hands)
rec_all_lower.append(rec_latent_lower)
else:
rec_all_upper.append(rec_latent_upper[:, self.args.pre_frames:])
rec_all_hands.append(rec_latent_hands[:, self.args.pre_frames:])
rec_all_lower.append(rec_latent_lower[:, self.args.pre_frames:])
rec_all_upper = torch.cat(rec_all_upper, dim=1) * self.vqvae_latent_scale
rec_all_hands = torch.cat(rec_all_hands, dim=1) * self.vqvae_latent_scale
rec_all_lower = torch.cat(rec_all_lower, dim=1) * self.vqvae_latent_scale
rec_upper = self.vq_model_upper.latent2origin(rec_all_upper)[0]
rec_hands = self.vq_model_hands.latent2origin(rec_all_hands)[0]
rec_lower = self.vq_model_lower.latent2origin(rec_all_lower)[0]
if self.use_trans:
rec_trans_v = rec_lower[...,-3:]
rec_trans_v = rec_trans_v * self.trans_std + self.trans_mean
rec_trans = torch.zeros_like(rec_trans_v)
rec_trans = torch.cumsum(rec_trans_v, dim=-2)
rec_trans[...,1]=rec_trans_v[...,1]
rec_lower = rec_lower[...,:-3]
if self.args.pose_norm:
rec_upper = rec_upper * self.std_upper + self.mean_upper
rec_hands = rec_hands * self.std_hands + self.mean_hands
rec_lower = rec_lower * self.std_lower + self.mean_lower
n = n - remain
tar_pose = tar_pose[:, :n, :]
tar_exps = tar_exps[:, :n, :]
tar_trans = tar_trans[:, :n, :]
tar_beta = tar_beta[:, :n, :]
rec_exps = tar_exps
#rec_pose_jaw = rec_face[:, :, :6]
rec_pose_legs = rec_lower[:, :, :54]
bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1]
rec_pose_upper = rec_upper.reshape(bs, n, 13, 6)
rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)#
rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3)
rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs*n)
rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6)
rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower)
rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6)
rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3)
rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs*n)
rec_pose_hands = rec_hands.reshape(bs, n, 30, 6)
rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands)
rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3)
rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs*n)
rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover
rec_pose[:, 66:69] = tar_pose.reshape(bs*n, 55*3)[:, 66:69]
rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3))
rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6)
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3))
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
return {
'rec_pose': rec_pose,
'rec_trans': rec_trans,
'tar_pose': tar_pose,
'tar_exps': tar_exps,
'tar_beta': tar_beta,
'tar_trans': tar_trans,
'rec_exps': rec_exps,
}
def test_demo(self, epoch):
'''
input audio and text, output motion
do not calculate loss and metric
save video
'''
results_save_path = self.checkpoint_path + f"/{epoch}/"
if os.path.exists(results_save_path):
import shutil
shutil.rmtree(results_save_path)
os.makedirs(results_save_path)
start_time = time.time()
total_length = 0
test_seq_list = self.test_data.selected_file
align = 0
latent_out = []
latent_ori = []
l2_all = 0
lvel = 0
self.model.eval()
self.smplx.eval()
# self.eval_copy.eval()
with torch.no_grad():
for its, batch_data in enumerate(self.test_loader):
loaded_data = self._load_data(batch_data)
net_out = self._g_test(loaded_data)
tar_pose = net_out['tar_pose']
rec_pose = net_out['rec_pose']
tar_exps = net_out['tar_exps']
tar_beta = net_out['tar_beta']
rec_trans = net_out['rec_trans']
tar_trans = net_out['tar_trans']
rec_exps = net_out['rec_exps']
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
if (30/self.args.pose_fps) != 1:
assert 30%self.args.pose_fps == 0
n *= int(30/self.args.pose_fps)
tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1)
rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1)
rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6))
rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6)
tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6))
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6))
rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3)
tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6))
tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3)
tar_pose_np = tar_pose.detach().cpu().numpy()
rec_pose_np = rec_pose.detach().cpu().numpy()
rec_trans_np = rec_trans.detach().cpu().numpy().reshape(bs*n, 3)
rec_exp_np = rec_exps.detach().cpu().numpy().reshape(bs*n, 100)
tar_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100)
tar_trans_np = tar_trans.detach().cpu().numpy().reshape(bs*n, 3)
gt_npz = np.load("./demo/examples/2_scott_0_1_1.npz", allow_pickle=True)
results_npz_file_save_path = results_save_path+f"result_{self.time_name_expend[:-1]}"+'.npz'
np.savez(results_npz_file_save_path,
betas=gt_npz["betas"],
poses=rec_pose_np,
expressions=rec_exp_np,
trans=rec_trans_np,
model='smplx2020',
gender='neutral',
mocap_frame_rate = 30,
)
total_length += n
render_vid_path = other_tools_hf.render_one_sequence_no_gt(
results_npz_file_save_path,
# results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz',
results_save_path,
self.audio_path,
self.args.data_path_1+"smplx_models/",
use_matplotlib = False,
args = self.args,
)
result = [
gr.Video(value=render_vid_path, visible=True),
gr.File(value=results_npz_file_save_path, label="download motion and visualize in blender"),
]
end_time = time.time() - start_time
logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion")
return result
@logger.catch
@spaces.GPU
def syntalker(audio_path,sample_stratege):
args = config.parse_args()
if sample_stratege==0:
args.use_ddim=True
elif sample_stratege==1:
args.use_ddim=False
print(sample_stratege)
print(args.use_ddim)
#os.environ['TRANSFORMERS_CACHE'] = args.data_path_1 + "hub/"
if not sys.warnoptions:
warnings.simplefilter("ignore")
# dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
#logger_tools.set_args_and_logger(args, rank)
other_tools_hf.set_random_seed(args)
other_tools_hf.print_exp_info(args)
# return one intance of trainer
trainer = BaseTrainer(args, ap = audio_path)
other_tools.load_checkpoints(trainer.model, args.test_ckpt, args.g_name)
result = trainer.test_demo(999)
return result
examples = [
["demo/examples/2_scott_0_1_1.wav"],
["demo/examples/2_scott_0_2_2.wav"],
["demo/examples/2_scott_0_3_3.wav"],
["demo/examples/2_scott_0_4_4.wav"],
["demo/examples/2_scott_0_5_5.wav"],
]
demo = gr.Interface(
syntalker, # function
inputs=[
# gr.File(label="Please upload SMPL-X file with npz format here.", file_types=["npz", "NPZ"]),
gr.Audio(),
gr.Radio(choices=["DDIM", "DDPM"], label="Please select a sample strategy", type="index", value="DDIM"), # 0 for DDIM, 1 for DDPM
# gr.File(label="Please upload textgrid format file here.", file_types=["TextGrid", "Textgrid", "textgrid"])
], # input type
outputs=[
gr.Video(format="mp4", visible=True),
gr.File(label="download motion and visualize in blender")
],
title='SynTalker: Enabling Synergistic Full-Body Control in Prompt-Based Co-Speech Motion Generation',
description="1. Upload your audio.
\
2. Then, sit back and wait for the rendering to happen! This may take a while (e.g. 2 minutes)
\
3. After, you can view the videos.
\
4. Notice that we use a fix face animation, our method only produce body motion.
\
5. Use DDPM sample strategy will generate a better result, while it will take more inference time. \
",
article="Project links: [SynTalker](https://robinwitch.github.io/SynTalker-Page).
\
Reference links: [EMAGE](https://pantomatrix.github.io/EMAGE/). ",
examples=examples,
)
if __name__ == "__main__":
os.environ["MASTER_ADDR"]='127.0.0.1'
os.environ["MASTER_PORT"]='8675'
#os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
demo.launch(share=True)