|
import torch |
|
from torch import nn |
|
from typing import Optional |
|
|
|
from mogen.models.utils.misc import zero_module |
|
from mogen.models.utils.position_encoding import timestep_embedding |
|
from mogen.models.utils.stylization_block import StylizationBlock |
|
|
|
from ..builder import SUBMODULES, build_attention |
|
from .remodiffuse import ReMoDiffuseTransformer |
|
|
|
|
|
class FFN(nn.Module): |
|
""" |
|
A feed-forward network (FFN) with optional stylization block. |
|
|
|
Args: |
|
latent_dim (int): The dimension of the input and output latent space. |
|
ffn_dim (int): The dimension of the hidden feed-forward network. |
|
dropout (float): The dropout rate to apply after activation. |
|
time_embed_dim (int): The dimension of the time embedding. |
|
""" |
|
def __init__(self, latent_dim: int, ffn_dim: int, dropout: float, time_embed_dim: int): |
|
super().__init__() |
|
self.latent_dim = latent_dim |
|
self.linear1 = nn.Linear(latent_dim, ffn_dim) |
|
self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim)) |
|
self.activation = nn.GELU() |
|
self.dropout = nn.Dropout(dropout) |
|
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) |
|
|
|
def forward(self, x: torch.Tensor, emb: torch.Tensor, **kwargs) -> torch.Tensor: |
|
""" |
|
Forward pass of the FFN layer. |
|
|
|
Args: |
|
x (torch.Tensor): Input tensor of shape (B, T, latent_dim*2). |
|
emb (torch.Tensor): Time embedding tensor. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor after FFN and stylization block. |
|
""" |
|
x1 = x[:, :, :self.latent_dim].contiguous() |
|
x2 = x[:, :, self.latent_dim:].contiguous() |
|
y1 = self.linear2(self.dropout(self.activation(self.linear1(x1)))) |
|
y1 = x1 + self.proj_out(y1, emb) |
|
y2 = self.linear2(self.dropout(self.activation(self.linear1(x2)))) |
|
y2 = x2 + self.proj_out(y2, emb) |
|
y = torch.cat((y1, y2), dim=-1) |
|
return y |
|
|
|
|
|
class DecoderLayer(nn.Module): |
|
""" |
|
A single decoder layer consisting of a cross-attention block and a feed-forward network (FFN). |
|
|
|
Args: |
|
ca_block_cfg (Optional[dict]): Configuration for the cross-attention block. |
|
ffn_cfg (Optional[dict]): Configuration for the feed-forward network. |
|
""" |
|
def __init__(self, ca_block_cfg: Optional[dict] = None, ffn_cfg: Optional[dict] = None): |
|
super().__init__() |
|
self.ca_block = build_attention(ca_block_cfg) |
|
self.ffn = FFN(**ffn_cfg) |
|
|
|
def forward(self, **kwargs) -> torch.Tensor: |
|
""" |
|
Forward pass of the decoder layer. |
|
|
|
Args: |
|
**kwargs: Arguments passed to the cross-attention and FFN layers. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor after passing through the layer. |
|
""" |
|
if self.ca_block is not None: |
|
x = self.ca_block(**kwargs) |
|
kwargs.update({'x': x}) |
|
if self.ffn is not None: |
|
x = self.ffn(**kwargs) |
|
return x |
|
|
|
|
|
@SUBMODULES.register_module() |
|
class MoMatMoGenTransformer(ReMoDiffuseTransformer): |
|
""" |
|
MoMatMoGenTransformer is a motion generation transformer model, which uses ReMoDiffuse as a base. |
|
|
|
Args: |
|
ReMoDiffuseTransformer: Base transformer class. |
|
""" |
|
def build_temporal_blocks(self, sa_block_cfg: Optional[dict], ca_block_cfg: Optional[dict], ffn_cfg: Optional[dict]): |
|
""" |
|
Build temporal decoder blocks using the provided configurations. |
|
|
|
Args: |
|
sa_block_cfg (Optional[dict]): Self-attention block configuration. |
|
ca_block_cfg (Optional[dict]): Cross-attention block configuration. |
|
ffn_cfg (Optional[dict]): Feed-forward network configuration. |
|
""" |
|
self.temporal_decoder_blocks = nn.ModuleList() |
|
for i in range(self.num_layers): |
|
self.temporal_decoder_blocks.append( |
|
DecoderLayer(ca_block_cfg=ca_block_cfg, ffn_cfg=ffn_cfg)) |
|
|
|
def forward(self, |
|
motion: torch.Tensor, |
|
timesteps: torch.Tensor, |
|
motion_mask: Optional[torch.Tensor] = None, |
|
**kwargs) -> torch.Tensor: |
|
""" |
|
Forward pass for motion generation. |
|
|
|
Args: |
|
motion (torch.Tensor): Input motion tensor of shape (B, T, D). |
|
timesteps (torch.Tensor): Timestep embeddings. |
|
motion_mask (Optional[torch.Tensor]): Motion mask, if any. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor after processing the motion data. |
|
""" |
|
T = motion.shape[1] |
|
conditions = self.get_precompute_condition(device=motion.device, |
|
**kwargs) |
|
if len(motion_mask.shape) == 2: |
|
src_mask = motion_mask.clone().unsqueeze(-1) |
|
else: |
|
src_mask = motion_mask.clone() |
|
|
|
if self.time_embedding_type == 'sinusoidal': |
|
emb = self.time_embed( |
|
timestep_embedding(timesteps, self.latent_dim)) |
|
else: |
|
emb = self.time_embed(self.time_tokens(timesteps)) |
|
|
|
if self.use_text_proj: |
|
emb = emb + conditions['xf_proj'] |
|
|
|
motion1 = motion[:, :, :self.input_feats].contiguous() |
|
motion2 = motion[:, :, self.input_feats:].contiguous() |
|
h1 = self.joint_embed(motion1) |
|
h2 = self.joint_embed(motion2) |
|
if self.use_pos_embedding: |
|
h1 = h1 + self.sequence_embedding.unsqueeze(0)[:, :T, :] |
|
h2 = h2 + self.sequence_embedding.unsqueeze(0)[:, :T, :] |
|
h = torch.cat((h1, h2), dim=-1) |
|
|
|
if self.training: |
|
output = self.forward_train(h=h, |
|
src_mask=src_mask, |
|
emb=emb, |
|
timesteps=timesteps, |
|
**conditions) |
|
else: |
|
output = self.forward_test(h=h, |
|
src_mask=src_mask, |
|
emb=emb, |
|
timesteps=timesteps, |
|
**conditions) |
|
if self.use_residual_connection: |
|
output = motion + output |
|
return output |
|
|
|
def forward_train(self, |
|
h: Optional[torch.Tensor] = None, |
|
src_mask: Optional[torch.Tensor] = None, |
|
emb: Optional[torch.Tensor] = None, |
|
xf_out: Optional[torch.Tensor] = None, |
|
re_dict: Optional[dict] = None, |
|
**kwargs) -> torch.Tensor: |
|
""" |
|
Training forward pass for the motion generation transformer. |
|
|
|
Args: |
|
h (Optional[torch.Tensor]): Input tensor. |
|
src_mask (Optional[torch.Tensor]): Source mask. |
|
emb (Optional[torch.Tensor]): Embedding tensor. |
|
xf_out (Optional[torch.Tensor]): Output of the cross-attention block. |
|
re_dict (Optional[dict]): Dictionary for recurrent features. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor after processing. |
|
""" |
|
B, T = h.shape[0], h.shape[1] |
|
cond_type = torch.randint(0, 100, size=(B, 1, 1)).to(h.device) |
|
for module in self.temporal_decoder_blocks: |
|
h = module(x=h, |
|
xf=xf_out, |
|
emb=emb, |
|
src_mask=src_mask, |
|
cond_type=cond_type, |
|
re_dict=re_dict) |
|
|
|
out1 = self.out(h[:, :, :self.latent_dim].contiguous()) |
|
out1 = out1.view(B, T, -1).contiguous() |
|
out2 = self.out(h[:, :, self.latent_dim:].contiguous()) |
|
out2 = out2.view(B, T, -1).contiguous() |
|
output = torch.cat((out1, out2), dim=-1) |
|
return output |
|
|
|
def forward_test(self, |
|
h: Optional[torch.Tensor] = None, |
|
src_mask: Optional[torch.Tensor] = None, |
|
emb: Optional[torch.Tensor] = None, |
|
xf_out: Optional[torch.Tensor] = None, |
|
re_dict: Optional[dict] = None, |
|
timesteps: Optional[torch.Tensor] = None, |
|
**kwargs) -> torch.Tensor: |
|
""" |
|
Testing forward pass for the motion generation transformer. |
|
|
|
Args: |
|
h (Optional[torch.Tensor]): Input tensor. |
|
src_mask (Optional[torch.Tensor]): Source mask. |
|
emb (Optional[torch.Tensor]): Embedding tensor. |
|
xf_out (Optional[torch.Tensor]): Output of the cross-attention block. |
|
re_dict (Optional[dict]): Dictionary for recurrent features. |
|
timesteps (Optional[torch.Tensor]): Timestep embeddings. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor after processing. |
|
""" |
|
B, T = h.shape[0], h.shape[1] |
|
both_cond_type = torch.zeros(B, 1, 1).to(h.device) + 99 |
|
text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1 |
|
retr_cond_type = torch.zeros(B, 1, 1).to(h.device) + 10 |
|
none_cond_type = torch.zeros(B, 1, 1).to(h.device) |
|
|
|
all_cond_type = torch.cat( |
|
(both_cond_type, text_cond_type, retr_cond_type, none_cond_type), |
|
dim=0) |
|
h = h.repeat(4, 1, 1) |
|
xf_out = xf_out.repeat(4, 1, 1) |
|
emb = emb.repeat(4, 1) |
|
src_mask = src_mask.repeat(4, 1, 1) |
|
if re_dict['re_motion'].shape[0] != h.shape[0]: |
|
re_dict['re_motion'] = re_dict['re_motion'].repeat(4, 1, 1, 1) |
|
re_dict['re_text'] = re_dict['re_text'].repeat(4, 1, 1, 1) |
|
re_dict['re_mask'] = re_dict['re_mask'].repeat(4, 1, 1) |
|
|
|
for module in self.temporal_decoder_blocks: |
|
h = module(x=h, |
|
xf=xf_out, |
|
emb=emb, |
|
src_mask=src_mask, |
|
cond_type=all_cond_type, |
|
re_dict=re_dict) |
|
|
|
out1 = self.out(h[:, :, :self.latent_dim].contiguous()) |
|
out1 = out1.view(4 * B, T, -1).contiguous() |
|
out2 = self.out(h[:, :, self.latent_dim:].contiguous()) |
|
out2 = out2.view(4 * B, T, -1).contiguous() |
|
out = torch.cat((out1, out2), dim=-1) |
|
out_both = out[:B].contiguous() |
|
out_text = out[B:2 * B].contiguous() |
|
out_retr = out[2 * B:3 * B].contiguous() |
|
out_none = out[3 * B:].contiguous() |
|
|
|
coef_cfg = self.scale_func(int(timesteps[0])) |
|
both_coef = coef_cfg['both_coef'] |
|
text_coef = coef_cfg['text_coef'] |
|
retr_coef = coef_cfg['retr_coef'] |
|
none_coef = coef_cfg['none_coef'] |
|
output = out_both * both_coef |
|
output += out_text * text_coef |
|
output += out_retr * retr_coef |
|
output += out_none * none_coef |
|
return output |
|
|