mingyuan's picture
initial commit
373af33
raw
history blame
22 kB
import numpy as np
import torch
from torch import nn
from typing import Optional, Dict, List
from mogen.models.utils.misc import zero_module
from ..builder import SUBMODULES, build_attention
from ..utils.stylization_block import StylizationBlock
from .motion_transformer import MotionTransformer
def get_kit_slice(idx: int) -> List[int]:
"""
Get the slice indices for the KIT skeleton.
Args:
idx (int): The index of the skeleton part.
Returns:
List[int]: Slice indices for the specified skeleton part.
"""
if idx == 0:
return [0, 1, 2, 3, 184, 185, 186, 247, 248, 249, 250]
return [
4 + (idx - 1) * 3,
4 + (idx - 1) * 3 + 1,
4 + (idx - 1) * 3 + 2,
64 + (idx - 1) * 6,
64 + (idx - 1) * 6 + 1,
64 + (idx - 1) * 6 + 2,
64 + (idx - 1) * 6 + 3,
64 + (idx - 1) * 6 + 4,
64 + (idx - 1) * 6 + 5,
184 + idx * 3,
184 + idx * 3 + 1,
184 + idx * 3 + 2,
]
def get_t2m_slice(idx: int) -> List[int]:
"""
Get the slice indices for the T2M skeleton.
Args:
idx (int): The index of the skeleton part.
Returns:
List[int]: Slice indices for the specified skeleton part.
"""
if idx == 0:
return [0, 1, 2, 3, 193, 194, 195, 259, 260, 261, 262]
return [
4 + (idx - 1) * 3,
4 + (idx - 1) * 3 + 1,
4 + (idx - 1) * 3 + 2,
67 + (idx - 1) * 6,
67 + (idx - 1) * 6 + 1,
67 + (idx - 1) * 6 + 2,
67 + (idx - 1) * 6 + 3,
67 + (idx - 1) * 6 + 4,
67 + (idx - 1) * 6 + 5,
193 + idx * 3,
193 + idx * 3 + 1,
193 + idx * 3 + 2,
]
def get_part_slice(idx_list: List[int], func) -> List[int]:
"""
Get the slice indices for a list of indices.
Args:
idx_list (List[int]): List of part indices.
func (Callable): Function to get slice indices for each part.
Returns:
List[int]: Concatenated list of slice indices for the parts.
"""
result = []
for idx in idx_list:
result.extend(func(idx))
return result
class PoseEncoder(nn.Module):
"""
Pose Encoder to process motion data and encode body parts into latent representations.
"""
def __init__(self,
dataset_name: str = "human_ml3d",
latent_dim: int = 64,
input_dim: int = 263):
super().__init__()
self.dataset_name = dataset_name
if dataset_name == "human_ml3d":
func = get_t2m_slice
self.head_slice = get_part_slice([12, 15], func)
self.stem_slice = get_part_slice([3, 6, 9], func)
self.larm_slice = get_part_slice([14, 17, 19, 21], func)
self.rarm_slice = get_part_slice([13, 16, 18, 20], func)
self.lleg_slice = get_part_slice([2, 5, 8, 11], func)
self.rleg_slice = get_part_slice([1, 4, 7, 10], func)
self.root_slice = get_part_slice([0], func)
self.body_slice = get_part_slice([_ for _ in range(22)], func)
elif dataset_name == "kit_ml":
func = get_kit_slice
self.head_slice = get_part_slice([4], func)
self.stem_slice = get_part_slice([1, 2, 3], func)
self.larm_slice = get_part_slice([8, 9, 10], func)
self.rarm_slice = get_part_slice([5, 6, 7], func)
self.lleg_slice = get_part_slice([16, 17, 18, 19, 20], func)
self.rleg_slice = get_part_slice([11, 12, 13, 14, 15], func)
self.root_slice = get_part_slice([0], func)
self.body_slice = get_part_slice([_ for _ in range(21)], func)
else:
raise ValueError()
self.head_embed = nn.Linear(len(self.head_slice), latent_dim)
self.stem_embed = nn.Linear(len(self.stem_slice), latent_dim)
self.larm_embed = nn.Linear(len(self.larm_slice), latent_dim)
self.rarm_embed = nn.Linear(len(self.rarm_slice), latent_dim)
self.lleg_embed = nn.Linear(len(self.lleg_slice), latent_dim)
self.rleg_embed = nn.Linear(len(self.rleg_slice), latent_dim)
self.root_embed = nn.Linear(len(self.root_slice), latent_dim)
self.body_embed = nn.Linear(len(self.body_slice), latent_dim)
assert len(set(self.body_slice)) == input_dim
def forward(self, motion: torch.Tensor) -> torch.Tensor:
"""
Forward pass for encoding the motion into body part embeddings.
Args:
motion (torch.Tensor): Input motion tensor of shape (B, T, D).
Returns:
torch.Tensor: Concatenated latent representations of body parts.
"""
head_feat = self.head_embed(motion[:, :, self.head_slice].contiguous())
stem_feat = self.stem_embed(motion[:, :, self.stem_slice].contiguous())
larm_feat = self.larm_embed(motion[:, :, self.larm_slice].contiguous())
rarm_feat = self.rarm_embed(motion[:, :, self.rarm_slice].contiguous())
lleg_feat = self.lleg_embed(motion[:, :, self.lleg_slice].contiguous())
rleg_feat = self.rleg_embed(motion[:, :, self.rleg_slice].contiguous())
root_feat = self.root_embed(motion[:, :, self.root_slice].contiguous())
body_feat = self.body_embed(motion[:, :, self.body_slice].contiguous())
feat = torch.cat((head_feat, stem_feat, larm_feat, rarm_feat,
lleg_feat, rleg_feat, root_feat, body_feat),
dim=-1)
return feat
class PoseDecoder(nn.Module):
"""
Pose Decoder to decode the latent representations of body parts back into motion.
"""
def __init__(self,
dataset_name: str = "human_ml3d",
latent_dim: int = 64,
output_dim: int = 263):
super().__init__()
self.dataset_name = dataset_name
self.latent_dim = latent_dim
self.output_dim = output_dim
if dataset_name == "human_ml3d":
func = get_t2m_slice
self.head_slice = get_part_slice([12, 15], func)
self.stem_slice = get_part_slice([3, 6, 9], func)
self.larm_slice = get_part_slice([14, 17, 19, 21], func)
self.rarm_slice = get_part_slice([13, 16, 18, 20], func)
self.lleg_slice = get_part_slice([2, 5, 8, 11], func)
self.rleg_slice = get_part_slice([1, 4, 7, 10], func)
self.root_slice = get_part_slice([0], func)
self.body_slice = get_part_slice([_ for _ in range(22)], func)
elif dataset_name == "kit_ml":
func = get_kit_slice
self.head_slice = get_part_slice([4], func)
self.stem_slice = get_part_slice([1, 2, 3], func)
self.larm_slice = get_part_slice([8, 9, 10], func)
self.rarm_slice = get_part_slice([5, 6, 7], func)
self.lleg_slice = get_part_slice([16, 17, 18, 19, 20], func)
self.rleg_slice = get_part_slice([11, 12, 13, 14, 15], func)
self.root_slice = get_part_slice([0], func)
self.body_slice = get_part_slice([_ for _ in range(21)], func)
else:
raise ValueError()
self.head_out = nn.Linear(latent_dim, len(self.head_slice))
self.stem_out = nn.Linear(latent_dim, len(self.stem_slice))
self.larm_out = nn.Linear(latent_dim, len(self.larm_slice))
self.rarm_out = nn.Linear(latent_dim, len(self.rarm_slice))
self.lleg_out = nn.Linear(latent_dim, len(self.lleg_slice))
self.rleg_out = nn.Linear(latent_dim, len(self.rleg_slice))
self.root_out = nn.Linear(latent_dim, len(self.root_slice))
self.body_out = nn.Linear(latent_dim, len(self.body_slice))
def forward(self, motion: torch.Tensor) -> torch.Tensor:
"""
Forward pass to decode the latent body part features back to motion.
Args:
motion (torch.Tensor): Input tensor of shape (B, T, D).
Returns:
torch.Tensor: Output motion tensor of shape (B, T, output_dim).
"""
B, T = motion.shape[:2]
D = self.latent_dim
head_feat = self.head_out(motion[:, :, :D].contiguous())
stem_feat = self.stem_out(motion[:, :, D:2 * D].contiguous())
larm_feat = self.larm_out(motion[:, :, 2 * D:3 * D].contiguous())
rarm_feat = self.rarm_out(motion[:, :, 3 * D:4 * D].contiguous())
lleg_feat = self.lleg_out(motion[:, :, 4 * D:5 * D].contiguous())
rleg_feat = self.rleg_out(motion[:, :, 5 * D:6 * D].contiguous())
root_feat = self.root_out(motion[:, :, 6 * D:7 * D].contiguous())
body_feat = self.body_out(motion[:, :, 7 * D:].contiguous())
output = torch.zeros(B, T, self.output_dim).type_as(motion)
output[:, :, self.head_slice] = head_feat
output[:, :, self.stem_slice] = stem_feat
output[:, :, self.larm_slice] = larm_feat
output[:, :, self.rarm_slice] = rarm_feat
output[:, :, self.lleg_slice] = lleg_feat
output[:, :, self.rleg_slice] = rleg_feat
output[:, :, self.root_slice] = root_feat
output = (output + body_feat) / 2.0
return output
class SFFN(nn.Module):
"""
A Stylized Feed-Forward Network (SFFN) module for transformer layers.
Args:
latent_dim (int): Dimensionality of the input.
ffn_dim (int): Dimensionality of the feed-forward layer.
dropout (float): Dropout probability.
time_embed_dim (int): Dimensionality of the time embedding.
norm (str): Normalization type ('None').
activation (str): Activation function ('GELU').
"""
def __init__(self,
latent_dim: int,
ffn_dim: int,
dropout: float,
time_embed_dim: int,
norm: str = "None",
activation: str = "GELU",
**kwargs):
super().__init__()
self.linear1_list = nn.ModuleList()
self.linear2_list = nn.ModuleList()
channel_mul = 1
if activation == "GELU":
self.activation = nn.GELU()
for i in range(8):
self.linear1_list.append(nn.Linear(latent_dim, ffn_dim * channel_mul))
self.linear2_list.append(nn.Linear(ffn_dim, latent_dim))
self.dropout = nn.Dropout(dropout)
self.proj_out = StylizationBlock(latent_dim * 8, time_embed_dim, dropout)
if norm == "None":
self.norm = nn.Identity()
def forward(self, x: torch.Tensor, emb: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Forward pass of the SFFN layer.
Args:
x (torch.Tensor): Input tensor of shape (B, T, D).
emb (torch.Tensor): Embedding tensor for time step modulation.
Returns:
torch.Tensor: Output tensor of shape (B, T, D).
"""
B, T, D = x.shape
x = self.norm(x)
x = x.reshape(B, T, 8, -1)
output = []
for i in range(8):
feat = x[:, :, i].contiguous()
feat = self.dropout(self.activation(self.linear1_list[i](feat)))
feat = self.linear2_list[i](feat)
output.append(feat)
y = torch.cat(output, dim=-1)
y = x.reshape(B, T, D) + self.proj_out(y, emb)
return y
class DecoderLayer(nn.Module):
"""
A transformer decoder layer with cross-attention and feed-forward network (SFFN).
Args:
ca_block_cfg (Optional[Dict]): Configuration for the cross-attention block.
ffn_cfg (Optional[Dict]): Configuration for the feed-forward network (SFFN).
"""
def __init__(self, ca_block_cfg: Optional[Dict] = None, ffn_cfg: Optional[Dict] = None):
super().__init__()
self.ca_block = build_attention(ca_block_cfg)
self.ffn = SFFN(**ffn_cfg)
def forward(self, **kwargs) -> torch.Tensor:
"""
Forward pass of the decoder layer.
Args:
kwargs: Keyword arguments for attention and feed-forward layers.
Returns:
torch.Tensor: Output of the decoder layer.
"""
if self.ca_block is not None:
x = self.ca_block(**kwargs)
kwargs.update({'x': x})
if self.ffn is not None:
x = self.ffn(**kwargs)
return x
@SUBMODULES.register_module()
class FineMoGenTransformer(MotionTransformer):
"""
A transformer model for motion generation using fine-grained control with Diffusion.
Args:
scale_func_cfg (Optional[Dict]): Configuration for scaling function.
pose_encoder_cfg (Optional[Dict]): Configuration for the PoseEncoder.
pose_decoder_cfg (Optional[Dict]): Configuration for the PoseDecoder.
moe_route_loss_weight (float): Weight for the Mixture of Experts (MoE) routing loss.
template_kl_loss_weight (float): Weight for the KL loss in template generation.
fine_mode (bool): Whether to enable fine mode for control over body parts.
"""
def __init__(self,
scale_func_cfg: Optional[Dict] = None,
pose_encoder_cfg: Optional[Dict] = None,
pose_decoder_cfg: Optional[Dict] = None,
moe_route_loss_weight: float = 1.0,
template_kl_loss_weight: float = 0.0001,
fine_mode: bool = False,
**kwargs):
super().__init__(**kwargs)
self.scale_func_cfg = scale_func_cfg
self.joint_embed = PoseEncoder(**pose_encoder_cfg)
self.out = zero_module(PoseDecoder(**pose_decoder_cfg))
self.moe_route_loss_weight = moe_route_loss_weight
self.template_kl_loss_weight = template_kl_loss_weight
self.mean = np.load("data/datasets/kit_ml/mean.npy")
self.std = np.load("data/datasets/kit_ml/std.npy")
self.fine_mode = fine_mode
def build_temporal_blocks(self, sa_block_cfg: Optional[Dict], ca_block_cfg: Optional[Dict], ffn_cfg: Optional[Dict]):
"""
Build temporal decoder blocks for the model.
Args:
sa_block_cfg (Optional[Dict]): Configuration for self-attention blocks.
ca_block_cfg (Optional[Dict]): Configuration for cross-attention blocks.
ffn_cfg (Optional[Dict]): Configuration for feed-forward networks.
"""
self.temporal_decoder_blocks = nn.ModuleList()
for i in range(self.num_layers):
if isinstance(ffn_cfg, list):
ffn_cfg_block = ffn_cfg[i]
else:
ffn_cfg_block = ffn_cfg
self.temporal_decoder_blocks.append(DecoderLayer(ca_block_cfg=ca_block_cfg, ffn_cfg=ffn_cfg_block))
def scale_func(self, timestep: int) -> Dict[str, float]:
"""
Scaling function for text and none coefficient based on timestep.
Args:
timestep (int): Current diffusion timestep.
Returns:
Dict[str, float]: Scaling factors for text and non-text conditioning.
"""
scale = self.scale_func_cfg['scale']
w = (1 - (1000 - timestep) / 1000) * scale + 1
return {'text_coef': w, 'none_coef': 1 - w}
def aux_loss(self) -> Dict[str, torch.Tensor]:
"""
Auxiliary loss computation for MoE routing and KL loss.
Returns:
Dict[str, torch.Tensor]: Computed auxiliary losses.
"""
aux_loss = 0
kl_loss = 0
for module in self.temporal_decoder_blocks:
if hasattr(module.ca_block, 'aux_loss'):
aux_loss = aux_loss + module.ca_block.aux_loss
if hasattr(module.ca_block, 'kl_loss'):
kl_loss = kl_loss + module.ca_block.kl_loss
losses = {}
if aux_loss > 0:
losses['moe_route_loss'] = aux_loss * self.moe_route_loss_weight
if kl_loss > 0:
losses['template_kl_loss'] = kl_loss * self.template_kl_loss_weight
return losses
def get_precompute_condition(self,
text: Optional[str] = None,
motion_length: Optional[torch.Tensor] = None,
xf_out: Optional[torch.Tensor] = None,
re_dict: Optional[Dict] = None,
device: Optional[torch.device] = None,
sample_idx: Optional[int] = None,
clip_feat: Optional[torch.Tensor] = None,
**kwargs) -> Dict[str, torch.Tensor]:
"""
Precompute conditioning features for text or other modalities.
Args:
text (Optional[str]): Text input for conditioning.
motion_length (Optional[torch.Tensor]): Length of the motion sequence.
xf_out (Optional[torch.Tensor]): Precomputed text features.
re_dict (Optional[Dict]): Additional features dictionary.
device (Optional[torch.device]): Target device for the model.
sample_idx (Optional[int]): Sample index for specific conditioning.
clip_feat (Optional[torch.Tensor]): Precomputed CLIP features.
Returns:
Dict[str, torch.Tensor]: Precomputed conditioning features.
"""
if xf_out is None:
xf_out = self.encode_text(text, clip_feat, device)
output = {'xf_out': xf_out}
return output
def post_process(self, motion: torch.Tensor) -> torch.Tensor:
"""
Post-process motion data by unnormalizing if necessary.
Args:
motion (torch.Tensor): Input motion data.
Returns:
torch.Tensor: Processed motion data.
"""
if self.post_process_cfg is not None:
if self.post_process_cfg.get("unnormalized_infer", False):
mean = torch.from_numpy(np.load(self.post_process_cfg['mean_path'])).type_as(motion)
std = torch.from_numpy(np.load(self.post_process_cfg['std_path'])).type_as(motion)
motion = motion * std + mean
return motion
def forward_train(self,
h: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
emb: Optional[torch.Tensor] = None,
xf_out: Optional[torch.Tensor] = None,
motion_length: Optional[torch.Tensor] = None,
num_intervals: int = 1,
**kwargs) -> torch.Tensor:
"""
Forward pass during training.
Args:
h (torch.Tensor): Input tensor of shape (B, T, D).
src_mask (Optional[torch.Tensor]): Source mask tensor.
emb (Optional[torch.Tensor]): Time embedding tensor.
xf_out (Optional[torch.Tensor]): Precomputed text features.
motion_length (Optional[torch.Tensor]): Lengths of motion sequences.
num_intervals (int): Number of intervals for processing.
Returns:
torch.Tensor: Output tensor of shape (B, T, D).
"""
B, T = h.shape[0], h.shape[1]
cond_type = torch.randint(0, 100, size=(B, 1, 1)).repeat(1, 8, 1).to(h.device) if self.fine_mode else torch.randint(0, 100, size=(B, 1, 1)).to(h.device)
for module in self.temporal_decoder_blocks:
h = module(x=h,
xf=xf_out,
emb=emb,
src_mask=src_mask,
cond_type=cond_type,
motion_length=motion_length,
num_intervals=num_intervals)
output = self.out(h).view(B, T, -1).contiguous()
return output
def forward_test(self,
h: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
emb: Optional[torch.Tensor] = None,
xf_out: Optional[torch.Tensor] = None,
timesteps: Optional[torch.Tensor] = None,
motion_length: Optional[torch.Tensor] = None,
num_intervals: int = 1,
**kwargs) -> torch.Tensor:
"""
Forward pass during inference.
Args:
h (torch.Tensor): Input tensor of shape (B, T, D).
src_mask (Optional[torch.Tensor]): Source mask tensor.
emb (Optional[torch.Tensor]): Time embedding tensor.
xf_out (Optional[torch.Tensor]): Precomputed text features.
timesteps (Optional[torch.Tensor]): Diffusion timesteps.
motion_length (Optional[torch.Tensor]): Lengths of motion sequences.
num_intervals (int): Number of intervals for processing.
Returns:
torch.Tensor: Output tensor of shape (B, T, D).
"""
B, T = h.shape[0], h.shape[1]
text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1
none_cond_type = torch.zeros(B, 1, 1).to(h.device)
all_cond_type = torch.cat((text_cond_type, none_cond_type), dim=0)
h = h.repeat(2, 1, 1)
xf_out = xf_out.repeat(2, 1, 1)
emb = emb.repeat(2, 1)
src_mask = src_mask.repeat(2, 1, 1)
motion_length = motion_length.repeat(2, 1)
for module in self.temporal_decoder_blocks:
h = module(x=h,
xf=xf_out,
emb=emb,
src_mask=src_mask,
cond_type=all_cond_type,
motion_length=motion_length,
num_intervals=num_intervals)
out = self.out(h).view(2 * B, T, -1).contiguous()
out_text = out[:B].contiguous()
out_none = out[B:].contiguous()
coef_cfg = self.scale_func(int(timesteps[0]))
text_coef = coef_cfg['text_coef']
none_coef = coef_cfg['none_coef']
output = out_text * text_coef + out_none * none_coef
return output