File size: 8,309 Bytes
899324d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import numpy as np
import torch
import torch.nn as nn
from model.rotation2xyz import Rotation2xyz
from model.MDM import InputProcess, OutputProcess
from model.base_models import TextConditionalModel
from model.x_transformers.x_transformers import ContinuousTransformerWrapper, Encoder


class BPE_Schedule():
    def __init__(self, training_rate: float, inference_step: int, max_steps: int) -> None:
        assert training_rate >= 0 and training_rate <= 1, "training_rate must be between 0 and 1"
        assert inference_step == -1 or (inference_step >= 0 and inference_step <= max_steps), "inference_step must be between 0 and max_steps"
        self.training_rate = training_rate
        self.inference_step = inference_step
        self.max_steps = max_steps
        self.last_random = None

    def step(self, t: torch.Tensor, training: bool):
        self.last_random = torch.rand(t.shape[0], device=t.device)

    def get_schedule_fn(self, t: torch.Tensor, training: bool) -> torch.Tensor:
        # False --> absolute
        # True --> relative
        if training: # at TRAINING: then random dropout
            return self.last_random < self.training_rate
        # at INFERENCE: step function as BPE schedule
        elif self.inference_step == -1: # --> all denoising chain with APE (absolute)
            return torch.zeros_like(t, dtype=torch.bool)
        elif self.inference_step == 0: # --> all denoising chain with RPE (relative)
            return torch.ones_like(t, dtype=torch.bool)
        else: # --> BPE with binary step function. Step from APE to RPE at "self.inference_step"
            return ~(t > self.max_steps - self.inference_step)
    
    def use_bias(self, t: torch.Tensor, training: bool) -> torch.Tensor:
        # function that returns True if we should use the absolute bias (only when using multi-segments **inference**)
        assert (t[0] == t).all(), "Bias from mixed schedule only supported when using same timestep for all batch elements: " + str(t)
        return ~self.get_schedule_fn(t[0], training) # if APE --> use bias to limit attention to the each subsequence

    def get_time_weights(self, t: torch.Tensor, training: bool) -> torch.Tensor:
        # 0 --> absolute
        # 1 --> relative
        return self.get_schedule_fn(t, training).to(torch.int32)
    

class FlowMDM(TextConditionalModel):
    def __init__(self, njoints, nfeats, translation, pose_rep, glob, glob_rot,
                 latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1,
                 data_rep='rot6d', dataset='babel', 
                 clip_dim=512, clip_version=None, cond_mode="no_cond", cond_mask_prob=0.,
                 **kargs):
        super().__init__(latent_dim=latent_dim, cond_mode=cond_mode, cond_mask_prob=cond_mask_prob, dropout=dropout, clip_dim=clip_dim, clip_version=clip_version)
        self.njoints = njoints
        self.nfeats = nfeats
        self.data_rep = data_rep
        self.dataset = dataset

        self.pose_rep = pose_rep
        self.glob = glob
        self.glob_rot = glob_rot
        self.translation = translation

        self.latent_dim = latent_dim

        self.ff_size = ff_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout

        self.input_feats = self.njoints * self.nfeats
        self.max_seq_att = kargs.get('max_seq_att', 1024)
        self.input_process = InputProcess(self.data_rep, self.input_feats, self.latent_dim)
        self.process_cond_input = [nn.Linear(2*self.latent_dim, self.latent_dim) for _ in range(self.num_layers)]

        print(f"FlowMDM init")
        self.use_chunked_att = kargs.get('use_chunked_att', False)
        bpe_training_rate = kargs.get('bpe_training_ratio', 0.5) # for training, we dropout with prob 50% --> APE vs RPE
        bpe_inference_step = kargs.get('bpe_denoising_step', None)
        diffusion_steps = kargs.get('diffusion_steps', None)
        self.bpe_schedule = BPE_Schedule(bpe_training_rate, bpe_inference_step, diffusion_steps)
        ws = kargs.get('rpe_horizon', -1) # Max attention horizon
        self.local_attn_window_size = 200 if ws == -1 else ws
        print("[Training] RPE/APE rate:", bpe_training_rate)
        print(f"[Inference] BPE switch from APE to RPE at denoising step {bpe_inference_step}/{diffusion_steps}.")
        print("Local attention window size:", self.local_attn_window_size)

        self.seqTransEncoder = ContinuousTransformerWrapper(
            dim_in = self.latent_dim, dim_out = self.latent_dim,
            emb_dropout = self.dropout,
            max_seq_len = self.max_seq_att,
            use_abs_pos_emb = True,
            absolute_bpe_schedule = self.bpe_schedule, # bpe schedule for absolute embeddings (APE)
            attn_layers = Encoder(
                dim = self.latent_dim,
                depth = self.num_layers,
                heads = self.num_heads,
                ff_mult = int(np.round(self.ff_size / self.latent_dim)), # 2 for MDM hyper params
                layer_dropout = self.dropout, cross_attn_tokens_dropout = 0,

                # ======== FLOWMDM ========
                custom_layers=('A', 'f'), # A --> PCCAT
                custom_query_fn = self.process_cond_input, # function that merges the condition into the query --> PCCAT dense layer (see Fig. 3)
                attn_max_attend_past = self.local_attn_window_size,
                attn_max_attend_future = self.local_attn_window_size,
                # ======== RELATIVE POSITIONAL EMBEDDINGS ========
                rotary_pos_emb = True, # rotary embeddings
                rotary_bpe_schedule = self.bpe_schedule, # bpe schedule for rotary embeddings (RPE)
            )
        )

        self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints,
                                            self.nfeats)
        self.rot2xyz = Rotation2xyz(device='cpu', dataset=self.dataset)

    def forward(self, x, timesteps, y):
        """
        x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper
        timesteps: [batch_size] (int)
        inside y: model_kwargs with mask, pe_bias, pos_pe_abs, conditions_mask. See DiffusionWrapper_FlowMDM.
        """
        bs, njoints, nfeats, nframes = x.shape
        mask = (y['mask'].reshape((bs, nframes))[:, :nframes].to(x.device)).bool() # [bs, max_frames]

        self.bpe_schedule.step(timesteps, self.training) # update the BPE scheduler (decides either APE or RPE for each timestep)
        if self.training or self.bpe_schedule.use_bias(timesteps, self.training):
            pe_bias = y.get("pe_bias", None) # This is for limiting the attention to inside each conditioned subsequence. The BPE will decide if we use it or not depending on the dropout at training time.
            chunked_attn = False
        else: # when using RPE at inference --> we use the bias to limit the attention to the each subsequence
            pe_bias = None
            chunked_attn = self.use_chunked_att # faster attention for inference with RPE for very long sequences (see LongFormer paper for details)

        # store info needed for the relative PE --> rotary embedding
        rotary_kwargs = {'timesteps': timesteps, 'pos_pe_abs': y.get("pos_pe_abs", None), 'training': self.training, 'pe_bias': pe_bias }

        # ============== INPUT PROCESSING ==============
        emb = self.compute_embedding(x, timesteps, y)
        x = self.input_process(x) # [seqlen, bs, d]

        # ============== MAIN ARCHITECTURE ==============
        # APE or RPE is injected inside seqTransEncoder forward function
        x, emb = x.permute(1, 0, 2), emb.permute(1, 0, 2)
        output = self.seqTransEncoder(x, mask=mask, cond_tokens=emb, attn_bias=pe_bias, rotary_kwargs=rotary_kwargs, chunked_attn=chunked_attn)  # [bs, seqlen, d]
        output = output.permute(1, 0, 2)  # [seqlen, bs, d]

        # ============== OUTPUT PROCESSING ==============
        return self.output_process(output)  # [bs, njoints, nfeats, nframes]


    def _apply(self, fn):
        super()._apply(fn)
        self.rot2xyz.smpl_model._apply(fn)


    def train(self, *args, **kwargs):
        super().train(*args, **kwargs)
        self.rot2xyz.smpl_model.train(*args, **kwargs)