Spaces:
Running
on
Zero
Running
on
Zero
# encoding = 'utf-8' | |
import os.path as osp | |
import math | |
from rich.progress import track | |
from omegaconf import OmegaConf | |
import torch | |
import torch.nn as nn | |
from .talking_head_dit import TalkingHeadDiT_models | |
import sys | |
from ..schedulers.scheduling_ddim import DDIMScheduler | |
from ..schedulers.flow_matching import ModelSamplingDiscreteFlow | |
sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))))) | |
scheduler_map = { | |
"ddim": DDIMScheduler, | |
# "ddpm": DiffusionSchedule, | |
"flow_matching": ModelSamplingDiscreteFlow | |
} | |
lip_dims=[18, 19, 20, 36, 37, 38, 42, 43, 44, 51, 52, 53, 57, 58, 59, 60, 61, 62] | |
class MotionDiffusion(nn.Module): | |
def __init__(self, config, device="cuda", dtype=torch.float32, smo_wsize=3, loss_type="l2"): | |
super().__init__() | |
self.config = config | |
self.smo_wsize = smo_wsize | |
print(f"================================== Init Motion GeneratorV2 ==================================") | |
print(OmegaConf.to_yaml(self.config)) | |
motion_gen_config = config.motion_generator | |
motion_gen_params = motion_gen_config.params | |
audio_proj_config = config.audio_projector | |
audio_proj_params = audio_proj_config.params | |
scheduler_config = config.noise_scheduler | |
scheduler_params = scheduler_config.params | |
self.device = device | |
# init motion generator | |
self.talking_head_dit = TalkingHeadDiT_models[config.model_name]( | |
input_dim = motion_gen_params.input_dim * 2, | |
output_dim = motion_gen_params.output_dim, | |
seq_len = motion_gen_params.n_pred_frames, | |
audio_unit_len = audio_proj_params.sequence_length, | |
audio_blocks = audio_proj_params.blocks, | |
audio_dim = audio_proj_params.audio_feat_dim, | |
audio_tokens = audio_proj_params.context_tokens, | |
audio_embedder_type = audio_proj_params.audio_embedder_type, | |
audio_cond_dim = audio_proj_params.audio_cond_dim, | |
norm_type = motion_gen_params.norm_type, | |
qk_norm = motion_gen_params.qk_norm, | |
exp_dim = motion_gen_params.exp_dim | |
) | |
self.input_dim = motion_gen_params.input_dim | |
self.exp_dim = motion_gen_params.exp_dim | |
self.audio_feat_dim = audio_proj_params.audio_feat_dim | |
self.audio_seq_len = audio_proj_params.sequence_length | |
self.audio_blocks = audio_proj_params.blocks | |
self.audio_margin = (audio_proj_params.sequence_length - 1) // 2 | |
self.indices = ( | |
torch.arange(2 * self.audio_margin + 1) - self.audio_margin | |
).unsqueeze(0) # Generates [-2, -1, 0, 1, 2], size 1 x (2*self.audio_margin+1) | |
self.n_prev_frames = motion_gen_params.n_prev_frames | |
self.n_pred_frames = motion_gen_params.n_pred_frames | |
# init diffusion schedule | |
self.scheduler = scheduler_map[scheduler_config.type]( | |
num_train_timesteps = scheduler_params.num_train_timesteps, | |
beta_start = scheduler_params.beta_start, | |
beta_end = scheduler_params.beta_end, | |
beta_schedule = scheduler_params.mode, | |
prediction_type = scheduler_config.sample_mode, | |
time_shifting = scheduler_params.time_shifting, | |
) | |
self.scheduler_type = scheduler_config.type | |
self.eta = scheduler_params.eta | |
self.scheduler.set_timesteps(scheduler_params.num_inference_steps, device=self.device) | |
self.timesteps = self.scheduler.timesteps | |
print(f"time steps: {self.timesteps}") | |
self.sample_mode = scheduler_config.sample_mode | |
assert (self.sample_mode in ["noise", "sample"], f"Unknown sample mode {self.sample_mode}, should be noise or sample") | |
# init other params | |
self.audio_drop_ratio = config.train.audio_drop_ratio | |
self.pre_drop_ratio = config.train.pre_drop_ratio | |
self.null_audio_feat = nn.Parameter( | |
torch.randn(1, 1, 1, 1, self.audio_feat_dim), | |
requires_grad=True | |
).to(device=self.device, dtype=dtype) | |
self.null_motion_feat = nn.Parameter( | |
torch.randn(1, 1, self.input_dim), | |
requires_grad=True | |
).to(device=self.device, dtype=dtype) | |
# for segments fusion | |
self.overlap_len = min(16, self.n_pred_frames - 16) | |
self.fuse_alpha = torch.arange(self.overlap_len, device=self.device, dtype=dtype).reshape(1, -1, 1) / self.overlap_len | |
self.dtype = dtype | |
self.loss_type = loss_type | |
total_params = sum(p.numel() for p in self.parameters()) | |
print('Number of parameter: % .4fM' % (total_params / 1e6)) | |
print(f"================================== init Motion GeneratorV2: Done ==================================") | |
def _smooth(self, motion): | |
# motion, B x L x D | |
if self.smo_wsize <= 1: | |
return motion | |
new_motion = motion.clone() | |
n = motion.shape[1] | |
half_k = self.smo_wsize // 2 | |
for i in range(n): | |
ss = max(0, i - half_k) | |
ee = min(n, i + half_k + 1) | |
# only smooth head pose motion | |
motion[:, i, self.exp_dim:] = torch.mean(new_motion[:, ss:ee, self.exp_dim:], dim=1) | |
return motion | |
def _fuse(self, prev_motion, cur_motion): | |
r1 = prev_motion[:, -self.overlap_len:] | |
r2 = cur_motion[:, :self.overlap_len] | |
r_fuse = r1 * (1 - self.fuse_alpha) + r2 * self.fuse_alpha | |
prev_motion[:, -self.overlap_len:] = r_fuse # fuse last | |
return prev_motion | |
def sample_subclip( | |
self, | |
audio, | |
ref_kp, | |
prev_motion, | |
emo=None, | |
cfg_scale=1.15, | |
init_latents=None, | |
dynamic_threshold = None | |
): | |
# prepare audio feat | |
batch_size = audio.shape[0] | |
audio = audio.to(self.device) | |
if audio.ndim == 4: | |
audio = audio.unsqueeze(2) | |
# reference keypoints | |
ref_kp = ref_kp.view(batch_size, 1, -1) | |
# cfg | |
if cfg_scale > 1: | |
uncond_audio = self.null_audio_feat.expand( | |
batch_size, self.n_pred_frames, self.audio_seq_len, self.audio_blocks, -1 | |
) | |
audio = torch.cat([uncond_audio,audio], dim=0) | |
ref_kp = torch.cat([ref_kp] * 2, dim=0) | |
if emo is not None: | |
uncond_emo = torch.Tensor([self.talking_head_dit.num_emo_class]).long().to(self.device) | |
emo = torch.cat([uncond_emo,emo], dim=0) | |
ref_kp = ref_kp.repeat(1, audio.shape[1], 1) # B, L, kD | |
# prepare noisy motion | |
if init_latents is None: | |
latents = torch.randn((batch_size, self.n_pred_frames, self.input_dim)).to(self.device) | |
else: | |
latents = init_latents | |
prev_motion = prev_motion.expand_as(latents).to(dtype=self.dtype) | |
latents = latents.to(dtype=self.dtype) | |
audio = audio.to(dtype=self.dtype) | |
ref_kp = ref_kp.to(dtype=self.dtype) | |
for t in track(self.timesteps, description='🚀Denosing', total=len(self.timesteps)): | |
motion_in = torch.cat([prev_motion, latents], dim=-1) | |
step_in = torch.tensor([t] * batch_size, device=self.device, dtype=self.dtype) | |
if cfg_scale > 1: | |
motion_in = torch.cat([motion_in] * 2, dim=0) | |
step_in = torch.cat([step_in] * 2, dim=0) | |
# predict | |
pred = self.talking_head_dit( | |
motion = motion_in, | |
times = step_in, | |
audio = audio, | |
emo = emo, | |
audio_cond = ref_kp | |
) | |
if dynamic_threshold: | |
dt_ratio, dt_min, dt_max = dynamic_threshold | |
abs_results = pred.reshape(batch_size * 2, -1).abs() | |
s = torch.quantile(abs_results, dt_ratio, dim=1) | |
s = torch.clamp(s, min=dt_min, max=dt_max) | |
s = s[..., None, None] | |
pred = torch.clamp(pred, min=-s, max=s) | |
# CFG | |
if cfg_scale > 1: | |
# uncond_pred, emo_cond_pred, all_cond_pred = pred.chunk(3, dim=0) | |
# pred = uncond_pred + 8 * (emo_cond_pred - uncond_pred) + 1.2 * (all_cond_pred - emo_cond_pred) | |
uncond_pred, cond_pred = pred.chunk(2, dim=0) | |
pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred) | |
# Step | |
latents = self.scheduler.step(pred, t, latents, eta=self.eta, return_dict=False)[0] | |
self.talking_head_dit.bank=[] | |
return latents | |
def sample(self, audio, ref_kp, prev_motion, cfg_scale=1.15, audio_pad_mode="zero", emo=None,dynamic_threshold=None): | |
# prev_motion, B, 1, D | |
# for inference with any length audio | |
# crop audio into n_subdivision according to n_pred_frames | |
clip_len = audio.shape[0] | |
stride = self.n_pred_frames - self.overlap_len | |
if clip_len <= self.n_pred_frames: | |
n_subdivision = 1 | |
else: | |
n_subdivision = math.ceil((clip_len - self.n_pred_frames) / stride) + 1 | |
# padding | |
n_padding_frames = self.n_pred_frames + stride * (n_subdivision - 1) - clip_len | |
if n_padding_frames > 0: | |
padding_value = 0 | |
if audio_pad_mode == 'zero': | |
padding_value = torch.zeros_like(audio[-1:]) | |
elif audio_pad_mode == 'replicate': | |
padding_value = audio[-1:] | |
else: | |
raise ValueError(f'Unknown pad mode: {audio_pad_mode}') | |
audio = torch.cat( | |
[audio[:1]] * self.audio_margin \ | |
+ [audio] + [padding_value] * n_padding_frames \ | |
+ [audio[-1:]] * self.audio_margin, | |
dim=0 | |
) | |
center_indices = torch.arange( | |
self.audio_margin, | |
audio.shape[0] - self.audio_margin | |
).unsqueeze(1) + self.indices | |
audio_tensor = audio[center_indices] # T, L, b, aD | |
# add reference keypoints | |
motion_lst = [] | |
#init_latents = torch.randn((1, self.n_pred_frames, self.motion_dim)).to(device=self.device) | |
init_latents = None | |
# emotion label | |
if emo is not None: | |
emo = torch.Tensor([emo]).long().to(self.device) | |
start_idx = 0 | |
for i in range(0, n_subdivision): | |
print(f"Sample subclip {i+1}/{n_subdivision}") | |
end_idx = start_idx + self.n_pred_frames | |
audio_segment = audio_tensor[start_idx: end_idx].unsqueeze(0) | |
start_idx += stride | |
# debug | |
#print(f"scale:") | |
motion_segment = self.sample_subclip( | |
audio = audio_segment, | |
ref_kp = ref_kp, | |
prev_motion = prev_motion, | |
emo = emo, | |
cfg_scale = cfg_scale, | |
init_latents = init_latents, | |
dynamic_threshold = dynamic_threshold | |
) | |
# smooth | |
motion_segment = self._smooth(motion_segment) | |
# update prev motion | |
prev_motion = motion_segment[:, stride-1:stride].clone() | |
# save results | |
motion_coef = motion_segment | |
if i == n_subdivision - 1 and n_padding_frames > 0: | |
motion_coef = motion_coef[:, :-n_padding_frames] # delete padded frames | |
if len(motion_lst) > 0: | |
# fuse segments | |
motion_lst[-1] = self._fuse(motion_lst[-1], motion_coef) | |
motion_lst.append(motion_coef[:, self.overlap_len:]) | |
else: | |
motion_lst.append(motion_coef) | |
motion = torch.cat(motion_lst, dim=1) | |
# smooth for full clip | |
motion = self._smooth(motion) | |
motion = motion.squeeze() | |
return motion.float() | |