from abc import ABCMeta, abstractmethod import clip import torch from torch import nn from mmcv.runner import BaseModule from ..builder import build_attention from mogen.models.utils.position_encoding import ( timestep_embedding ) from mogen.models.utils.stylization_block import StylizationBlock from mogen.models.utils.misc import set_requires_grad, zero_module class CLIPWrapper: def __init__(self, clip_model): self.clip_model = clip_model self.device = "cpu" def __call__(self, **kwargs): return self.clip_model(**kwargs) def encode_text(self, text): if text.is_cuda and self.device == "cpu": self.clip_model = self.clip_model.cuda() self.device = "cuda" if not text.is_cuda and self.device == "cuda": self.clip_model = self.clip_model.cpu() self.device = "cpu" return self.clip_model.encode_text(text) def to(self, device): self.clip_model = self.clip_model.to(device) class FFN(nn.Module): def __init__(self, latent_dim, ffn_dim, dropout, time_embed_dim=None): super().__init__() self.linear1 = nn.Linear(latent_dim, ffn_dim) self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim)) self.activation = nn.GELU() self.dropout = nn.Dropout(dropout) if time_embed_dim is not None: self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) else: self.proj_out = None def forward(self, x, emb=None, **kwargs): y = self.linear2(self.dropout(self.activation(self.linear1(x)))) if self.proj_out is not None: y = x + self.proj_out(y, emb) else: y = x + y return y class DecoderLayer(nn.Module): def __init__(self, sa_block_cfg=None, ca_block_cfg=None, ffn_cfg=None): super().__init__() self.sa_block = build_attention(sa_block_cfg) self.ca_block = build_attention(ca_block_cfg) self.ffn = FFN(**ffn_cfg) def forward(self, **kwargs): if self.sa_block is not None: x = self.sa_block(**kwargs) kwargs.update({'x': x}) if self.ca_block is not None: x = self.ca_block(**kwargs) kwargs.update({'x': x}) if self.ffn is not None: x = self.ffn(**kwargs) return x class MotionTransformer(BaseModule, metaclass=ABCMeta): def __init__(self, input_feats, max_seq_len=240, latent_dim=512, time_embed_dim=2048, num_layers=8, sa_block_cfg=None, ca_block_cfg=None, ffn_cfg=None, text_encoder=None, use_pos_embedding=True, use_residual_connection=False, time_embedding_type='sinusoidal', post_process_cfg=None, init_cfg=None): super().__init__(init_cfg=init_cfg) self.input_feats = input_feats self.max_seq_len = max_seq_len self.latent_dim = latent_dim self.num_layers = num_layers self.time_embed_dim = time_embed_dim self.use_pos_embedding = use_pos_embedding if self.use_pos_embedding: self.sequence_embedding = nn.Parameter(torch.randn(max_seq_len, latent_dim)) self.build_text_encoder(text_encoder) # Input Embedding self.joint_embed = nn.Linear(self.input_feats, self.latent_dim) self.time_embedding_type = time_embedding_type if time_embedding_type != 'none': if time_embedding_type == 'learnable': self.time_tokens = nn.Embedding(1000, self.latent_dim) self.time_embed = nn.Sequential( nn.Linear(self.latent_dim, self.time_embed_dim), nn.SiLU(), nn.Linear(self.time_embed_dim, self.time_embed_dim), ) self.build_temporal_blocks(sa_block_cfg, ca_block_cfg, ffn_cfg) # Output Module self.out = zero_module(nn.Linear(self.latent_dim, self.input_feats)) self.use_residual_connection = use_residual_connection self.post_process_cfg = post_process_cfg def build_temporal_blocks(self, sa_block_cfg, ca_block_cfg, ffn_cfg): self.temporal_decoder_blocks = nn.ModuleList() for i in range(self.num_layers): self.temporal_decoder_blocks.append( DecoderLayer( sa_block_cfg=sa_block_cfg, ca_block_cfg=ca_block_cfg, ffn_cfg=ffn_cfg ) ) def build_text_encoder(self, text_encoder): if text_encoder is None: self.use_text_proj = False return text_latent_dim = text_encoder['latent_dim'] num_text_layers = text_encoder.get('num_layers', 0) text_ff_size = text_encoder.get('ff_size', 2048) pretrained_model = text_encoder['pretrained_model'] text_num_heads = text_encoder.get('num_heads', 4) dropout = text_encoder.get('dropout', 0) activation = text_encoder.get('activation', 'gelu') self.use_text_proj = text_encoder.get('use_text_proj', False) if pretrained_model == 'clip': clip_model, _ = clip.load('ViT-B/32', "cpu") set_requires_grad(clip_model, False) self.clip = CLIPWrapper(clip_model) if text_latent_dim != 512: self.text_pre_proj = nn.Linear(512, text_latent_dim) else: self.text_pre_proj = nn.Identity() else: raise NotImplementedError() if num_text_layers > 0: self.use_text_finetune = True textTransEncoderLayer = nn.TransformerEncoderLayer( d_model=text_latent_dim, nhead=text_num_heads, dim_feedforward=text_ff_size, dropout=dropout, activation=activation) self.textTransEncoder = nn.TransformerEncoder( textTransEncoderLayer, num_layers=num_text_layers) else: self.use_text_finetune = False self.text_ln = nn.LayerNorm(text_latent_dim) if self.use_text_proj: self.text_proj = nn.Sequential( nn.Linear(text_latent_dim, self.time_embed_dim) ) def encode_text(self, text, clip_feat, device): B = len(text) if type(text[0]) is dict: knames = ["head", "stem", "left_arm", "right_arm", "left_leg", "right_leg", "pelvis", "all"] new_text = [] for item in text: for kname in knames: new_text.append(item[kname]) text = new_text text = clip.tokenize(text, truncate=True).to(device) if clip_feat is None: with torch.no_grad(): if isinstance(self.clip, CLIPWrapper): self.clip.to(device) dtype = self.clip.clip_model.dtype # [batch_size, n_ctx, d_model] x = self.clip.clip_model.token_embedding(text).type(dtype) x = x + self.clip.clip_model.positional_embedding.type(dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.clip.clip_model.transformer(x) x = self.clip.clip_model.ln_final(x).type(dtype) else: dtype = self.clip.dtype # [batch_size, n_ctx, d_model] x = self.clip.token_embedding(text).type(dtype) x = x + self.clip.positional_embedding.type(dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.clip.transformer(x) x = self.clip.ln_final(x).type(dtype) else: x = clip_feat.float().to(device) if len(x.shape) == 4: x = x.permute(1, 0, 2, 3) x = x.reshape([x.shape[0], x.shape[1] * x.shape[2], x.shape[3]]) else: x = x.permute(1, 0, 2) # T, B, D x = self.text_pre_proj(x) xf_out = self.textTransEncoder(x) xf_out = self.text_ln(xf_out) if self.use_text_proj: xf_proj = self.text_proj(xf_out[text.argmax(dim=-1), torch.arange(xf_out.shape[1])]) # B, T, D xf_out = xf_out.permute(1, 0, 2) return xf_proj, xf_out else: xf_out = xf_out.permute(1, 0, 2) return xf_out @abstractmethod def get_precompute_condition(self, **kwargs): pass @abstractmethod def forward_train(self, h, src_mask, emb, **kwargs): pass @abstractmethod def forward_test(self, h, src_mask, emb, **kwargs): pass def forward(self, motion, timesteps=None, motion_mask=None, motion_length=None, num_intervals=1, **kwargs): """ motion: B, T, D """ B, T = motion.shape[0], motion.shape[1] conditions = self.get_precompute_condition(device=motion.device, motion_length=motion_length, **kwargs) if len(motion_mask.shape) == 2: src_mask = motion_mask.clone().unsqueeze(-1) else: src_mask = motion_mask.clone() if self.time_embedding_type != 'none': if self.time_embedding_type == 'sinusoidal': emb = self.time_embed(timestep_embedding(timesteps, self.latent_dim)) else: emb = self.time_embed(self.time_tokens(timesteps)) if self.use_text_proj: emb = emb + conditions['xf_proj'] else: emb = None # B, T, latent_dim h = self.joint_embed(motion) if self.use_pos_embedding: h = h + self.sequence_embedding.unsqueeze(0)[:, :T, :] if self.training: output = self.forward_train( h=h, src_mask=src_mask, emb=emb, timesteps=timesteps, motion_length=motion_length, num_intervals=num_intervals, motion=motion, **conditions) else: output = self.forward_test( h=h, src_mask=src_mask, emb=emb, timesteps=timesteps, motion_length=motion_length, num_intervals=num_intervals, **conditions) if self.use_residual_connection: output = motion + output return output