File size: 10,880 Bytes
373af33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 |
import torch
from torch import nn
from typing import Optional
from mogen.models.utils.misc import zero_module
from mogen.models.utils.position_encoding import timestep_embedding
from mogen.models.utils.stylization_block import StylizationBlock
from ..builder import SUBMODULES, build_attention
from .remodiffuse import ReMoDiffuseTransformer
class FFN(nn.Module):
"""
A feed-forward network (FFN) with optional stylization block.
Args:
latent_dim (int): The dimension of the input and output latent space.
ffn_dim (int): The dimension of the hidden feed-forward network.
dropout (float): The dropout rate to apply after activation.
time_embed_dim (int): The dimension of the time embedding.
"""
def __init__(self, latent_dim: int, ffn_dim: int, dropout: float, time_embed_dim: int):
super().__init__()
self.latent_dim = latent_dim
self.linear1 = nn.Linear(latent_dim, ffn_dim)
self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim))
self.activation = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
def forward(self, x: torch.Tensor, emb: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Forward pass of the FFN layer.
Args:
x (torch.Tensor): Input tensor of shape (B, T, latent_dim*2).
emb (torch.Tensor): Time embedding tensor.
Returns:
torch.Tensor: Output tensor after FFN and stylization block.
"""
x1 = x[:, :, :self.latent_dim].contiguous()
x2 = x[:, :, self.latent_dim:].contiguous()
y1 = self.linear2(self.dropout(self.activation(self.linear1(x1))))
y1 = x1 + self.proj_out(y1, emb)
y2 = self.linear2(self.dropout(self.activation(self.linear1(x2))))
y2 = x2 + self.proj_out(y2, emb)
y = torch.cat((y1, y2), dim=-1)
return y
class DecoderLayer(nn.Module):
"""
A single decoder layer consisting of a cross-attention block and a feed-forward network (FFN).
Args:
ca_block_cfg (Optional[dict]): Configuration for the cross-attention block.
ffn_cfg (Optional[dict]): Configuration for the feed-forward network.
"""
def __init__(self, ca_block_cfg: Optional[dict] = None, ffn_cfg: Optional[dict] = None):
super().__init__()
self.ca_block = build_attention(ca_block_cfg)
self.ffn = FFN(**ffn_cfg)
def forward(self, **kwargs) -> torch.Tensor:
"""
Forward pass of the decoder layer.
Args:
**kwargs: Arguments passed to the cross-attention and FFN layers.
Returns:
torch.Tensor: Output tensor after passing through the layer.
"""
if self.ca_block is not None:
x = self.ca_block(**kwargs)
kwargs.update({'x': x})
if self.ffn is not None:
x = self.ffn(**kwargs)
return x
@SUBMODULES.register_module()
class MoMatMoGenTransformer(ReMoDiffuseTransformer):
"""
MoMatMoGenTransformer is a motion generation transformer model, which uses ReMoDiffuse as a base.
Args:
ReMoDiffuseTransformer: Base transformer class.
"""
def build_temporal_blocks(self, sa_block_cfg: Optional[dict], ca_block_cfg: Optional[dict], ffn_cfg: Optional[dict]):
"""
Build temporal decoder blocks using the provided configurations.
Args:
sa_block_cfg (Optional[dict]): Self-attention block configuration.
ca_block_cfg (Optional[dict]): Cross-attention block configuration.
ffn_cfg (Optional[dict]): Feed-forward network configuration.
"""
self.temporal_decoder_blocks = nn.ModuleList()
for i in range(self.num_layers):
self.temporal_decoder_blocks.append(
DecoderLayer(ca_block_cfg=ca_block_cfg, ffn_cfg=ffn_cfg))
def forward(self,
motion: torch.Tensor,
timesteps: torch.Tensor,
motion_mask: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""
Forward pass for motion generation.
Args:
motion (torch.Tensor): Input motion tensor of shape (B, T, D).
timesteps (torch.Tensor): Timestep embeddings.
motion_mask (Optional[torch.Tensor]): Motion mask, if any.
Returns:
torch.Tensor: Output tensor after processing the motion data.
"""
T = motion.shape[1]
conditions = self.get_precompute_condition(device=motion.device,
**kwargs)
if len(motion_mask.shape) == 2:
src_mask = motion_mask.clone().unsqueeze(-1)
else:
src_mask = motion_mask.clone()
if self.time_embedding_type == 'sinusoidal':
emb = self.time_embed(
timestep_embedding(timesteps, self.latent_dim))
else:
emb = self.time_embed(self.time_tokens(timesteps))
if self.use_text_proj:
emb = emb + conditions['xf_proj']
motion1 = motion[:, :, :self.input_feats].contiguous()
motion2 = motion[:, :, self.input_feats:].contiguous()
h1 = self.joint_embed(motion1)
h2 = self.joint_embed(motion2)
if self.use_pos_embedding:
h1 = h1 + self.sequence_embedding.unsqueeze(0)[:, :T, :]
h2 = h2 + self.sequence_embedding.unsqueeze(0)[:, :T, :]
h = torch.cat((h1, h2), dim=-1)
if self.training:
output = self.forward_train(h=h,
src_mask=src_mask,
emb=emb,
timesteps=timesteps,
**conditions)
else:
output = self.forward_test(h=h,
src_mask=src_mask,
emb=emb,
timesteps=timesteps,
**conditions)
if self.use_residual_connection:
output = motion + output
return output
def forward_train(self,
h: Optional[torch.Tensor] = None,
src_mask: Optional[torch.Tensor] = None,
emb: Optional[torch.Tensor] = None,
xf_out: Optional[torch.Tensor] = None,
re_dict: Optional[dict] = None,
**kwargs) -> torch.Tensor:
"""
Training forward pass for the motion generation transformer.
Args:
h (Optional[torch.Tensor]): Input tensor.
src_mask (Optional[torch.Tensor]): Source mask.
emb (Optional[torch.Tensor]): Embedding tensor.
xf_out (Optional[torch.Tensor]): Output of the cross-attention block.
re_dict (Optional[dict]): Dictionary for recurrent features.
Returns:
torch.Tensor: Output tensor after processing.
"""
B, T = h.shape[0], h.shape[1]
cond_type = torch.randint(0, 100, size=(B, 1, 1)).to(h.device)
for module in self.temporal_decoder_blocks:
h = module(x=h,
xf=xf_out,
emb=emb,
src_mask=src_mask,
cond_type=cond_type,
re_dict=re_dict)
out1 = self.out(h[:, :, :self.latent_dim].contiguous())
out1 = out1.view(B, T, -1).contiguous()
out2 = self.out(h[:, :, self.latent_dim:].contiguous())
out2 = out2.view(B, T, -1).contiguous()
output = torch.cat((out1, out2), dim=-1)
return output
def forward_test(self,
h: Optional[torch.Tensor] = None,
src_mask: Optional[torch.Tensor] = None,
emb: Optional[torch.Tensor] = None,
xf_out: Optional[torch.Tensor] = None,
re_dict: Optional[dict] = None,
timesteps: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""
Testing forward pass for the motion generation transformer.
Args:
h (Optional[torch.Tensor]): Input tensor.
src_mask (Optional[torch.Tensor]): Source mask.
emb (Optional[torch.Tensor]): Embedding tensor.
xf_out (Optional[torch.Tensor]): Output of the cross-attention block.
re_dict (Optional[dict]): Dictionary for recurrent features.
timesteps (Optional[torch.Tensor]): Timestep embeddings.
Returns:
torch.Tensor: Output tensor after processing.
"""
B, T = h.shape[0], h.shape[1]
both_cond_type = torch.zeros(B, 1, 1).to(h.device) + 99
text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1
retr_cond_type = torch.zeros(B, 1, 1).to(h.device) + 10
none_cond_type = torch.zeros(B, 1, 1).to(h.device)
all_cond_type = torch.cat(
(both_cond_type, text_cond_type, retr_cond_type, none_cond_type),
dim=0)
h = h.repeat(4, 1, 1)
xf_out = xf_out.repeat(4, 1, 1)
emb = emb.repeat(4, 1)
src_mask = src_mask.repeat(4, 1, 1)
if re_dict['re_motion'].shape[0] != h.shape[0]:
re_dict['re_motion'] = re_dict['re_motion'].repeat(4, 1, 1, 1)
re_dict['re_text'] = re_dict['re_text'].repeat(4, 1, 1, 1)
re_dict['re_mask'] = re_dict['re_mask'].repeat(4, 1, 1)
for module in self.temporal_decoder_blocks:
h = module(x=h,
xf=xf_out,
emb=emb,
src_mask=src_mask,
cond_type=all_cond_type,
re_dict=re_dict)
out1 = self.out(h[:, :, :self.latent_dim].contiguous())
out1 = out1.view(4 * B, T, -1).contiguous()
out2 = self.out(h[:, :, self.latent_dim:].contiguous())
out2 = out2.view(4 * B, T, -1).contiguous()
out = torch.cat((out1, out2), dim=-1)
out_both = out[:B].contiguous()
out_text = out[B:2 * B].contiguous()
out_retr = out[2 * B:3 * B].contiguous()
out_none = out[3 * B:].contiguous()
coef_cfg = self.scale_func(int(timesteps[0]))
both_coef = coef_cfg['both_coef']
text_coef = coef_cfg['text_coef']
retr_coef = coef_cfg['retr_coef']
none_coef = coef_cfg['none_coef']
output = out_both * both_coef
output += out_text * text_coef
output += out_retr * retr_coef
output += out_none * none_coef
return output
|