|
import numpy as np |
|
import torch |
|
|
|
from typing import Optional, Dict, List |
|
|
|
from ..builder import SUBMODULES |
|
from .motion_transformer import MotionTransformer |
|
|
|
|
|
@SUBMODULES.register_module() |
|
class MotionDiffuseTransformer(MotionTransformer): |
|
""" |
|
MotionDiffuseTransformer is a subclass of DiffusionTransformer designed for motion generation. |
|
It uses a diffusion-based approach with optional guidance during training and inference. |
|
|
|
Args: |
|
guidance_cfg (dict, optional): Configuration for guidance during inference and training. |
|
'type' can be 'constant' or dynamically calculated based on timesteps. |
|
kwargs: Additional keyword arguments for the DiffusionTransformer base class. |
|
""" |
|
|
|
def __init__(self, guidance_cfg: Optional[dict] = None, **kwargs): |
|
""" |
|
Initialize the MotionDiffuseTransformer. |
|
|
|
Args: |
|
guidance_cfg (Optional[dict]): Configuration for the guidance. |
|
kwargs: Additional arguments passed to the base class. |
|
""" |
|
super().__init__(**kwargs) |
|
self.guidance_cfg = guidance_cfg |
|
|
|
def scale_func(self, timestep: int) -> dict: |
|
""" |
|
Compute the scaling coefficients for text-based guidance and no-guidance. |
|
|
|
Args: |
|
timestep (int): The current diffusion timestep. |
|
|
|
Returns: |
|
dict: A dictionary containing 'text_coef' and 'none_coef' that control the mix of text-conditioned and |
|
non-text-conditioned outputs. |
|
""" |
|
if self.guidance_cfg['type'] == 'constant': |
|
w = self.guidance_cfg['scale'] |
|
return {'text_coef': w, 'none_coef': 1 - w} |
|
else: |
|
scale = self.guidance_cfg['scale'] |
|
w = (1 - (1000 - timestep) / 1000) * scale + 1 |
|
output = {'text_coef': w, 'none_coef': 1 - w} |
|
return output |
|
|
|
def get_precompute_condition(self, |
|
text: Optional[torch.Tensor] = None, |
|
xf_proj: Optional[torch.Tensor] = None, |
|
xf_out: Optional[torch.Tensor] = None, |
|
device: Optional[torch.device] = None, |
|
clip_feat: Optional[torch.Tensor] = None, |
|
**kwargs) -> dict: |
|
""" |
|
Precompute the conditions for text-based guidance using a text encoder. |
|
|
|
Args: |
|
text (Optional[torch.Tensor]): The input text data. |
|
xf_proj (Optional[torch.Tensor]): Precomputed text projection. |
|
xf_out (Optional[torch.Tensor]): Precomputed output from the text encoder. |
|
device (Optional[torch.device]): The device on which the model is running. |
|
clip_feat (Optional[torch.Tensor]): CLIP features for text guidance. |
|
kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
dict: A dictionary containing the text projection and output from the encoder. |
|
""" |
|
if xf_out is None: |
|
if self.use_text_proj: |
|
xf_proj, xf_out = self.encode_text(text, clip_feat, device) |
|
else: |
|
xf_out = self.encode_text(text, clip_feat, device) |
|
return {'xf_proj': xf_proj, 'xf_out': xf_out} |
|
|
|
def post_process(self, motion: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Post-process the generated motion data by re-normalizing it using mean and standard deviation. |
|
|
|
Args: |
|
motion (torch.Tensor): The generated motion data. |
|
|
|
Returns: |
|
torch.Tensor: Post-processed motion data. |
|
""" |
|
if self.post_process_cfg is not None: |
|
if self.post_process_cfg.get("unnormalized_infer", False): |
|
mean = torch.from_numpy(np.load(self.post_process_cfg['mean_path'])) |
|
mean = mean.type_as(motion) |
|
std = torch.from_numpy(np.load(self.post_process_cfg['std_path'])) |
|
std = std.type_as(motion) |
|
motion = motion * std + mean |
|
return motion |
|
|
|
def forward_train(self, |
|
h: torch.Tensor, |
|
src_mask: Optional[torch.Tensor] = None, |
|
emb: Optional[torch.Tensor] = None, |
|
xf_out: Optional[torch.Tensor] = None, |
|
**kwargs) -> torch.Tensor: |
|
""" |
|
Forward pass during training. |
|
|
|
Args: |
|
h (torch.Tensor): Input motion tensor of shape (B, T, D). |
|
src_mask (Optional[torch.Tensor]): Source mask for masking the input. |
|
emb (torch.Tensor): Time-step embeddings. |
|
xf_out (Optional[torch.Tensor]): Precomputed output from the text encoder. |
|
kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
torch.Tensor: Output motion data after processing by the temporal decoder blocks. |
|
""" |
|
B, T = h.shape[0], h.shape[1] |
|
if self.guidance_cfg is None: |
|
for module in self.temporal_decoder_blocks: |
|
h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask) |
|
else: |
|
cond_type = torch.randint(0, 100, size=(B, 1, 1)).to(h.device) |
|
for module in self.temporal_decoder_blocks: |
|
h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask, cond_type=cond_type) |
|
output = self.out(h).view(B, T, -1).contiguous() |
|
return output |
|
|
|
def forward_test(self, |
|
h: torch.Tensor, |
|
src_mask: Optional[torch.Tensor] = None, |
|
emb: Optional[torch.Tensor] = None, |
|
xf_out: Optional[torch.Tensor] = None, |
|
timesteps: Optional[torch.Tensor] = None, |
|
**kwargs) -> torch.Tensor: |
|
""" |
|
Forward pass during testing/inference. |
|
|
|
Args: |
|
h (torch.Tensor): Input motion tensor of shape (B, T, D). |
|
src_mask (Optional[torch.Tensor]): Source mask for masking the input. |
|
emb (torch.Tensor): Time-step embeddings. |
|
xf_out (Optional[torch.Tensor]): Precomputed output from the text encoder. |
|
timesteps (Optional[torch.Tensor]): Current diffusion timesteps. |
|
kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
torch.Tensor: Output motion data after processing by the temporal decoder blocks. |
|
""" |
|
B, T = h.shape[0], h.shape[1] |
|
if self.guidance_cfg is None: |
|
for module in self.temporal_decoder_blocks: |
|
h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask) |
|
output = self.out(h).view(B, T, -1).contiguous() |
|
else: |
|
text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1 |
|
none_cond_type = torch.zeros(B, 1, 1).to(h.device) |
|
all_cond_type = torch.cat((text_cond_type, none_cond_type), dim=0) |
|
h = h.repeat(2, 1, 1) |
|
xf_out = xf_out.repeat(2, 1, 1) |
|
emb = emb.repeat(2, 1) |
|
src_mask = src_mask.repeat(2, 1, 1) |
|
for module in self.temporal_decoder_blocks: |
|
h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask, cond_type=all_cond_type) |
|
out = self.out(h).view(2 * B, T, -1).contiguous() |
|
out_text = out[:B].contiguous() |
|
out_none = out[B:].contiguous() |
|
coef_cfg = self.scale_func(int(timesteps[0])) |
|
text_coef = coef_cfg['text_coef'] |
|
none_coef = coef_cfg['none_coef'] |
|
output = out_text * text_coef + out_none * none_coef |
|
return output |
|
|