|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from typing import Optional, List, Dict, Union |
|
|
|
from ..builder import ARCHITECTURES, build_loss, build_submodule |
|
from ..utils.gaussian_diffusion import create_named_schedule_sampler, build_diffusion |
|
from ..utils.mask_helper import expand_mask_to_all |
|
from .base_architecture import BaseArchitecture |
|
|
|
|
|
def set_requires_grad(nets: Union[torch.nn.Module, List[torch.nn.Module]], requires_grad: bool = False): |
|
"""Set requires_grad for all the networks. |
|
|
|
Args: |
|
nets (nn.Module | list[nn.Module]): A list of networks or a single network. |
|
requires_grad (bool): Whether the networks require gradients or not. |
|
""" |
|
if not isinstance(nets, list): |
|
nets = [nets] |
|
for net in nets: |
|
if net is not None: |
|
for param in net.parameters(): |
|
param.requires_grad = requires_grad |
|
|
|
|
|
@ARCHITECTURES.register_module() |
|
class MotionDiffusion(BaseArchitecture): |
|
""" |
|
Motion Diffusion architecture for modeling and generating motion sequences using diffusion models. |
|
|
|
Args: |
|
dataset_name (Optional[str]): Name of the dataset being used (e.g., 'kit_ml', 'human_ml3d'). |
|
model (dict): Configuration for the submodule (e.g., the motion generation model). |
|
loss_recon (dict): Configuration for the reconstruction loss. |
|
loss_reduction (str): Specifies the reduction method for the loss. Defaults to 'frame'. |
|
use_loss_score (bool): Whether to use a scoring mechanism for loss calculation. Defaults to False. |
|
diffusion_train (dict): Configuration for the diffusion model during training. |
|
diffusion_test (dict): Configuration for the diffusion model during testing. |
|
sampler_type (str): The type of sampler to use. Defaults to 'uniform'. |
|
init_cfg (dict): Initialization config for the module. |
|
inference_type (str): Type of inference to use ('ddpm' or 'ddim'). Defaults to 'ddpm'. |
|
""" |
|
|
|
def __init__(self, |
|
dataset_name: Optional[str] = None, |
|
model: dict = None, |
|
loss_recon: dict = None, |
|
loss_reduction: str = "frame", |
|
use_loss_score: bool = False, |
|
diffusion_train: dict = None, |
|
diffusion_test: dict = None, |
|
sampler_type: str = 'uniform', |
|
init_cfg: dict = None, |
|
inference_type: str = 'ddpm', |
|
**kwargs): |
|
super().__init__(init_cfg=init_cfg, **kwargs) |
|
self.model = build_submodule(model) |
|
self.loss_recon = build_loss(loss_recon) |
|
self.diffusion_train = build_diffusion(diffusion_train) |
|
self.diffusion_test = build_diffusion(diffusion_test) |
|
self.sampler = create_named_schedule_sampler(sampler_type, self.diffusion_train) |
|
self.inference_type = inference_type |
|
self.loss_reduction = loss_reduction |
|
self.use_loss_score = use_loss_score |
|
self.dataset_name = dataset_name |
|
|
|
if self.dataset_name == "kit_ml": |
|
self.mean = np.load("data/datasets/kit_ml/mean.npy") |
|
self.std = np.load("data/datasets/kit_ml/std.npy") |
|
elif self.dataset_name == "human_ml3d": |
|
self.mean = np.load("data/datasets/human_ml3d/mean.npy") |
|
self.std = np.load("data/datasets/human_ml3d/std.npy") |
|
elif self.dataset_name is not None: |
|
raise NotImplementedError() |
|
|
|
|
|
def forward(self, **kwargs) -> Union[Dict, List]: |
|
"""Forward pass of the model. |
|
|
|
Depending on whether the model is in training mode, this method performs the forward pass |
|
during training or inference, and calculates the relevant losses. |
|
|
|
Args: |
|
**kwargs: Keyword arguments containing the input data for the model. |
|
|
|
Returns: |
|
dict or list: The calculated losses during training or the generated motion during inference. |
|
""" |
|
motion = kwargs['motion'].float() |
|
motion_mask = kwargs['motion_mask'].float() |
|
motion_length = kwargs['motion_length'] |
|
num_intervals = kwargs.get('num_intervals', 1) |
|
sample_idx = kwargs.get('sample_idx', None) |
|
clip_feat = kwargs.get('clip_feat', None) |
|
B, T = motion.shape[:2] |
|
text = [kwargs['motion_metas'][i]['text'] for i in range(B)] |
|
|
|
if self.training: |
|
t, _ = self.sampler.sample(B, motion.device) |
|
output = self.diffusion_train.training_losses( |
|
model=self.model, |
|
x_start=motion, |
|
t=t, |
|
model_kwargs={ |
|
'motion_mask': motion_mask, |
|
'motion_length': motion_length, |
|
'text': text, |
|
'clip_feat': clip_feat, |
|
'sample_idx': sample_idx, |
|
'num_intervals': num_intervals |
|
} |
|
) |
|
pred, target = output['pred'], output['target'] |
|
recon_loss = self.loss_recon(pred, target, reduction_override='none') |
|
|
|
if self.use_loss_score: |
|
loss_score = kwargs['score'] |
|
recon_loss = recon_loss * loss_score.view(B, 1, -1) |
|
|
|
recon_loss = recon_loss.mean(dim=-1) * motion_mask |
|
recon_loss_batch = recon_loss.sum(dim=1) / motion_mask.sum(dim=1) |
|
recon_loss_frame = recon_loss.sum() / motion_mask.sum() |
|
|
|
if self.loss_reduction == "frame": |
|
recon_loss = recon_loss_frame |
|
else: |
|
recon_loss = recon_loss_batch |
|
|
|
if hasattr(self.sampler, "update_with_local_losses"): |
|
self.sampler.update_with_local_losses(t, recon_loss_batch) |
|
|
|
loss = {'recon_loss': recon_loss.mean()} |
|
if hasattr(self.model, 'aux_loss'): |
|
loss.update(self.model.aux_loss()) |
|
return loss |
|
|
|
else: |
|
dim_pose = kwargs['motion'].shape[-1] |
|
model_kwargs = self.model.get_precompute_condition( |
|
device=motion.device, text=text, **kwargs |
|
) |
|
model_kwargs.update({ |
|
'motion_mask': motion_mask, |
|
'sample_idx': sample_idx, |
|
'motion_length': motion_length, |
|
'num_intervals': num_intervals |
|
}) |
|
|
|
inference_kwargs = kwargs.get('inference_kwargs', {}) |
|
if self.inference_type == 'ddpm': |
|
output = self.diffusion_test.p_sample_loop( |
|
self.model, (B, T, dim_pose), clip_denoised=False, progress=False, |
|
model_kwargs=model_kwargs, **inference_kwargs |
|
) |
|
else: |
|
output = self.diffusion_test.ddim_sample_loop( |
|
self.model, (B, T, dim_pose), clip_denoised=False, progress=False, |
|
model_kwargs=model_kwargs, eta=0, **inference_kwargs |
|
) |
|
|
|
results = kwargs |
|
if getattr(self.model, "post_process") is not None: |
|
output = self.model.post_process(output) |
|
|
|
results['pred_motion'] = output |
|
results = self.split_results(results) |
|
return results |
|
|
|
|
|
@ARCHITECTURES.register_module() |
|
class UnifiedMotionDiffusion(BaseArchitecture): |
|
""" |
|
Unified Motion Diffusion architecture for generating motion sequences using diffusion models. |
|
|
|
Args: |
|
model (dict): Configuration for the motion generation model. |
|
loss_recon (dict): Configuration for the reconstruction loss. |
|
loss_reduction (str): Specifies the reduction method for the loss. Defaults to 'frame'. |
|
random_mask (float): Probability or scaling factor for applying random masking. Defaults to 0. |
|
diffusion_train (dict): Configuration for the diffusion model during training. |
|
diffusion_test (dict): Configuration for the diffusion model during testing. |
|
sampler_type (str): The type of sampler to use. Defaults to 'uniform'. |
|
init_cfg (dict): Initialization config for the module. |
|
inference_type (str): Type of inference to use ('ddpm' or 'ddim'). Defaults to 'ddpm'. |
|
body_scale (float): Scaling factor for the body motion mask. Defaults to 1.0. |
|
hand_scale (float): Scaling factor for the hand motion mask. Defaults to 1.0. |
|
face_scale (float): Scaling factor for the face motion mask. Defaults to 1.0. |
|
""" |
|
|
|
def __init__(self, |
|
model: dict = None, |
|
loss_recon: dict = None, |
|
loss_reduction: str = "frame", |
|
random_mask: float = 0, |
|
diffusion_train: dict = None, |
|
diffusion_test_dict: dict = None, |
|
sampler_type: str = 'uniform', |
|
init_cfg: dict = None, |
|
inference_type: str = 'ddpm', |
|
body_scale: float = 1.0, |
|
hand_scale: float = 1.0, |
|
face_scale: float = 1.0, |
|
train_repeat: int = 1, |
|
loss_weight: str = None, |
|
**kwargs): |
|
super().__init__(init_cfg=init_cfg, **kwargs) |
|
self.model = build_submodule(model) |
|
self.loss_recon = build_loss(loss_recon) |
|
self.diffusion_train = build_diffusion(diffusion_train) |
|
self.diffusion_test_dict = diffusion_test_dict |
|
self.sampler = create_named_schedule_sampler(sampler_type, self.diffusion_train) |
|
self.inference_type = inference_type |
|
self.loss_reduction = loss_reduction |
|
self.random_mask = random_mask |
|
self.body_scale = body_scale |
|
self.hand_scale = hand_scale |
|
self.face_scale = face_scale |
|
self.train_repeat = train_repeat |
|
self.loss_weight = None |
|
if init_cfg is not None: |
|
self.init_weights() |
|
|
|
def repeat_data(self, **kwargs): |
|
if self.train_repeat == 1: |
|
return kwargs |
|
N = self.train_repeat |
|
motion = kwargs['motion'].float().repeat(N, 1, 1) |
|
B = motion.shape[0] |
|
kwargs['motion'] = motion |
|
|
|
motion_mask = kwargs['motion_mask'].float().repeat(N, 1, 1) |
|
kwargs['motion_mask'] = motion_mask |
|
|
|
motion_length = kwargs['motion_length'].repeat(N, 1) |
|
kwargs['motion_length'] = motion_length |
|
|
|
motion_metas = kwargs['motion_metas'] * N |
|
kwargs['motion_metas'] = motion_metas |
|
|
|
if 'text_seq_feat' in kwargs: |
|
kwargs['text_seq_feat'] = kwargs['text_seq_feat'].repeat(N, 1) |
|
if 'text_word_feat' in kwargs: |
|
kwargs['text_word_feat'] = kwargs['text_word_feat'].repeat(N, 1, 1) |
|
if 'text_cond' in kwargs: |
|
kwargs['text_cond'] = kwargs['text_cond'].repeat(N, 1) |
|
|
|
if 'music_seq_feat' in kwargs: |
|
kwargs['music_seq_feat'] = kwargs['music_seq_feat'].repeat(N, 1) |
|
if 'music_word_feat' in kwargs: |
|
kwargs['music_word_feat'] = kwargs['music_word_feat'].repeat(N, 1, 1) |
|
if 'music_cond' in kwargs: |
|
kwargs['music_cond'] = kwargs['music_cond'].repeat(N, 1) |
|
|
|
if 'speech_seq_feat' in kwargs: |
|
kwargs['speech_seq_feat'] = kwargs['speech_seq_feat'].repeat(N, 1) |
|
if 'speech_word_feat' in kwargs: |
|
kwargs['speech_word_feat'] = kwargs['speech_word_feat'].repeat(N, 1, 1) |
|
if 'speech_cond' in kwargs: |
|
kwargs['speech_cond'] = kwargs['speech_cond'].repeat(N, 1) |
|
|
|
if 'video_seq_feat' in kwargs: |
|
kwargs['video_seq_feat'] = kwargs['video_seq_feat'].repeat(N, 1) |
|
if 'video_word_feat' in kwargs: |
|
kwargs['video_word_feat'] = kwargs['video_word_feat'].repeat(N, 1, 1) |
|
if 'video_cond' in kwargs: |
|
kwargs['video_cond'] = kwargs['video_cond'].repeat(N, 1) |
|
return kwargs |
|
|
|
|
|
def forward(self, **kwargs) -> Dict: |
|
"""Forward pass for training or inference in the unified motion diffusion model. |
|
|
|
Args: |
|
**kwargs: Keyword arguments containing the input data for the model. |
|
|
|
Returns: |
|
dict: The calculated losses during training or the generated motion during inference. |
|
""" |
|
if self.training: |
|
kwargs = self.repeat_data(**kwargs) |
|
|
|
motion = kwargs['motion'].float() |
|
B, T = motion.shape[:2] |
|
motion_mask = kwargs['motion_mask'].float() |
|
motion_length = kwargs['motion_length'] |
|
num_intervals = kwargs.get('num_intervals', 1) |
|
sample_idx = kwargs.get('sample_idx', None) |
|
motion_metas = kwargs['motion_metas'] |
|
|
|
|
|
text_word_feat = kwargs.get('text_word_feat', None) |
|
text_seq_feat = kwargs.get('text_seq_feat', None) |
|
text_cond = kwargs.get('text_cond', torch.zeros(B).type_as(motion)) |
|
|
|
music_word_feat = kwargs.get('music_word_feat', None) |
|
music_seq_feat = kwargs.get('music_seq_feat', None) |
|
music_cond = kwargs.get('music_cond', torch.zeros(B).type_as(motion)) |
|
|
|
speech_word_feat = kwargs.get('speech_word_feat', None) |
|
speech_seq_feat = kwargs.get('speech_seq_feat', None) |
|
speech_cond = kwargs.get('speech_cond', torch.zeros(B).type_as(motion)) |
|
|
|
video_word_feat = kwargs.get('video_word_feat', None) |
|
video_seq_feat = kwargs.get('video_seq_feat', None) |
|
video_cond = kwargs.get('video_cond', torch.zeros(B).type_as(motion)) |
|
|
|
if self.training: |
|
|
|
t, _ = self.sampler.sample(B, motion.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = self.diffusion_train.training_losses( |
|
model=self.model, |
|
x_start=motion, |
|
t=t, |
|
model_kwargs={ |
|
'motion_mask': motion_mask, |
|
'motion_length': motion_length, |
|
'num_intervals': num_intervals, |
|
'motion_metas': motion_metas, |
|
'text_word_feat': text_word_feat, |
|
'text_seq_feat': text_seq_feat, |
|
'text_cond': text_cond, |
|
'music_word_feat': music_word_feat, |
|
'music_seq_feat': music_seq_feat, |
|
'music_cond': music_cond, |
|
'speech_word_feat': speech_word_feat, |
|
'speech_seq_feat': speech_seq_feat, |
|
'speech_cond': speech_cond, |
|
'video_word_feat': video_word_feat, |
|
'video_seq_feat': video_seq_feat, |
|
'video_cond': video_cond, |
|
}) |
|
pred, target = output['pred'], output['target'] |
|
recon_loss = self.loss_recon(pred, target, reduction_override='none') |
|
|
|
motion_mask = expand_mask_to_all( |
|
motion_mask, self.body_scale, self.hand_scale, self.face_scale |
|
) |
|
if self.loss_weight is not None: |
|
loss_weight = torch.from_numpy(self.loss_weight).type_as(motion_mask) |
|
dataset_idx = self.model.dataset_idx |
|
loss_weight = loss_weight.index_select(0, dataset_idx).unsqueeze(1) |
|
motion_mask = motion_mask * loss_weight |
|
recon_loss = (recon_loss * motion_mask).sum(dim=-1) |
|
motion_mask = motion_mask.sum(dim=-1) |
|
else: |
|
recon_loss = (recon_loss * motion_mask).mean(dim=-1) |
|
motion_mask = motion_mask.mean(dim=-1) |
|
|
|
recon_loss_batch = recon_loss.sum(dim=1) / motion_mask.sum(dim=1) |
|
recon_loss_frame = recon_loss.sum() / motion_mask.sum() |
|
|
|
|
|
if self.loss_reduction == "frame": |
|
recon_loss = recon_loss_frame |
|
else: |
|
recon_loss = recon_loss_batch |
|
|
|
if hasattr(self.sampler, "update_with_local_losses"): |
|
self.sampler.update_with_local_losses(t, recon_loss_batch) |
|
|
|
loss = {'recon_loss': recon_loss.mean()} |
|
|
|
|
|
if hasattr(self.model, 'aux_loss'): |
|
loss.update(self.model.aux_loss()) |
|
|
|
return loss |
|
else: |
|
|
|
dim_pose = 669 |
|
model_kwargs = self.model.get_precompute_condition( |
|
device=motion.device, **kwargs |
|
) |
|
model_kwargs.update({ |
|
'motion_mask': motion_mask, |
|
'sample_idx': sample_idx, |
|
'motion_length': motion_length, |
|
'num_intervals': num_intervals, |
|
'motion_metas': motion_metas, |
|
'text_word_feat': text_word_feat, |
|
'text_seq_feat': text_seq_feat, |
|
'text_cond': text_cond, |
|
'music_word_feat': music_word_feat, |
|
'music_seq_feat': music_seq_feat, |
|
'music_cond': music_cond, |
|
'speech_word_feat': speech_word_feat, |
|
'speech_seq_feat': speech_seq_feat, |
|
'speech_cond': speech_cond, |
|
'video_word_feat': video_word_feat, |
|
'video_seq_feat': video_seq_feat, |
|
'video_cond': video_cond, |
|
}) |
|
inference_kwargs = kwargs.get('inference_kwargs', {}) |
|
inference_kwargs['gt_motion'] = motion |
|
inference_kwargs['context_mask'] = kwargs.get('context_mask', None) |
|
dataset_name = motion_metas[0]['meta_data']['dataset_name'] |
|
diffusion_test_cfg = self.diffusion_test_dict['base'] |
|
diffusion_test_cfg.update(dict(respace=self.diffusion_test_dict[dataset_name])) |
|
diffusion_test = build_diffusion(diffusion_test_cfg) |
|
if self.inference_type == 'ddpm': |
|
output = diffusion_test.p_sample_loop( |
|
self.model, (B, T, dim_pose), clip_denoised=False, |
|
progress=False, model_kwargs=model_kwargs, **inference_kwargs |
|
) |
|
else: |
|
output = diffusion_test.ddim_sample_loop( |
|
self.model, (B, T, dim_pose), clip_denoised=False, |
|
progress=False, model_kwargs=model_kwargs, eta=0, |
|
**inference_kwargs |
|
) |
|
|
|
results = kwargs |
|
if getattr(self.model, "post_process") is not None: |
|
output = self.model.post_process(output) |
|
|
|
results['pred_motion'] = output |
|
results = self.split_results(results) |
|
|
|
return results |
|
|