import clip import numpy as np import torch import torch.nn as nn from ..builder import SUBMODULES def convert_weights(model: nn.Module): """Convert applicable model parameters to fp32""" def _convert_weights_to_fp32(m): if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): m.weight.data = m.weight.data.float() if m.bias is not None: m.bias.data = m.bias.data.float() if isinstance(m, nn.MultiheadAttention): attr_list = [f"{s}_proj_weight" for s in ["in", "q", "k", "v"]] attr_list += ["in_proj_bias", "bias_k", "bias_v"] for attr in attr_list: tensor = getattr(m, attr) if tensor is not None: tensor.data = tensor.data.float() for name in ["text_projection", "proj"]: if hasattr(m, name): attr = getattr(m, name) if attr is not None: attr.data = attr.data.float() model.apply(_convert_weights_to_fp32) @SUBMODULES.register_module() class MDMTransformer(nn.Module): def __init__(self, input_feats=263, latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1, activation="gelu", clip_dim=512, clip_version=None, guide_scale=1.0, cond_mask_prob=0.1, use_official_ckpt=False, **kwargs): super().__init__() self.latent_dim = latent_dim self.ff_size = ff_size self.num_layers = num_layers self.num_heads = num_heads self.dropout = dropout self.activation = activation self.clip_dim = clip_dim self.input_feats = input_feats self.guide_scale = guide_scale self.use_official_ckpt = use_official_ckpt self.cond_mask_prob = cond_mask_prob self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) seqTransEncoderLayer = nn.TransformerEncoderLayer( d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.seqTransEncoder = nn.TransformerEncoder( seqTransEncoderLayer, num_layers=self.num_layers) self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) self.embed_text = nn.Linear(self.clip_dim, self.latent_dim) self.clip_version = clip_version self.clip_model = self.load_and_freeze_clip(clip_version) self.poseFinal = nn.Linear(self.latent_dim, self.input_feats) def load_and_freeze_clip(self, clip_version): clip_model, _ = clip.load(clip_version, device='cpu', jit=False) clip.model.convert_weights(clip_model) clip_model.eval() for p in clip_model.parameters(): p.requires_grad = False return clip_model def mask_cond(self, cond, force_mask=False): bs = cond.shape[0] if force_mask: return torch.zeros_like(cond) elif self.training and self.cond_mask_prob > 0.: mask = torch.ones(bs, device=cond.device) * self.cond_mask_prob # 1-> use null_cond, 0-> use real cond mask = torch.bernoulli(mask).view(bs, 1) return cond * (1. - mask) else: return cond def encode_text(self, raw_text): device = next(self.parameters()).device max_text_len = 20 if max_text_len is not None: default_context_length = 77 context_length = max_text_len + 2 # start_token + 20 + end_token assert context_length < default_context_length texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) zero_pad = torch.zeros( [texts.shape[0], default_context_length - context_length], dtype=texts.dtype, device=texts.device) texts = torch.cat([texts, zero_pad], dim=1) return self.clip_model.encode_text(texts).float() def get_precompute_condition(self, text, device=None, **kwargs): if not self.training and device == torch.device('cpu'): convert_weights(self.clip_model) text_feat = self.encode_text(text) return {'text_feat': text_feat} def post_process(self, motion): assert len(motion.shape) == 3 if self.use_official_ckpt: motion[:, :, :4] = motion[:, :, :4] * 25 return motion def forward(self, motion, timesteps, text_feat=None, **kwargs): """ motion: B, T, D timesteps: [batch_size] (int) """ B, T, D = motion.shape if text_feat is None: enc_text = self.get_precompute_condition(**kwargs)['text_feat'] else: enc_text = text_feat if self.training: # T, B, D motion = self.poseEmbedding(motion).permute(1, 0, 2) emb = self.embed_timestep(timesteps) # [1, bs, d] emb += self.embed_text(self.mask_cond(enc_text, force_mask=False)) xseq = self.sequence_pos_encoder(torch.cat((emb, motion), axis=0)) output = self.seqTransEncoder(xseq)[1:] # B, T, D output = self.poseFinal(output).permute(1, 0, 2) return output else: # T, B, D motion = self.poseEmbedding(motion).permute(1, 0, 2) emb = self.embed_timestep(timesteps) # [1, bs, d] emb_uncond = emb + \ self.embed_text(self.mask_cond(enc_text, force_mask=True)) emb_text = emb + \ self.embed_text(self.mask_cond(enc_text, force_mask=False)) xseq = self.sequence_pos_encoder( torch.cat((emb_uncond, motion), axis=0)) xseq_text = self.sequence_pos_encoder( torch.cat((emb_text, motion), axis=0)) output = self.seqTransEncoder(xseq)[1:] output_text = self.seqTransEncoder(xseq_text)[1:] # B, T, D output = self.poseFinal(output).permute(1, 0, 2) output_text = self.poseFinal(output_text).permute(1, 0, 2) scale = self.guide_scale output = output + scale * (output_text - output) return output class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.arange(0, d_model, 2).float() * \ (-np.log(10000.0) / d_model) div_term = torch.exp(div_term) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): # not used in the final model x = x + self.pe[:x.shape[0], :] return self.dropout(x) class TimestepEmbedder(nn.Module): def __init__(self, latent_dim, sequence_pos_encoder): super().__init__() self.latent_dim = latent_dim self.sequence_pos_encoder = sequence_pos_encoder time_embed_dim = self.latent_dim self.time_embed = nn.Sequential( nn.Linear(self.latent_dim, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim), ) def forward(self, timesteps): output = self.time_embed(self.sequence_pos_encoder.pe[timesteps]) output = output.permute(1, 0, 2) return output