File size: 7,586 Bytes
373af33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import numpy as np
import torch
from typing import Optional, Dict, List
from ..builder import SUBMODULES
from .motion_transformer import MotionTransformer
@SUBMODULES.register_module()
class MotionDiffuseTransformer(MotionTransformer):
"""
MotionDiffuseTransformer is a subclass of DiffusionTransformer designed for motion generation.
It uses a diffusion-based approach with optional guidance during training and inference.
Args:
guidance_cfg (dict, optional): Configuration for guidance during inference and training.
'type' can be 'constant' or dynamically calculated based on timesteps.
kwargs: Additional keyword arguments for the DiffusionTransformer base class.
"""
def __init__(self, guidance_cfg: Optional[dict] = None, **kwargs):
"""
Initialize the MotionDiffuseTransformer.
Args:
guidance_cfg (Optional[dict]): Configuration for the guidance.
kwargs: Additional arguments passed to the base class.
"""
super().__init__(**kwargs)
self.guidance_cfg = guidance_cfg
def scale_func(self, timestep: int) -> dict:
"""
Compute the scaling coefficients for text-based guidance and no-guidance.
Args:
timestep (int): The current diffusion timestep.
Returns:
dict: A dictionary containing 'text_coef' and 'none_coef' that control the mix of text-conditioned and
non-text-conditioned outputs.
"""
if self.guidance_cfg['type'] == 'constant':
w = self.guidance_cfg['scale']
return {'text_coef': w, 'none_coef': 1 - w}
else:
scale = self.guidance_cfg['scale']
w = (1 - (1000 - timestep) / 1000) * scale + 1
output = {'text_coef': w, 'none_coef': 1 - w}
return output
def get_precompute_condition(self,
text: Optional[torch.Tensor] = None,
xf_proj: Optional[torch.Tensor] = None,
xf_out: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
clip_feat: Optional[torch.Tensor] = None,
**kwargs) -> dict:
"""
Precompute the conditions for text-based guidance using a text encoder.
Args:
text (Optional[torch.Tensor]): The input text data.
xf_proj (Optional[torch.Tensor]): Precomputed text projection.
xf_out (Optional[torch.Tensor]): Precomputed output from the text encoder.
device (Optional[torch.device]): The device on which the model is running.
clip_feat (Optional[torch.Tensor]): CLIP features for text guidance.
kwargs: Additional keyword arguments.
Returns:
dict: A dictionary containing the text projection and output from the encoder.
"""
if xf_out is None:
if self.use_text_proj:
xf_proj, xf_out = self.encode_text(text, clip_feat, device)
else:
xf_out = self.encode_text(text, clip_feat, device)
return {'xf_proj': xf_proj, 'xf_out': xf_out}
def post_process(self, motion: torch.Tensor) -> torch.Tensor:
"""
Post-process the generated motion data by re-normalizing it using mean and standard deviation.
Args:
motion (torch.Tensor): The generated motion data.
Returns:
torch.Tensor: Post-processed motion data.
"""
if self.post_process_cfg is not None:
if self.post_process_cfg.get("unnormalized_infer", False):
mean = torch.from_numpy(np.load(self.post_process_cfg['mean_path']))
mean = mean.type_as(motion)
std = torch.from_numpy(np.load(self.post_process_cfg['std_path']))
std = std.type_as(motion)
motion = motion * std + mean
return motion
def forward_train(self,
h: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
emb: Optional[torch.Tensor] = None,
xf_out: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""
Forward pass during training.
Args:
h (torch.Tensor): Input motion tensor of shape (B, T, D).
src_mask (Optional[torch.Tensor]): Source mask for masking the input.
emb (torch.Tensor): Time-step embeddings.
xf_out (Optional[torch.Tensor]): Precomputed output from the text encoder.
kwargs: Additional keyword arguments.
Returns:
torch.Tensor: Output motion data after processing by the temporal decoder blocks.
"""
B, T = h.shape[0], h.shape[1]
if self.guidance_cfg is None:
for module in self.temporal_decoder_blocks:
h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask)
else:
cond_type = torch.randint(0, 100, size=(B, 1, 1)).to(h.device)
for module in self.temporal_decoder_blocks:
h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask, cond_type=cond_type)
output = self.out(h).view(B, T, -1).contiguous()
return output
def forward_test(self,
h: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
emb: Optional[torch.Tensor] = None,
xf_out: Optional[torch.Tensor] = None,
timesteps: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""
Forward pass during testing/inference.
Args:
h (torch.Tensor): Input motion tensor of shape (B, T, D).
src_mask (Optional[torch.Tensor]): Source mask for masking the input.
emb (torch.Tensor): Time-step embeddings.
xf_out (Optional[torch.Tensor]): Precomputed output from the text encoder.
timesteps (Optional[torch.Tensor]): Current diffusion timesteps.
kwargs: Additional keyword arguments.
Returns:
torch.Tensor: Output motion data after processing by the temporal decoder blocks.
"""
B, T = h.shape[0], h.shape[1]
if self.guidance_cfg is None:
for module in self.temporal_decoder_blocks:
h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask)
output = self.out(h).view(B, T, -1).contiguous()
else:
text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1
none_cond_type = torch.zeros(B, 1, 1).to(h.device)
all_cond_type = torch.cat((text_cond_type, none_cond_type), dim=0)
h = h.repeat(2, 1, 1)
xf_out = xf_out.repeat(2, 1, 1)
emb = emb.repeat(2, 1)
src_mask = src_mask.repeat(2, 1, 1)
for module in self.temporal_decoder_blocks:
h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask, cond_type=all_cond_type)
out = self.out(h).view(2 * B, T, -1).contiguous()
out_text = out[:B].contiguous()
out_none = out[B:].contiguous()
coef_cfg = self.scale_func(int(timesteps[0]))
text_coef = coef_cfg['text_coef']
none_coef = coef_cfg['none_coef']
output = out_text * text_coef + out_none * none_coef
return output
|