|
from collections import OrderedDict |
|
import torch |
|
import torch.distributed as dist |
|
from mmcv.runner import BaseModule |
|
from typing import Dict, Tuple, List |
|
|
|
|
|
def to_cpu(x: torch.Tensor) -> torch.Tensor: |
|
"""Move a tensor to CPU and detach it from the computation graph. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor. |
|
|
|
Returns: |
|
torch.Tensor: The tensor detached and moved to CPU. |
|
""" |
|
if isinstance(x, torch.Tensor): |
|
return x.detach().cpu() |
|
return x |
|
|
|
|
|
class BaseArchitecture(BaseModule): |
|
"""Base class for mogen architecture. |
|
|
|
Args: |
|
init_cfg (dict, optional): Initialization config for the module. |
|
""" |
|
|
|
def __init__(self, init_cfg: dict = None): |
|
super(BaseArchitecture, self).__init__(init_cfg) |
|
|
|
def forward_train(self, **kwargs): |
|
"""Forward computation during training.""" |
|
pass |
|
|
|
def forward_test(self, **kwargs): |
|
"""Forward computation during testing.""" |
|
pass |
|
|
|
def _parse_losses(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, float]]: |
|
"""Parse the raw outputs (losses) of the network. |
|
|
|
Args: |
|
losses (dict): Raw output of the network, which usually contains |
|
losses and other necessary information. |
|
|
|
Returns: |
|
tuple[Tensor, dict]: (loss, log_vars) |
|
- loss is the loss tensor which may be a weighted sum of all losses, |
|
- log_vars contains all the variables to be logged. |
|
""" |
|
log_vars = OrderedDict() |
|
for loss_name, loss_value in losses.items(): |
|
if isinstance(loss_value, torch.Tensor): |
|
log_vars[loss_name] = loss_value.mean() |
|
elif isinstance(loss_value, list): |
|
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) |
|
else: |
|
raise TypeError(f'{loss_name} is not a tensor or list of tensors') |
|
|
|
loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) |
|
|
|
log_vars['loss'] = loss |
|
for loss_name, loss_value in log_vars.items(): |
|
|
|
if dist.is_available() and dist.is_initialized(): |
|
loss_value = loss_value.data.clone() |
|
dist.all_reduce(loss_value.div_(dist.get_world_size())) |
|
log_vars[loss_name] = loss_value.item() |
|
|
|
return loss, log_vars |
|
|
|
def train_step(self, data: Dict, optimizer: torch.optim.Optimizer) -> Dict: |
|
"""The iteration step during training. |
|
|
|
This method defines an iteration step during training, excluding backpropagation |
|
and optimizer updating, which are handled by an optimizer hook. |
|
|
|
Args: |
|
data (dict): The output of the dataloader. |
|
optimizer (torch.optim.Optimizer): The optimizer object (unused). |
|
|
|
Returns: |
|
dict: A dictionary containing the loss, log_vars for logging, and the number of samples. |
|
- ``loss``: A tensor for backpropagation, which may be a weighted sum of multiple losses. |
|
- ``log_vars``: All the variables to be logged. |
|
- ``num_samples``: The number of samples in the batch. |
|
""" |
|
losses = self(**data) |
|
loss, log_vars = self._parse_losses(losses) |
|
|
|
outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data['motion'])) |
|
return outputs |
|
|
|
def val_step(self, data: Dict, optimizer: torch.optim.Optimizer = None) -> Dict: |
|
"""The iteration step during validation. |
|
|
|
Args: |
|
data (dict): The output of the dataloader. |
|
optimizer (torch.optim.Optimizer, optional): The optimizer object (unused). |
|
|
|
Returns: |
|
dict: A dictionary containing the loss, log_vars for logging, and the number of samples. |
|
""" |
|
losses = self(**data) |
|
loss, log_vars = self._parse_losses(losses) |
|
|
|
outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data['motion'])) |
|
return outputs |
|
|
|
def forward(self, **kwargs): |
|
"""Forward computation based on the training or testing mode.""" |
|
if self.training: |
|
return self.forward_train(**kwargs) |
|
else: |
|
return self.forward_test(**kwargs) |
|
|
|
def split_results(self, results: Dict[str, torch.Tensor]) -> List[Dict]: |
|
"""Split batched results into individual outputs. |
|
|
|
Args: |
|
results (dict): The batched results from the model containing 'motion', 'pred_motion', etc. |
|
|
|
Returns: |
|
list: A list of dictionaries where each dictionary contains results for a single instance. |
|
""" |
|
B = results['motion'].shape[0] |
|
output = [] |
|
for i in range(B): |
|
batch_output = dict() |
|
batch_output['motion'] = to_cpu(results['motion'][i]) |
|
batch_output['pred_motion'] = to_cpu(results['pred_motion'][i]) |
|
batch_output['motion_length'] = to_cpu(results['motion_length'][i]) |
|
batch_output['motion'][batch_output['motion_length']:, :] = 0 |
|
batch_output['motion_mask'] = to_cpu(results['motion_mask'][i]) |
|
if 'pred_motion_length' in results: |
|
batch_output['pred_motion_length'] = to_cpu(results['pred_motion_length'][i]) |
|
else: |
|
batch_output['pred_motion_length'] = to_cpu(results['motion_length'][i]) |
|
batch_output['pred_motion'][batch_output['pred_motion_length']:, :] = 0 |
|
if 'pred_motion_mask' in results: |
|
batch_output['pred_motion_mask'] = to_cpu(results['pred_motion_mask'][i]) |
|
else: |
|
batch_output['pred_motion_mask'] = to_cpu(results['motion_mask'][i]) |
|
if 'motion_metas' in results: |
|
motion_metas = results['motion_metas'][i] |
|
if 'text' in motion_metas: |
|
batch_output['text'] = motion_metas['text'] |
|
if 'token' in motion_metas: |
|
batch_output['token'] = motion_metas['token'] |
|
if 'meta_data' in motion_metas and 'category_id' in motion_metas['meta_data']: |
|
batch_output['category_id'] = motion_metas['meta_data']['category_id'] |
|
batch_output['motion_metas'] = motion_metas |
|
output.append(batch_output) |
|
return output |
|
|