|
import clip |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from mmcv.runner import BaseModule |
|
|
|
from mogen.models.utils.misc import set_requires_grad |
|
|
|
from ..builder import SUBMODULES |
|
|
|
loss_ce = nn.CrossEntropyLoss() |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
|
def __init__(self, d_model, dropout=0.0, max_len=5000): |
|
super(PositionalEncoding, self).__init__() |
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
pe = torch.zeros(max_len, d_model) |
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.arange(0, d_model, 2).float() * \ |
|
(-np.log(10000.0) / d_model) |
|
div_term = torch.exp(div_term) |
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
|
|
|
self.register_buffer('pe', pe) |
|
|
|
def forward(self, x): |
|
|
|
x = x + self.pe[:x.shape[1], :].unsqueeze(0) |
|
return self.dropout(x) |
|
|
|
|
|
class MotionEncoder(nn.Module): |
|
|
|
def __init__(self, input_dim, latent_dim, ff_size, num_layers, num_heads, |
|
dropout, activation): |
|
super().__init__() |
|
|
|
self.input_feats = input_dim |
|
self.latent_dim = latent_dim |
|
self.ff_size = ff_size |
|
self.num_layers = num_layers |
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
self.activation = activation |
|
|
|
self.query_token = nn.Parameter(torch.randn(1, self.latent_dim)) |
|
|
|
self.embed_motion = nn.Linear(self.input_feats * 2, self.latent_dim) |
|
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, |
|
self.dropout, |
|
max_len=2000) |
|
|
|
seqTransEncoderLayer = nn.TransformerEncoderLayer( |
|
d_model=self.latent_dim, |
|
nhead=self.num_heads, |
|
dim_feedforward=self.ff_size, |
|
dropout=self.dropout, |
|
activation=self.activation) |
|
self.transformer = nn.TransformerEncoder(seqTransEncoderLayer, |
|
num_layers=self.num_layers) |
|
self.out_ln = nn.LayerNorm(self.latent_dim) |
|
self.out = nn.Linear(self.latent_dim, 512) |
|
|
|
def forward(self, motion, motion_mask): |
|
x, mask = motion, motion_mask |
|
B, T = x.shape[:2] |
|
|
|
x = x.reshape(B, T, 2, -1)[..., :-4].reshape(B, T, -1) |
|
|
|
x_emb = self.embed_motion(x) |
|
|
|
idx = torch.zeros(B, dtype=torch.long, device=x.device) |
|
emb = torch.cat([self.query_token[idx][:, None], x_emb], dim=1) |
|
|
|
seq_mask = (mask > 0.5) |
|
token_mask = torch.ones((B, 1), dtype=bool, device=x.device) |
|
valid_mask = torch.cat([token_mask, seq_mask], dim=1) |
|
|
|
h = self.sequence_pos_encoder(emb) |
|
|
|
h = h.permute(1, 0, 2) |
|
h = self.transformer(h, src_key_padding_mask=~valid_mask).permute( |
|
1, 0, 2) |
|
h = self.out_ln(h) |
|
motion_emb = self.out(h[:, 0]) |
|
|
|
return motion_emb |
|
|
|
|
|
@SUBMODULES.register_module() |
|
class InterCLIP(BaseModule): |
|
|
|
def __init__(self, |
|
input_dim=258, |
|
latent_dim=1024, |
|
ff_size=2048, |
|
num_layers=8, |
|
num_heads=8, |
|
dropout=0.1, |
|
activation="gelu", |
|
init_cfg=None): |
|
super().__init__() |
|
self.latent_dim = latent_dim |
|
self.motion_encoder = MotionEncoder(input_dim=input_dim, |
|
latent_dim=latent_dim, |
|
ff_size=ff_size, |
|
num_layers=num_layers, |
|
num_heads=num_heads, |
|
dropout=dropout, |
|
activation=activation) |
|
|
|
self.latent_dim = self.latent_dim |
|
|
|
clip_model, _ = clip.load("ViT-L/14@336px", device="cpu", jit=False) |
|
|
|
self.token_embedding = clip_model.token_embedding |
|
self.positional_embedding = clip_model.positional_embedding |
|
self.dtype = clip_model.dtype |
|
self.latent_scale = nn.Parameter(torch.Tensor([1])) |
|
|
|
set_requires_grad(self.token_embedding, False) |
|
|
|
textTransEncoderLayer = nn.TransformerEncoderLayer( |
|
d_model=768, |
|
nhead=8, |
|
dim_feedforward=ff_size, |
|
dropout=0.1, |
|
activation="gelu") |
|
self.textTransEncoder = nn.TransformerEncoder(textTransEncoderLayer, |
|
num_layers=8) |
|
self.text_ln = nn.LayerNorm(768) |
|
self.out = nn.Linear(768, 512) |
|
|
|
self.clip_training = "text_" |
|
self.l1_criterion = torch.nn.L1Loss(reduction='mean') |
|
assert init_cfg['type'] == 'Pretrained' |
|
self.load_pretrained(init_cfg['checkpoint']) |
|
|
|
def compute_loss(self, batch): |
|
losses = {} |
|
losses["total"] = 0 |
|
|
|
|
|
batch = self.encode_text(batch) |
|
batch = self.encode_motion(batch) |
|
|
|
mixed_clip_loss, clip_losses = self.compute_clip_losses(batch) |
|
losses.update(clip_losses) |
|
losses["total"] += mixed_clip_loss |
|
|
|
return losses["total"], losses |
|
|
|
def generate_src_mask(self, T, length): |
|
B = length.shape[0] |
|
src_mask = torch.ones(B, T) |
|
for i in range(B): |
|
for j in range(length[i], T): |
|
src_mask[i, j] = 0 |
|
return src_mask |
|
|
|
def encode_motion(self, |
|
motion, |
|
motion_length=None, |
|
motion_mask=None, |
|
**kwargs): |
|
motion_emb = self.motion_encoder(motion, motion_mask) |
|
motion_emb = motion_emb / motion_emb.norm(dim=-1, keepdim=True) |
|
motion_emb = motion_emb * self.latent_scale |
|
return motion_emb |
|
|
|
def encode_text(self, text, device=None, **kwargs): |
|
raw_text = text |
|
with torch.no_grad(): |
|
text = clip.tokenize(raw_text, truncate=True).to(device) |
|
x = self.token_embedding(text).type(self.dtype) |
|
pe_tokens = x + self.positional_embedding.type(self.dtype) |
|
|
|
pe_tokens = pe_tokens.permute(1, 0, 2) |
|
out = self.textTransEncoder(pe_tokens) |
|
out = out.permute(1, 0, 2) |
|
|
|
out = self.text_ln(out) |
|
|
|
out = out[torch.arange(x.shape[0]), text.argmax(dim=-1)] |
|
out = self.out(out) |
|
|
|
text_emb = out |
|
text_emb = text_emb / text_emb.norm(dim=-1, keepdim=True) |
|
text_emb = text_emb * self.latent_scale |
|
|
|
return text_emb |
|
|
|
def load_pretrained(self, ckpt_path): |
|
checkpoint = torch.load(ckpt_path, map_location="cpu") |
|
state_dict = checkpoint["state_dict"] |
|
for k in list(state_dict.keys()): |
|
if "model" in k: |
|
state_dict[k.replace("model.", "")] = state_dict.pop(k) |
|
self.load_state_dict(state_dict, strict=True) |
|
|