|
import numpy as np |
|
import torch |
|
from torch import nn |
|
|
|
from typing import Optional, Dict, List |
|
|
|
from mogen.models.utils.misc import zero_module |
|
|
|
from ..builder import SUBMODULES, build_attention |
|
from ..utils.stylization_block import StylizationBlock |
|
from .motion_transformer import MotionTransformer |
|
|
|
|
|
def get_kit_slice(idx: int) -> List[int]: |
|
""" |
|
Get the slice indices for the KIT skeleton. |
|
|
|
Args: |
|
idx (int): The index of the skeleton part. |
|
|
|
Returns: |
|
List[int]: Slice indices for the specified skeleton part. |
|
""" |
|
if idx == 0: |
|
return [0, 1, 2, 3, 184, 185, 186, 247, 248, 249, 250] |
|
return [ |
|
4 + (idx - 1) * 3, |
|
4 + (idx - 1) * 3 + 1, |
|
4 + (idx - 1) * 3 + 2, |
|
64 + (idx - 1) * 6, |
|
64 + (idx - 1) * 6 + 1, |
|
64 + (idx - 1) * 6 + 2, |
|
64 + (idx - 1) * 6 + 3, |
|
64 + (idx - 1) * 6 + 4, |
|
64 + (idx - 1) * 6 + 5, |
|
184 + idx * 3, |
|
184 + idx * 3 + 1, |
|
184 + idx * 3 + 2, |
|
] |
|
|
|
|
|
def get_t2m_slice(idx: int) -> List[int]: |
|
""" |
|
Get the slice indices for the T2M skeleton. |
|
|
|
Args: |
|
idx (int): The index of the skeleton part. |
|
|
|
Returns: |
|
List[int]: Slice indices for the specified skeleton part. |
|
""" |
|
if idx == 0: |
|
return [0, 1, 2, 3, 193, 194, 195, 259, 260, 261, 262] |
|
return [ |
|
4 + (idx - 1) * 3, |
|
4 + (idx - 1) * 3 + 1, |
|
4 + (idx - 1) * 3 + 2, |
|
67 + (idx - 1) * 6, |
|
67 + (idx - 1) * 6 + 1, |
|
67 + (idx - 1) * 6 + 2, |
|
67 + (idx - 1) * 6 + 3, |
|
67 + (idx - 1) * 6 + 4, |
|
67 + (idx - 1) * 6 + 5, |
|
193 + idx * 3, |
|
193 + idx * 3 + 1, |
|
193 + idx * 3 + 2, |
|
] |
|
|
|
|
|
def get_part_slice(idx_list: List[int], func) -> List[int]: |
|
""" |
|
Get the slice indices for a list of indices. |
|
|
|
Args: |
|
idx_list (List[int]): List of part indices. |
|
func (Callable): Function to get slice indices for each part. |
|
|
|
Returns: |
|
List[int]: Concatenated list of slice indices for the parts. |
|
""" |
|
result = [] |
|
for idx in idx_list: |
|
result.extend(func(idx)) |
|
return result |
|
|
|
|
|
class PoseEncoder(nn.Module): |
|
""" |
|
Pose Encoder to process motion data and encode body parts into latent representations. |
|
""" |
|
|
|
def __init__(self, |
|
dataset_name: str = "human_ml3d", |
|
latent_dim: int = 64, |
|
input_dim: int = 263): |
|
super().__init__() |
|
self.dataset_name = dataset_name |
|
if dataset_name == "human_ml3d": |
|
func = get_t2m_slice |
|
self.head_slice = get_part_slice([12, 15], func) |
|
self.stem_slice = get_part_slice([3, 6, 9], func) |
|
self.larm_slice = get_part_slice([14, 17, 19, 21], func) |
|
self.rarm_slice = get_part_slice([13, 16, 18, 20], func) |
|
self.lleg_slice = get_part_slice([2, 5, 8, 11], func) |
|
self.rleg_slice = get_part_slice([1, 4, 7, 10], func) |
|
self.root_slice = get_part_slice([0], func) |
|
self.body_slice = get_part_slice([_ for _ in range(22)], func) |
|
elif dataset_name == "kit_ml": |
|
func = get_kit_slice |
|
self.head_slice = get_part_slice([4], func) |
|
self.stem_slice = get_part_slice([1, 2, 3], func) |
|
self.larm_slice = get_part_slice([8, 9, 10], func) |
|
self.rarm_slice = get_part_slice([5, 6, 7], func) |
|
self.lleg_slice = get_part_slice([16, 17, 18, 19, 20], func) |
|
self.rleg_slice = get_part_slice([11, 12, 13, 14, 15], func) |
|
self.root_slice = get_part_slice([0], func) |
|
self.body_slice = get_part_slice([_ for _ in range(21)], func) |
|
else: |
|
raise ValueError() |
|
|
|
self.head_embed = nn.Linear(len(self.head_slice), latent_dim) |
|
self.stem_embed = nn.Linear(len(self.stem_slice), latent_dim) |
|
self.larm_embed = nn.Linear(len(self.larm_slice), latent_dim) |
|
self.rarm_embed = nn.Linear(len(self.rarm_slice), latent_dim) |
|
self.lleg_embed = nn.Linear(len(self.lleg_slice), latent_dim) |
|
self.rleg_embed = nn.Linear(len(self.rleg_slice), latent_dim) |
|
self.root_embed = nn.Linear(len(self.root_slice), latent_dim) |
|
self.body_embed = nn.Linear(len(self.body_slice), latent_dim) |
|
|
|
assert len(set(self.body_slice)) == input_dim |
|
|
|
def forward(self, motion: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Forward pass for encoding the motion into body part embeddings. |
|
|
|
Args: |
|
motion (torch.Tensor): Input motion tensor of shape (B, T, D). |
|
|
|
Returns: |
|
torch.Tensor: Concatenated latent representations of body parts. |
|
""" |
|
head_feat = self.head_embed(motion[:, :, self.head_slice].contiguous()) |
|
stem_feat = self.stem_embed(motion[:, :, self.stem_slice].contiguous()) |
|
larm_feat = self.larm_embed(motion[:, :, self.larm_slice].contiguous()) |
|
rarm_feat = self.rarm_embed(motion[:, :, self.rarm_slice].contiguous()) |
|
lleg_feat = self.lleg_embed(motion[:, :, self.lleg_slice].contiguous()) |
|
rleg_feat = self.rleg_embed(motion[:, :, self.rleg_slice].contiguous()) |
|
root_feat = self.root_embed(motion[:, :, self.root_slice].contiguous()) |
|
body_feat = self.body_embed(motion[:, :, self.body_slice].contiguous()) |
|
feat = torch.cat((head_feat, stem_feat, larm_feat, rarm_feat, |
|
lleg_feat, rleg_feat, root_feat, body_feat), |
|
dim=-1) |
|
return feat |
|
|
|
|
|
class PoseDecoder(nn.Module): |
|
""" |
|
Pose Decoder to decode the latent representations of body parts back into motion. |
|
""" |
|
|
|
def __init__(self, |
|
dataset_name: str = "human_ml3d", |
|
latent_dim: int = 64, |
|
output_dim: int = 263): |
|
super().__init__() |
|
self.dataset_name = dataset_name |
|
self.latent_dim = latent_dim |
|
self.output_dim = output_dim |
|
if dataset_name == "human_ml3d": |
|
func = get_t2m_slice |
|
self.head_slice = get_part_slice([12, 15], func) |
|
self.stem_slice = get_part_slice([3, 6, 9], func) |
|
self.larm_slice = get_part_slice([14, 17, 19, 21], func) |
|
self.rarm_slice = get_part_slice([13, 16, 18, 20], func) |
|
self.lleg_slice = get_part_slice([2, 5, 8, 11], func) |
|
self.rleg_slice = get_part_slice([1, 4, 7, 10], func) |
|
self.root_slice = get_part_slice([0], func) |
|
self.body_slice = get_part_slice([_ for _ in range(22)], func) |
|
elif dataset_name == "kit_ml": |
|
func = get_kit_slice |
|
self.head_slice = get_part_slice([4], func) |
|
self.stem_slice = get_part_slice([1, 2, 3], func) |
|
self.larm_slice = get_part_slice([8, 9, 10], func) |
|
self.rarm_slice = get_part_slice([5, 6, 7], func) |
|
self.lleg_slice = get_part_slice([16, 17, 18, 19, 20], func) |
|
self.rleg_slice = get_part_slice([11, 12, 13, 14, 15], func) |
|
self.root_slice = get_part_slice([0], func) |
|
self.body_slice = get_part_slice([_ for _ in range(21)], func) |
|
else: |
|
raise ValueError() |
|
|
|
self.head_out = nn.Linear(latent_dim, len(self.head_slice)) |
|
self.stem_out = nn.Linear(latent_dim, len(self.stem_slice)) |
|
self.larm_out = nn.Linear(latent_dim, len(self.larm_slice)) |
|
self.rarm_out = nn.Linear(latent_dim, len(self.rarm_slice)) |
|
self.lleg_out = nn.Linear(latent_dim, len(self.lleg_slice)) |
|
self.rleg_out = nn.Linear(latent_dim, len(self.rleg_slice)) |
|
self.root_out = nn.Linear(latent_dim, len(self.root_slice)) |
|
self.body_out = nn.Linear(latent_dim, len(self.body_slice)) |
|
|
|
def forward(self, motion: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Forward pass to decode the latent body part features back to motion. |
|
|
|
Args: |
|
motion (torch.Tensor): Input tensor of shape (B, T, D). |
|
|
|
Returns: |
|
torch.Tensor: Output motion tensor of shape (B, T, output_dim). |
|
""" |
|
B, T = motion.shape[:2] |
|
D = self.latent_dim |
|
head_feat = self.head_out(motion[:, :, :D].contiguous()) |
|
stem_feat = self.stem_out(motion[:, :, D:2 * D].contiguous()) |
|
larm_feat = self.larm_out(motion[:, :, 2 * D:3 * D].contiguous()) |
|
rarm_feat = self.rarm_out(motion[:, :, 3 * D:4 * D].contiguous()) |
|
lleg_feat = self.lleg_out(motion[:, :, 4 * D:5 * D].contiguous()) |
|
rleg_feat = self.rleg_out(motion[:, :, 5 * D:6 * D].contiguous()) |
|
root_feat = self.root_out(motion[:, :, 6 * D:7 * D].contiguous()) |
|
body_feat = self.body_out(motion[:, :, 7 * D:].contiguous()) |
|
output = torch.zeros(B, T, self.output_dim).type_as(motion) |
|
output[:, :, self.head_slice] = head_feat |
|
output[:, :, self.stem_slice] = stem_feat |
|
output[:, :, self.larm_slice] = larm_feat |
|
output[:, :, self.rarm_slice] = rarm_feat |
|
output[:, :, self.lleg_slice] = lleg_feat |
|
output[:, :, self.rleg_slice] = rleg_feat |
|
output[:, :, self.root_slice] = root_feat |
|
output = (output + body_feat) / 2.0 |
|
return output |
|
|
|
|
|
class SFFN(nn.Module): |
|
""" |
|
A Stylized Feed-Forward Network (SFFN) module for transformer layers. |
|
|
|
Args: |
|
latent_dim (int): Dimensionality of the input. |
|
ffn_dim (int): Dimensionality of the feed-forward layer. |
|
dropout (float): Dropout probability. |
|
time_embed_dim (int): Dimensionality of the time embedding. |
|
norm (str): Normalization type ('None'). |
|
activation (str): Activation function ('GELU'). |
|
""" |
|
|
|
def __init__(self, |
|
latent_dim: int, |
|
ffn_dim: int, |
|
dropout: float, |
|
time_embed_dim: int, |
|
norm: str = "None", |
|
activation: str = "GELU", |
|
**kwargs): |
|
super().__init__() |
|
self.linear1_list = nn.ModuleList() |
|
self.linear2_list = nn.ModuleList() |
|
|
|
channel_mul = 1 |
|
if activation == "GELU": |
|
self.activation = nn.GELU() |
|
|
|
for i in range(8): |
|
self.linear1_list.append(nn.Linear(latent_dim, ffn_dim * channel_mul)) |
|
self.linear2_list.append(nn.Linear(ffn_dim, latent_dim)) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.proj_out = StylizationBlock(latent_dim * 8, time_embed_dim, dropout) |
|
|
|
if norm == "None": |
|
self.norm = nn.Identity() |
|
|
|
def forward(self, x: torch.Tensor, emb: torch.Tensor, **kwargs) -> torch.Tensor: |
|
""" |
|
Forward pass of the SFFN layer. |
|
|
|
Args: |
|
x (torch.Tensor): Input tensor of shape (B, T, D). |
|
emb (torch.Tensor): Embedding tensor for time step modulation. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor of shape (B, T, D). |
|
""" |
|
B, T, D = x.shape |
|
x = self.norm(x) |
|
x = x.reshape(B, T, 8, -1) |
|
output = [] |
|
for i in range(8): |
|
feat = x[:, :, i].contiguous() |
|
feat = self.dropout(self.activation(self.linear1_list[i](feat))) |
|
feat = self.linear2_list[i](feat) |
|
output.append(feat) |
|
y = torch.cat(output, dim=-1) |
|
y = x.reshape(B, T, D) + self.proj_out(y, emb) |
|
return y |
|
|
|
|
|
class DecoderLayer(nn.Module): |
|
""" |
|
A transformer decoder layer with cross-attention and feed-forward network (SFFN). |
|
|
|
Args: |
|
ca_block_cfg (Optional[Dict]): Configuration for the cross-attention block. |
|
ffn_cfg (Optional[Dict]): Configuration for the feed-forward network (SFFN). |
|
""" |
|
|
|
def __init__(self, ca_block_cfg: Optional[Dict] = None, ffn_cfg: Optional[Dict] = None): |
|
super().__init__() |
|
self.ca_block = build_attention(ca_block_cfg) |
|
self.ffn = SFFN(**ffn_cfg) |
|
|
|
def forward(self, **kwargs) -> torch.Tensor: |
|
""" |
|
Forward pass of the decoder layer. |
|
|
|
Args: |
|
kwargs: Keyword arguments for attention and feed-forward layers. |
|
|
|
Returns: |
|
torch.Tensor: Output of the decoder layer. |
|
""" |
|
if self.ca_block is not None: |
|
x = self.ca_block(**kwargs) |
|
kwargs.update({'x': x}) |
|
if self.ffn is not None: |
|
x = self.ffn(**kwargs) |
|
return x |
|
|
|
|
|
@SUBMODULES.register_module() |
|
class FineMoGenTransformer(MotionTransformer): |
|
""" |
|
A transformer model for motion generation using fine-grained control with Diffusion. |
|
|
|
Args: |
|
scale_func_cfg (Optional[Dict]): Configuration for scaling function. |
|
pose_encoder_cfg (Optional[Dict]): Configuration for the PoseEncoder. |
|
pose_decoder_cfg (Optional[Dict]): Configuration for the PoseDecoder. |
|
moe_route_loss_weight (float): Weight for the Mixture of Experts (MoE) routing loss. |
|
template_kl_loss_weight (float): Weight for the KL loss in template generation. |
|
fine_mode (bool): Whether to enable fine mode for control over body parts. |
|
""" |
|
|
|
def __init__(self, |
|
scale_func_cfg: Optional[Dict] = None, |
|
pose_encoder_cfg: Optional[Dict] = None, |
|
pose_decoder_cfg: Optional[Dict] = None, |
|
moe_route_loss_weight: float = 1.0, |
|
template_kl_loss_weight: float = 0.0001, |
|
fine_mode: bool = False, |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
self.scale_func_cfg = scale_func_cfg |
|
self.joint_embed = PoseEncoder(**pose_encoder_cfg) |
|
self.out = zero_module(PoseDecoder(**pose_decoder_cfg)) |
|
self.moe_route_loss_weight = moe_route_loss_weight |
|
self.template_kl_loss_weight = template_kl_loss_weight |
|
self.mean = np.load("data/datasets/kit_ml/mean.npy") |
|
self.std = np.load("data/datasets/kit_ml/std.npy") |
|
self.fine_mode = fine_mode |
|
|
|
def build_temporal_blocks(self, sa_block_cfg: Optional[Dict], ca_block_cfg: Optional[Dict], ffn_cfg: Optional[Dict]): |
|
""" |
|
Build temporal decoder blocks for the model. |
|
|
|
Args: |
|
sa_block_cfg (Optional[Dict]): Configuration for self-attention blocks. |
|
ca_block_cfg (Optional[Dict]): Configuration for cross-attention blocks. |
|
ffn_cfg (Optional[Dict]): Configuration for feed-forward networks. |
|
""" |
|
self.temporal_decoder_blocks = nn.ModuleList() |
|
for i in range(self.num_layers): |
|
if isinstance(ffn_cfg, list): |
|
ffn_cfg_block = ffn_cfg[i] |
|
else: |
|
ffn_cfg_block = ffn_cfg |
|
self.temporal_decoder_blocks.append(DecoderLayer(ca_block_cfg=ca_block_cfg, ffn_cfg=ffn_cfg_block)) |
|
|
|
def scale_func(self, timestep: int) -> Dict[str, float]: |
|
""" |
|
Scaling function for text and none coefficient based on timestep. |
|
|
|
Args: |
|
timestep (int): Current diffusion timestep. |
|
|
|
Returns: |
|
Dict[str, float]: Scaling factors for text and non-text conditioning. |
|
""" |
|
scale = self.scale_func_cfg['scale'] |
|
w = (1 - (1000 - timestep) / 1000) * scale + 1 |
|
return {'text_coef': w, 'none_coef': 1 - w} |
|
|
|
def aux_loss(self) -> Dict[str, torch.Tensor]: |
|
""" |
|
Auxiliary loss computation for MoE routing and KL loss. |
|
|
|
Returns: |
|
Dict[str, torch.Tensor]: Computed auxiliary losses. |
|
""" |
|
aux_loss = 0 |
|
kl_loss = 0 |
|
for module in self.temporal_decoder_blocks: |
|
if hasattr(module.ca_block, 'aux_loss'): |
|
aux_loss = aux_loss + module.ca_block.aux_loss |
|
if hasattr(module.ca_block, 'kl_loss'): |
|
kl_loss = kl_loss + module.ca_block.kl_loss |
|
losses = {} |
|
if aux_loss > 0: |
|
losses['moe_route_loss'] = aux_loss * self.moe_route_loss_weight |
|
if kl_loss > 0: |
|
losses['template_kl_loss'] = kl_loss * self.template_kl_loss_weight |
|
return losses |
|
|
|
def get_precompute_condition(self, |
|
text: Optional[str] = None, |
|
motion_length: Optional[torch.Tensor] = None, |
|
xf_out: Optional[torch.Tensor] = None, |
|
re_dict: Optional[Dict] = None, |
|
device: Optional[torch.device] = None, |
|
sample_idx: Optional[int] = None, |
|
clip_feat: Optional[torch.Tensor] = None, |
|
**kwargs) -> Dict[str, torch.Tensor]: |
|
""" |
|
Precompute conditioning features for text or other modalities. |
|
|
|
Args: |
|
text (Optional[str]): Text input for conditioning. |
|
motion_length (Optional[torch.Tensor]): Length of the motion sequence. |
|
xf_out (Optional[torch.Tensor]): Precomputed text features. |
|
re_dict (Optional[Dict]): Additional features dictionary. |
|
device (Optional[torch.device]): Target device for the model. |
|
sample_idx (Optional[int]): Sample index for specific conditioning. |
|
clip_feat (Optional[torch.Tensor]): Precomputed CLIP features. |
|
|
|
Returns: |
|
Dict[str, torch.Tensor]: Precomputed conditioning features. |
|
""" |
|
if xf_out is None: |
|
xf_out = self.encode_text(text, clip_feat, device) |
|
output = {'xf_out': xf_out} |
|
return output |
|
|
|
def post_process(self, motion: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Post-process motion data by unnormalizing if necessary. |
|
|
|
Args: |
|
motion (torch.Tensor): Input motion data. |
|
|
|
Returns: |
|
torch.Tensor: Processed motion data. |
|
""" |
|
if self.post_process_cfg is not None: |
|
if self.post_process_cfg.get("unnormalized_infer", False): |
|
mean = torch.from_numpy(np.load(self.post_process_cfg['mean_path'])).type_as(motion) |
|
std = torch.from_numpy(np.load(self.post_process_cfg['std_path'])).type_as(motion) |
|
motion = motion * std + mean |
|
return motion |
|
|
|
def forward_train(self, |
|
h: torch.Tensor, |
|
src_mask: Optional[torch.Tensor] = None, |
|
emb: Optional[torch.Tensor] = None, |
|
xf_out: Optional[torch.Tensor] = None, |
|
motion_length: Optional[torch.Tensor] = None, |
|
num_intervals: int = 1, |
|
**kwargs) -> torch.Tensor: |
|
""" |
|
Forward pass during training. |
|
|
|
Args: |
|
h (torch.Tensor): Input tensor of shape (B, T, D). |
|
src_mask (Optional[torch.Tensor]): Source mask tensor. |
|
emb (Optional[torch.Tensor]): Time embedding tensor. |
|
xf_out (Optional[torch.Tensor]): Precomputed text features. |
|
motion_length (Optional[torch.Tensor]): Lengths of motion sequences. |
|
num_intervals (int): Number of intervals for processing. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor of shape (B, T, D). |
|
""" |
|
B, T = h.shape[0], h.shape[1] |
|
cond_type = torch.randint(0, 100, size=(B, 1, 1)).repeat(1, 8, 1).to(h.device) if self.fine_mode else torch.randint(0, 100, size=(B, 1, 1)).to(h.device) |
|
for module in self.temporal_decoder_blocks: |
|
h = module(x=h, |
|
xf=xf_out, |
|
emb=emb, |
|
src_mask=src_mask, |
|
cond_type=cond_type, |
|
motion_length=motion_length, |
|
num_intervals=num_intervals) |
|
|
|
output = self.out(h).view(B, T, -1).contiguous() |
|
return output |
|
|
|
def forward_test(self, |
|
h: torch.Tensor, |
|
src_mask: Optional[torch.Tensor] = None, |
|
emb: Optional[torch.Tensor] = None, |
|
xf_out: Optional[torch.Tensor] = None, |
|
timesteps: Optional[torch.Tensor] = None, |
|
motion_length: Optional[torch.Tensor] = None, |
|
num_intervals: int = 1, |
|
**kwargs) -> torch.Tensor: |
|
""" |
|
Forward pass during inference. |
|
|
|
Args: |
|
h (torch.Tensor): Input tensor of shape (B, T, D). |
|
src_mask (Optional[torch.Tensor]): Source mask tensor. |
|
emb (Optional[torch.Tensor]): Time embedding tensor. |
|
xf_out (Optional[torch.Tensor]): Precomputed text features. |
|
timesteps (Optional[torch.Tensor]): Diffusion timesteps. |
|
motion_length (Optional[torch.Tensor]): Lengths of motion sequences. |
|
num_intervals (int): Number of intervals for processing. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor of shape (B, T, D). |
|
""" |
|
B, T = h.shape[0], h.shape[1] |
|
text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1 |
|
none_cond_type = torch.zeros(B, 1, 1).to(h.device) |
|
|
|
all_cond_type = torch.cat((text_cond_type, none_cond_type), dim=0) |
|
h = h.repeat(2, 1, 1) |
|
xf_out = xf_out.repeat(2, 1, 1) |
|
emb = emb.repeat(2, 1) |
|
src_mask = src_mask.repeat(2, 1, 1) |
|
motion_length = motion_length.repeat(2, 1) |
|
for module in self.temporal_decoder_blocks: |
|
h = module(x=h, |
|
xf=xf_out, |
|
emb=emb, |
|
src_mask=src_mask, |
|
cond_type=all_cond_type, |
|
motion_length=motion_length, |
|
num_intervals=num_intervals) |
|
out = self.out(h).view(2 * B, T, -1).contiguous() |
|
out_text = out[:B].contiguous() |
|
out_none = out[B:].contiguous() |
|
|
|
coef_cfg = self.scale_func(int(timesteps[0])) |
|
text_coef = coef_cfg['text_coef'] |
|
none_coef = coef_cfg['none_coef'] |
|
output = out_text * text_coef + out_none * none_coef |
|
return output |
|
|
|
|
|
|