FlowMDM / model /FlowMDM.py
Leitifel's picture
Upload 165 files
899324d verified
import numpy as np
import torch
import torch.nn as nn
from model.rotation2xyz import Rotation2xyz
from model.MDM import InputProcess, OutputProcess
from model.base_models import TextConditionalModel
from model.x_transformers.x_transformers import ContinuousTransformerWrapper, Encoder
class BPE_Schedule():
def __init__(self, training_rate: float, inference_step: int, max_steps: int) -> None:
assert training_rate >= 0 and training_rate <= 1, "training_rate must be between 0 and 1"
assert inference_step == -1 or (inference_step >= 0 and inference_step <= max_steps), "inference_step must be between 0 and max_steps"
self.training_rate = training_rate
self.inference_step = inference_step
self.max_steps = max_steps
self.last_random = None
def step(self, t: torch.Tensor, training: bool):
self.last_random = torch.rand(t.shape[0], device=t.device)
def get_schedule_fn(self, t: torch.Tensor, training: bool) -> torch.Tensor:
# False --> absolute
# True --> relative
if training: # at TRAINING: then random dropout
return self.last_random < self.training_rate
# at INFERENCE: step function as BPE schedule
elif self.inference_step == -1: # --> all denoising chain with APE (absolute)
return torch.zeros_like(t, dtype=torch.bool)
elif self.inference_step == 0: # --> all denoising chain with RPE (relative)
return torch.ones_like(t, dtype=torch.bool)
else: # --> BPE with binary step function. Step from APE to RPE at "self.inference_step"
return ~(t > self.max_steps - self.inference_step)
def use_bias(self, t: torch.Tensor, training: bool) -> torch.Tensor:
# function that returns True if we should use the absolute bias (only when using multi-segments **inference**)
assert (t[0] == t).all(), "Bias from mixed schedule only supported when using same timestep for all batch elements: " + str(t)
return ~self.get_schedule_fn(t[0], training) # if APE --> use bias to limit attention to the each subsequence
def get_time_weights(self, t: torch.Tensor, training: bool) -> torch.Tensor:
# 0 --> absolute
# 1 --> relative
return self.get_schedule_fn(t, training).to(torch.int32)
class FlowMDM(TextConditionalModel):
def __init__(self, njoints, nfeats, translation, pose_rep, glob, glob_rot,
latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1,
data_rep='rot6d', dataset='babel',
clip_dim=512, clip_version=None, cond_mode="no_cond", cond_mask_prob=0.,
**kargs):
super().__init__(latent_dim=latent_dim, cond_mode=cond_mode, cond_mask_prob=cond_mask_prob, dropout=dropout, clip_dim=clip_dim, clip_version=clip_version)
self.njoints = njoints
self.nfeats = nfeats
self.data_rep = data_rep
self.dataset = dataset
self.pose_rep = pose_rep
self.glob = glob
self.glob_rot = glob_rot
self.translation = translation
self.latent_dim = latent_dim
self.ff_size = ff_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.input_feats = self.njoints * self.nfeats
self.max_seq_att = kargs.get('max_seq_att', 1024)
self.input_process = InputProcess(self.data_rep, self.input_feats, self.latent_dim)
self.process_cond_input = [nn.Linear(2*self.latent_dim, self.latent_dim) for _ in range(self.num_layers)]
print(f"FlowMDM init")
self.use_chunked_att = kargs.get('use_chunked_att', False)
bpe_training_rate = kargs.get('bpe_training_ratio', 0.5) # for training, we dropout with prob 50% --> APE vs RPE
bpe_inference_step = kargs.get('bpe_denoising_step', None)
diffusion_steps = kargs.get('diffusion_steps', None)
self.bpe_schedule = BPE_Schedule(bpe_training_rate, bpe_inference_step, diffusion_steps)
ws = kargs.get('rpe_horizon', -1) # Max attention horizon
self.local_attn_window_size = 200 if ws == -1 else ws
print("[Training] RPE/APE rate:", bpe_training_rate)
print(f"[Inference] BPE switch from APE to RPE at denoising step {bpe_inference_step}/{diffusion_steps}.")
print("Local attention window size:", self.local_attn_window_size)
self.seqTransEncoder = ContinuousTransformerWrapper(
dim_in = self.latent_dim, dim_out = self.latent_dim,
emb_dropout = self.dropout,
max_seq_len = self.max_seq_att,
use_abs_pos_emb = True,
absolute_bpe_schedule = self.bpe_schedule, # bpe schedule for absolute embeddings (APE)
attn_layers = Encoder(
dim = self.latent_dim,
depth = self.num_layers,
heads = self.num_heads,
ff_mult = int(np.round(self.ff_size / self.latent_dim)), # 2 for MDM hyper params
layer_dropout = self.dropout, cross_attn_tokens_dropout = 0,
# ======== FLOWMDM ========
custom_layers=('A', 'f'), # A --> PCCAT
custom_query_fn = self.process_cond_input, # function that merges the condition into the query --> PCCAT dense layer (see Fig. 3)
attn_max_attend_past = self.local_attn_window_size,
attn_max_attend_future = self.local_attn_window_size,
# ======== RELATIVE POSITIONAL EMBEDDINGS ========
rotary_pos_emb = True, # rotary embeddings
rotary_bpe_schedule = self.bpe_schedule, # bpe schedule for rotary embeddings (RPE)
)
)
self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints,
self.nfeats)
self.rot2xyz = Rotation2xyz(device='cpu', dataset=self.dataset)
def forward(self, x, timesteps, y):
"""
x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper
timesteps: [batch_size] (int)
inside y: model_kwargs with mask, pe_bias, pos_pe_abs, conditions_mask. See DiffusionWrapper_FlowMDM.
"""
bs, njoints, nfeats, nframes = x.shape
mask = (y['mask'].reshape((bs, nframes))[:, :nframes].to(x.device)).bool() # [bs, max_frames]
self.bpe_schedule.step(timesteps, self.training) # update the BPE scheduler (decides either APE or RPE for each timestep)
if self.training or self.bpe_schedule.use_bias(timesteps, self.training):
pe_bias = y.get("pe_bias", None) # This is for limiting the attention to inside each conditioned subsequence. The BPE will decide if we use it or not depending on the dropout at training time.
chunked_attn = False
else: # when using RPE at inference --> we use the bias to limit the attention to the each subsequence
pe_bias = None
chunked_attn = self.use_chunked_att # faster attention for inference with RPE for very long sequences (see LongFormer paper for details)
# store info needed for the relative PE --> rotary embedding
rotary_kwargs = {'timesteps': timesteps, 'pos_pe_abs': y.get("pos_pe_abs", None), 'training': self.training, 'pe_bias': pe_bias }
# ============== INPUT PROCESSING ==============
emb = self.compute_embedding(x, timesteps, y)
x = self.input_process(x) # [seqlen, bs, d]
# ============== MAIN ARCHITECTURE ==============
# APE or RPE is injected inside seqTransEncoder forward function
x, emb = x.permute(1, 0, 2), emb.permute(1, 0, 2)
output = self.seqTransEncoder(x, mask=mask, cond_tokens=emb, attn_bias=pe_bias, rotary_kwargs=rotary_kwargs, chunked_attn=chunked_attn) # [bs, seqlen, d]
output = output.permute(1, 0, 2) # [seqlen, bs, d]
# ============== OUTPUT PROCESSING ==============
return self.output_process(output) # [bs, njoints, nfeats, nframes]
def _apply(self, fn):
super()._apply(fn)
self.rot2xyz.smpl_model._apply(fn)
def train(self, *args, **kwargs):
super().train(*args, **kwargs)
self.rot2xyz.smpl_model.train(*args, **kwargs)