File size: 8,309 Bytes
899324d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import numpy as np
import torch
import torch.nn as nn
from model.rotation2xyz import Rotation2xyz
from model.MDM import InputProcess, OutputProcess
from model.base_models import TextConditionalModel
from model.x_transformers.x_transformers import ContinuousTransformerWrapper, Encoder
class BPE_Schedule():
def __init__(self, training_rate: float, inference_step: int, max_steps: int) -> None:
assert training_rate >= 0 and training_rate <= 1, "training_rate must be between 0 and 1"
assert inference_step == -1 or (inference_step >= 0 and inference_step <= max_steps), "inference_step must be between 0 and max_steps"
self.training_rate = training_rate
self.inference_step = inference_step
self.max_steps = max_steps
self.last_random = None
def step(self, t: torch.Tensor, training: bool):
self.last_random = torch.rand(t.shape[0], device=t.device)
def get_schedule_fn(self, t: torch.Tensor, training: bool) -> torch.Tensor:
# False --> absolute
# True --> relative
if training: # at TRAINING: then random dropout
return self.last_random < self.training_rate
# at INFERENCE: step function as BPE schedule
elif self.inference_step == -1: # --> all denoising chain with APE (absolute)
return torch.zeros_like(t, dtype=torch.bool)
elif self.inference_step == 0: # --> all denoising chain with RPE (relative)
return torch.ones_like(t, dtype=torch.bool)
else: # --> BPE with binary step function. Step from APE to RPE at "self.inference_step"
return ~(t > self.max_steps - self.inference_step)
def use_bias(self, t: torch.Tensor, training: bool) -> torch.Tensor:
# function that returns True if we should use the absolute bias (only when using multi-segments **inference**)
assert (t[0] == t).all(), "Bias from mixed schedule only supported when using same timestep for all batch elements: " + str(t)
return ~self.get_schedule_fn(t[0], training) # if APE --> use bias to limit attention to the each subsequence
def get_time_weights(self, t: torch.Tensor, training: bool) -> torch.Tensor:
# 0 --> absolute
# 1 --> relative
return self.get_schedule_fn(t, training).to(torch.int32)
class FlowMDM(TextConditionalModel):
def __init__(self, njoints, nfeats, translation, pose_rep, glob, glob_rot,
latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1,
data_rep='rot6d', dataset='babel',
clip_dim=512, clip_version=None, cond_mode="no_cond", cond_mask_prob=0.,
**kargs):
super().__init__(latent_dim=latent_dim, cond_mode=cond_mode, cond_mask_prob=cond_mask_prob, dropout=dropout, clip_dim=clip_dim, clip_version=clip_version)
self.njoints = njoints
self.nfeats = nfeats
self.data_rep = data_rep
self.dataset = dataset
self.pose_rep = pose_rep
self.glob = glob
self.glob_rot = glob_rot
self.translation = translation
self.latent_dim = latent_dim
self.ff_size = ff_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.input_feats = self.njoints * self.nfeats
self.max_seq_att = kargs.get('max_seq_att', 1024)
self.input_process = InputProcess(self.data_rep, self.input_feats, self.latent_dim)
self.process_cond_input = [nn.Linear(2*self.latent_dim, self.latent_dim) for _ in range(self.num_layers)]
print(f"FlowMDM init")
self.use_chunked_att = kargs.get('use_chunked_att', False)
bpe_training_rate = kargs.get('bpe_training_ratio', 0.5) # for training, we dropout with prob 50% --> APE vs RPE
bpe_inference_step = kargs.get('bpe_denoising_step', None)
diffusion_steps = kargs.get('diffusion_steps', None)
self.bpe_schedule = BPE_Schedule(bpe_training_rate, bpe_inference_step, diffusion_steps)
ws = kargs.get('rpe_horizon', -1) # Max attention horizon
self.local_attn_window_size = 200 if ws == -1 else ws
print("[Training] RPE/APE rate:", bpe_training_rate)
print(f"[Inference] BPE switch from APE to RPE at denoising step {bpe_inference_step}/{diffusion_steps}.")
print("Local attention window size:", self.local_attn_window_size)
self.seqTransEncoder = ContinuousTransformerWrapper(
dim_in = self.latent_dim, dim_out = self.latent_dim,
emb_dropout = self.dropout,
max_seq_len = self.max_seq_att,
use_abs_pos_emb = True,
absolute_bpe_schedule = self.bpe_schedule, # bpe schedule for absolute embeddings (APE)
attn_layers = Encoder(
dim = self.latent_dim,
depth = self.num_layers,
heads = self.num_heads,
ff_mult = int(np.round(self.ff_size / self.latent_dim)), # 2 for MDM hyper params
layer_dropout = self.dropout, cross_attn_tokens_dropout = 0,
# ======== FLOWMDM ========
custom_layers=('A', 'f'), # A --> PCCAT
custom_query_fn = self.process_cond_input, # function that merges the condition into the query --> PCCAT dense layer (see Fig. 3)
attn_max_attend_past = self.local_attn_window_size,
attn_max_attend_future = self.local_attn_window_size,
# ======== RELATIVE POSITIONAL EMBEDDINGS ========
rotary_pos_emb = True, # rotary embeddings
rotary_bpe_schedule = self.bpe_schedule, # bpe schedule for rotary embeddings (RPE)
)
)
self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints,
self.nfeats)
self.rot2xyz = Rotation2xyz(device='cpu', dataset=self.dataset)
def forward(self, x, timesteps, y):
"""
x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper
timesteps: [batch_size] (int)
inside y: model_kwargs with mask, pe_bias, pos_pe_abs, conditions_mask. See DiffusionWrapper_FlowMDM.
"""
bs, njoints, nfeats, nframes = x.shape
mask = (y['mask'].reshape((bs, nframes))[:, :nframes].to(x.device)).bool() # [bs, max_frames]
self.bpe_schedule.step(timesteps, self.training) # update the BPE scheduler (decides either APE or RPE for each timestep)
if self.training or self.bpe_schedule.use_bias(timesteps, self.training):
pe_bias = y.get("pe_bias", None) # This is for limiting the attention to inside each conditioned subsequence. The BPE will decide if we use it or not depending on the dropout at training time.
chunked_attn = False
else: # when using RPE at inference --> we use the bias to limit the attention to the each subsequence
pe_bias = None
chunked_attn = self.use_chunked_att # faster attention for inference with RPE for very long sequences (see LongFormer paper for details)
# store info needed for the relative PE --> rotary embedding
rotary_kwargs = {'timesteps': timesteps, 'pos_pe_abs': y.get("pos_pe_abs", None), 'training': self.training, 'pe_bias': pe_bias }
# ============== INPUT PROCESSING ==============
emb = self.compute_embedding(x, timesteps, y)
x = self.input_process(x) # [seqlen, bs, d]
# ============== MAIN ARCHITECTURE ==============
# APE or RPE is injected inside seqTransEncoder forward function
x, emb = x.permute(1, 0, 2), emb.permute(1, 0, 2)
output = self.seqTransEncoder(x, mask=mask, cond_tokens=emb, attn_bias=pe_bias, rotary_kwargs=rotary_kwargs, chunked_attn=chunked_attn) # [bs, seqlen, d]
output = output.permute(1, 0, 2) # [seqlen, bs, d]
# ============== OUTPUT PROCESSING ==============
return self.output_process(output) # [bs, njoints, nfeats, nframes]
def _apply(self, fn):
super()._apply(fn)
self.rot2xyz.smpl_model._apply(fn)
def train(self, *args, **kwargs):
super().train(*args, **kwargs)
self.rot2xyz.smpl_model.train(*args, **kwargs)
|