mingyuan's picture
initial commit
373af33
raw
history blame
7.06 kB
import clip
import numpy as np
import torch
import torch.nn as nn
from mmcv.runner import BaseModule
from mogen.models.utils.misc import set_requires_grad
from ..builder import SUBMODULES
loss_ce = nn.CrossEntropyLoss()
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.0, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.arange(0, d_model, 2).float() * \
(-np.log(10000.0) / d_model)
div_term = torch.exp(div_term)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# pe = pe.unsqueeze(0)#.transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
# not used in the final model
x = x + self.pe[:x.shape[1], :].unsqueeze(0)
return self.dropout(x)
class MotionEncoder(nn.Module):
def __init__(self, input_dim, latent_dim, ff_size, num_layers, num_heads,
dropout, activation):
super().__init__()
self.input_feats = input_dim
self.latent_dim = latent_dim
self.ff_size = ff_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.activation = activation
self.query_token = nn.Parameter(torch.randn(1, self.latent_dim))
self.embed_motion = nn.Linear(self.input_feats * 2, self.latent_dim)
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim,
self.dropout,
max_len=2000)
seqTransEncoderLayer = nn.TransformerEncoderLayer(
d_model=self.latent_dim,
nhead=self.num_heads,
dim_feedforward=self.ff_size,
dropout=self.dropout,
activation=self.activation)
self.transformer = nn.TransformerEncoder(seqTransEncoderLayer,
num_layers=self.num_layers)
self.out_ln = nn.LayerNorm(self.latent_dim)
self.out = nn.Linear(self.latent_dim, 512)
def forward(self, motion, motion_mask):
x, mask = motion, motion_mask
B, T = x.shape[:2]
x = x.reshape(B, T, 2, -1)[..., :-4].reshape(B, T, -1)
x_emb = self.embed_motion(x)
idx = torch.zeros(B, dtype=torch.long, device=x.device)
emb = torch.cat([self.query_token[idx][:, None], x_emb], dim=1)
seq_mask = (mask > 0.5)
token_mask = torch.ones((B, 1), dtype=bool, device=x.device)
valid_mask = torch.cat([token_mask, seq_mask], dim=1)
h = self.sequence_pos_encoder(emb)
h = h.permute(1, 0, 2)
h = self.transformer(h, src_key_padding_mask=~valid_mask).permute(
1, 0, 2)
h = self.out_ln(h)
motion_emb = self.out(h[:, 0])
return motion_emb
@SUBMODULES.register_module()
class InterCLIP(BaseModule):
def __init__(self,
input_dim=258,
latent_dim=1024,
ff_size=2048,
num_layers=8,
num_heads=8,
dropout=0.1,
activation="gelu",
init_cfg=None):
super().__init__()
self.latent_dim = latent_dim
self.motion_encoder = MotionEncoder(input_dim=input_dim,
latent_dim=latent_dim,
ff_size=ff_size,
num_layers=num_layers,
num_heads=num_heads,
dropout=dropout,
activation=activation)
self.latent_dim = self.latent_dim
clip_model, _ = clip.load("ViT-L/14@336px", device="cpu", jit=False)
self.token_embedding = clip_model.token_embedding
self.positional_embedding = clip_model.positional_embedding
self.dtype = clip_model.dtype
self.latent_scale = nn.Parameter(torch.Tensor([1]))
set_requires_grad(self.token_embedding, False)
textTransEncoderLayer = nn.TransformerEncoderLayer(
d_model=768,
nhead=8,
dim_feedforward=ff_size,
dropout=0.1,
activation="gelu")
self.textTransEncoder = nn.TransformerEncoder(textTransEncoderLayer,
num_layers=8)
self.text_ln = nn.LayerNorm(768)
self.out = nn.Linear(768, 512)
self.clip_training = "text_"
self.l1_criterion = torch.nn.L1Loss(reduction='mean')
assert init_cfg['type'] == 'Pretrained'
self.load_pretrained(init_cfg['checkpoint'])
def compute_loss(self, batch):
losses = {}
losses["total"] = 0
# compute clip losses
batch = self.encode_text(batch)
batch = self.encode_motion(batch)
mixed_clip_loss, clip_losses = self.compute_clip_losses(batch)
losses.update(clip_losses)
losses["total"] += mixed_clip_loss
return losses["total"], losses
def generate_src_mask(self, T, length):
B = length.shape[0]
src_mask = torch.ones(B, T)
for i in range(B):
for j in range(length[i], T):
src_mask[i, j] = 0
return src_mask
def encode_motion(self,
motion,
motion_length=None,
motion_mask=None,
**kwargs):
motion_emb = self.motion_encoder(motion, motion_mask)
motion_emb = motion_emb / motion_emb.norm(dim=-1, keepdim=True)
motion_emb = motion_emb * self.latent_scale
return motion_emb
def encode_text(self, text, device=None, **kwargs):
raw_text = text
with torch.no_grad():
text = clip.tokenize(raw_text, truncate=True).to(device)
x = self.token_embedding(text).type(self.dtype)
pe_tokens = x + self.positional_embedding.type(self.dtype)
pe_tokens = pe_tokens.permute(1, 0, 2)
out = self.textTransEncoder(pe_tokens)
out = out.permute(1, 0, 2)
out = self.text_ln(out)
out = out[torch.arange(x.shape[0]), text.argmax(dim=-1)]
out = self.out(out)
text_emb = out
text_emb = text_emb / text_emb.norm(dim=-1, keepdim=True)
text_emb = text_emb * self.latent_scale
return text_emb
def load_pretrained(self, ckpt_path):
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint["state_dict"]
for k in list(state_dict.keys()):
if "model" in k:
state_dict[k.replace("model.", "")] = state_dict.pop(k)
self.load_state_dict(state_dict, strict=True)