File size: 7,586 Bytes
373af33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import numpy as np
import torch

from typing import Optional, Dict, List

from ..builder import SUBMODULES
from .motion_transformer import MotionTransformer


@SUBMODULES.register_module()
class MotionDiffuseTransformer(MotionTransformer):
    """
    MotionDiffuseTransformer is a subclass of DiffusionTransformer designed for motion generation.
    It uses a diffusion-based approach with optional guidance during training and inference.

    Args:
        guidance_cfg (dict, optional): Configuration for guidance during inference and training.
                                        'type' can be 'constant' or dynamically calculated based on timesteps.
        kwargs: Additional keyword arguments for the DiffusionTransformer base class.
    """

    def __init__(self, guidance_cfg: Optional[dict] = None, **kwargs):
        """
        Initialize the MotionDiffuseTransformer.

        Args:
            guidance_cfg (Optional[dict]): Configuration for the guidance.
            kwargs: Additional arguments passed to the base class.
        """
        super().__init__(**kwargs)
        self.guidance_cfg = guidance_cfg

    def scale_func(self, timestep: int) -> dict:
        """
        Compute the scaling coefficients for text-based guidance and no-guidance.

        Args:
            timestep (int): The current diffusion timestep.

        Returns:
            dict: A dictionary containing 'text_coef' and 'none_coef' that control the mix of text-conditioned and 
                  non-text-conditioned outputs.
        """
        if self.guidance_cfg['type'] == 'constant':
            w = self.guidance_cfg['scale']
            return {'text_coef': w, 'none_coef': 1 - w}
        else:
            scale = self.guidance_cfg['scale']
            w = (1 - (1000 - timestep) / 1000) * scale + 1
            output = {'text_coef': w, 'none_coef': 1 - w}
        return output

    def get_precompute_condition(self,
                                 text: Optional[torch.Tensor] = None,
                                 xf_proj: Optional[torch.Tensor] = None,
                                 xf_out: Optional[torch.Tensor] = None,
                                 device: Optional[torch.device] = None,
                                 clip_feat: Optional[torch.Tensor] = None,
                                 **kwargs) -> dict:
        """
        Precompute the conditions for text-based guidance using a text encoder.

        Args:
            text (Optional[torch.Tensor]): The input text data.
            xf_proj (Optional[torch.Tensor]): Precomputed text projection.
            xf_out (Optional[torch.Tensor]): Precomputed output from the text encoder.
            device (Optional[torch.device]): The device on which the model is running.
            clip_feat (Optional[torch.Tensor]): CLIP features for text guidance.
            kwargs: Additional keyword arguments.

        Returns:
            dict: A dictionary containing the text projection and output from the encoder.
        """
        if xf_out is None:
            if self.use_text_proj:
                xf_proj, xf_out = self.encode_text(text, clip_feat, device)
            else:
                xf_out = self.encode_text(text, clip_feat, device)
        return {'xf_proj': xf_proj, 'xf_out': xf_out}

    def post_process(self, motion: torch.Tensor) -> torch.Tensor:
        """
        Post-process the generated motion data by re-normalizing it using mean and standard deviation.

        Args:
            motion (torch.Tensor): The generated motion data.

        Returns:
            torch.Tensor: Post-processed motion data.
        """
        if self.post_process_cfg is not None:
            if self.post_process_cfg.get("unnormalized_infer", False):
                mean = torch.from_numpy(np.load(self.post_process_cfg['mean_path']))
                mean = mean.type_as(motion)
                std = torch.from_numpy(np.load(self.post_process_cfg['std_path']))
                std = std.type_as(motion)
            motion = motion * std + mean
        return motion

    def forward_train(self,
                      h: torch.Tensor,
                      src_mask: Optional[torch.Tensor] = None,
                      emb: Optional[torch.Tensor] = None,
                      xf_out: Optional[torch.Tensor] = None,
                      **kwargs) -> torch.Tensor:
        """
        Forward pass during training.

        Args:
            h (torch.Tensor): Input motion tensor of shape (B, T, D).
            src_mask (Optional[torch.Tensor]): Source mask for masking the input.
            emb (torch.Tensor): Time-step embeddings.
            xf_out (Optional[torch.Tensor]): Precomputed output from the text encoder.
            kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: Output motion data after processing by the temporal decoder blocks.
        """
        B, T = h.shape[0], h.shape[1]
        if self.guidance_cfg is None:
            for module in self.temporal_decoder_blocks:
                h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask)
        else:
            cond_type = torch.randint(0, 100, size=(B, 1, 1)).to(h.device)
            for module in self.temporal_decoder_blocks:
                h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask, cond_type=cond_type)
        output = self.out(h).view(B, T, -1).contiguous()
        return output

    def forward_test(self,
                     h: torch.Tensor,
                     src_mask: Optional[torch.Tensor] = None,
                     emb: Optional[torch.Tensor] = None,
                     xf_out: Optional[torch.Tensor] = None,
                     timesteps: Optional[torch.Tensor] = None,
                     **kwargs) -> torch.Tensor:
        """
        Forward pass during testing/inference.

        Args:
            h (torch.Tensor): Input motion tensor of shape (B, T, D).
            src_mask (Optional[torch.Tensor]): Source mask for masking the input.
            emb (torch.Tensor): Time-step embeddings.
            xf_out (Optional[torch.Tensor]): Precomputed output from the text encoder.
            timesteps (Optional[torch.Tensor]): Current diffusion timesteps.
            kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: Output motion data after processing by the temporal decoder blocks.
        """
        B, T = h.shape[0], h.shape[1]
        if self.guidance_cfg is None:
            for module in self.temporal_decoder_blocks:
                h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask)
            output = self.out(h).view(B, T, -1).contiguous()
        else:
            text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1
            none_cond_type = torch.zeros(B, 1, 1).to(h.device)
            all_cond_type = torch.cat((text_cond_type, none_cond_type), dim=0)
            h = h.repeat(2, 1, 1)
            xf_out = xf_out.repeat(2, 1, 1)
            emb = emb.repeat(2, 1)
            src_mask = src_mask.repeat(2, 1, 1)
            for module in self.temporal_decoder_blocks:
                h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask, cond_type=all_cond_type)
            out = self.out(h).view(2 * B, T, -1).contiguous()
            out_text = out[:B].contiguous()
            out_none = out[B:].contiguous()
            coef_cfg = self.scale_func(int(timesteps[0]))
            text_coef = coef_cfg['text_coef']
            none_coef = coef_cfg['none_coef']
            output = out_text * text_coef + out_none * none_coef
        return output