File size: 8,185 Bytes
373af33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
import clip
import numpy as np
import torch
import torch.nn as nn
from ..builder import SUBMODULES
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp32"""
def _convert_weights_to_fp32(m):
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)):
m.weight.data = m.weight.data.float()
if m.bias is not None:
m.bias.data = m.bias.data.float()
if isinstance(m, nn.MultiheadAttention):
attr_list = [f"{s}_proj_weight" for s in ["in", "q", "k", "v"]]
attr_list += ["in_proj_bias", "bias_k", "bias_v"]
for attr in attr_list:
tensor = getattr(m, attr)
if tensor is not None:
tensor.data = tensor.data.float()
for name in ["text_projection", "proj"]:
if hasattr(m, name):
attr = getattr(m, name)
if attr is not None:
attr.data = attr.data.float()
model.apply(_convert_weights_to_fp32)
@SUBMODULES.register_module()
class MDMTransformer(nn.Module):
def __init__(self,
input_feats=263,
latent_dim=256,
ff_size=1024,
num_layers=8,
num_heads=4,
dropout=0.1,
activation="gelu",
clip_dim=512,
clip_version=None,
guide_scale=1.0,
cond_mask_prob=0.1,
use_official_ckpt=False,
**kwargs):
super().__init__()
self.latent_dim = latent_dim
self.ff_size = ff_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.activation = activation
self.clip_dim = clip_dim
self.input_feats = input_feats
self.guide_scale = guide_scale
self.use_official_ckpt = use_official_ckpt
self.cond_mask_prob = cond_mask_prob
self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim)
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim,
self.dropout)
seqTransEncoderLayer = nn.TransformerEncoderLayer(
d_model=self.latent_dim,
nhead=self.num_heads,
dim_feedforward=self.ff_size,
dropout=self.dropout,
activation=self.activation)
self.seqTransEncoder = nn.TransformerEncoder(
seqTransEncoderLayer, num_layers=self.num_layers)
self.embed_timestep = TimestepEmbedder(self.latent_dim,
self.sequence_pos_encoder)
self.embed_text = nn.Linear(self.clip_dim, self.latent_dim)
self.clip_version = clip_version
self.clip_model = self.load_and_freeze_clip(clip_version)
self.poseFinal = nn.Linear(self.latent_dim, self.input_feats)
def load_and_freeze_clip(self, clip_version):
clip_model, _ = clip.load(clip_version, device='cpu', jit=False)
clip.model.convert_weights(clip_model)
clip_model.eval()
for p in clip_model.parameters():
p.requires_grad = False
return clip_model
def mask_cond(self, cond, force_mask=False):
bs = cond.shape[0]
if force_mask:
return torch.zeros_like(cond)
elif self.training and self.cond_mask_prob > 0.:
mask = torch.ones(bs, device=cond.device) * self.cond_mask_prob
# 1-> use null_cond, 0-> use real cond
mask = torch.bernoulli(mask).view(bs, 1)
return cond * (1. - mask)
else:
return cond
def encode_text(self, raw_text):
device = next(self.parameters()).device
max_text_len = 20
if max_text_len is not None:
default_context_length = 77
context_length = max_text_len + 2 # start_token + 20 + end_token
assert context_length < default_context_length
texts = clip.tokenize(raw_text,
context_length=context_length,
truncate=True).to(device)
zero_pad = torch.zeros(
[texts.shape[0], default_context_length - context_length],
dtype=texts.dtype,
device=texts.device)
texts = torch.cat([texts, zero_pad], dim=1)
return self.clip_model.encode_text(texts).float()
def get_precompute_condition(self, text, device=None, **kwargs):
if not self.training and device == torch.device('cpu'):
convert_weights(self.clip_model)
text_feat = self.encode_text(text)
return {'text_feat': text_feat}
def post_process(self, motion):
assert len(motion.shape) == 3
if self.use_official_ckpt:
motion[:, :, :4] = motion[:, :, :4] * 25
return motion
def forward(self, motion, timesteps, text_feat=None, **kwargs):
"""
motion: B, T, D
timesteps: [batch_size] (int)
"""
B, T, D = motion.shape
if text_feat is None:
enc_text = self.get_precompute_condition(**kwargs)['text_feat']
else:
enc_text = text_feat
if self.training:
# T, B, D
motion = self.poseEmbedding(motion).permute(1, 0, 2)
emb = self.embed_timestep(timesteps) # [1, bs, d]
emb += self.embed_text(self.mask_cond(enc_text, force_mask=False))
xseq = self.sequence_pos_encoder(torch.cat((emb, motion), axis=0))
output = self.seqTransEncoder(xseq)[1:]
# B, T, D
output = self.poseFinal(output).permute(1, 0, 2)
return output
else:
# T, B, D
motion = self.poseEmbedding(motion).permute(1, 0, 2)
emb = self.embed_timestep(timesteps) # [1, bs, d]
emb_uncond = emb + \
self.embed_text(self.mask_cond(enc_text, force_mask=True))
emb_text = emb + \
self.embed_text(self.mask_cond(enc_text, force_mask=False))
xseq = self.sequence_pos_encoder(
torch.cat((emb_uncond, motion), axis=0))
xseq_text = self.sequence_pos_encoder(
torch.cat((emb_text, motion), axis=0))
output = self.seqTransEncoder(xseq)[1:]
output_text = self.seqTransEncoder(xseq_text)[1:]
# B, T, D
output = self.poseFinal(output).permute(1, 0, 2)
output_text = self.poseFinal(output_text).permute(1, 0, 2)
scale = self.guide_scale
output = output + scale * (output_text - output)
return output
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.arange(0, d_model, 2).float() * \
(-np.log(10000.0) / d_model)
div_term = torch.exp(div_term)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
# not used in the final model
x = x + self.pe[:x.shape[0], :]
return self.dropout(x)
class TimestepEmbedder(nn.Module):
def __init__(self, latent_dim, sequence_pos_encoder):
super().__init__()
self.latent_dim = latent_dim
self.sequence_pos_encoder = sequence_pos_encoder
time_embed_dim = self.latent_dim
self.time_embed = nn.Sequential(
nn.Linear(self.latent_dim, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
def forward(self, timesteps):
output = self.time_embed(self.sequence_pos_encoder.pe[timesteps])
output = output.permute(1, 0, 2)
return output
|