|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Optional, Dict, Any |
|
|
|
from ..builder import ATTENTIONS |
|
from ..utils.stylization_block import StylizationBlock |
|
|
|
|
|
def zero_module(module: nn.Module) -> nn.Module: |
|
""" |
|
Zero out the parameters of a module and return it. |
|
|
|
Args: |
|
module (nn.Module): The input PyTorch module. |
|
|
|
Returns: |
|
nn.Module: The module with zeroed parameters. |
|
""" |
|
for p in module.parameters(): |
|
p.detach().zero_() |
|
return module |
|
|
|
|
|
@ATTENTIONS.register_module() |
|
class SemanticsModulatedAttention(nn.Module): |
|
""" |
|
Semantics-modulated attention module that integrates motion, text, and retrieval features into attention computation. |
|
|
|
Args: |
|
latent_dim (int): Dimensionality of the latent (motion) features. |
|
text_latent_dim (int): Dimensionality of the text features. |
|
num_heads (int): Number of attention heads. |
|
dropout (float): Dropout rate. |
|
time_embed_dim (int): Dimensionality of time embeddings. |
|
""" |
|
|
|
def __init__(self, latent_dim: int, text_latent_dim: int, num_heads: int, dropout: float, time_embed_dim: int): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
|
|
|
|
self.norm = nn.LayerNorm(latent_dim) |
|
self.text_norm = nn.LayerNorm(text_latent_dim) |
|
|
|
|
|
self.query = nn.Linear(latent_dim, latent_dim) |
|
self.key_text = nn.Linear(text_latent_dim, latent_dim) |
|
self.value_text = nn.Linear(text_latent_dim, latent_dim) |
|
self.key_motion = nn.Linear(latent_dim, latent_dim) |
|
self.value_motion = nn.Linear(latent_dim, latent_dim) |
|
|
|
|
|
self.retr_norm1 = nn.LayerNorm(2 * latent_dim) |
|
self.retr_norm2 = nn.LayerNorm(latent_dim) |
|
self.key_retr = nn.Linear(2 * latent_dim, latent_dim) |
|
self.value_retr = zero_module(nn.Linear(latent_dim, latent_dim)) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) |
|
|
|
def forward(self, x: torch.Tensor, xf: torch.Tensor, emb: torch.Tensor, src_mask: torch.Tensor, |
|
cond_type: torch.Tensor, re_dict: dict) -> torch.Tensor: |
|
""" |
|
Forward pass of SemanticsModulatedAttention. |
|
|
|
Args: |
|
x (torch.Tensor): Motion features of shape (B, T, D). |
|
xf (torch.Tensor): Text features of shape (B, N, L). |
|
emb (torch.Tensor): Time embedding. |
|
src_mask (torch.Tensor): Source mask for the input motion features. |
|
cond_type (torch.Tensor): Condition type tensor. |
|
re_dict (dict): Dictionary containing retrieval motion, text, and mask data. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor after attention modulation, shape (B, T, D). |
|
""" |
|
B, T, D = x.shape |
|
re_motion = re_dict['re_motion'] |
|
re_text = re_dict['re_text'] |
|
re_mask = re_dict['re_mask'].reshape(B, -1, 1) |
|
N = xf.shape[1] + x.shape[1] + re_motion.shape[1] * re_motion.shape[2] |
|
|
|
H = self.num_heads |
|
query = self.query(self.norm(x)) |
|
|
|
|
|
text_cond_type = (cond_type % 10 > 0).float() |
|
retr_cond_type = (cond_type // 10 > 0).float() |
|
re_text = re_text.repeat(1, 1, re_motion.shape[2], 1) |
|
re_feat_key = torch.cat((re_motion, re_text), dim=-1).reshape(B, -1, 2 * D) |
|
|
|
|
|
key_text = self.key_text(self.text_norm(xf)) + (1 - text_cond_type) * -1000000 |
|
key_retr = self.key_retr(self.retr_norm1(re_feat_key)) + (1 - retr_cond_type) * -1000000 + (1 - re_mask) * -1000000 |
|
key_motion = self.key_motion(self.norm(x)) + (1 - src_mask) * -1000000 |
|
|
|
key = torch.cat((key_text, key_retr, key_motion), dim=1) |
|
|
|
query = F.softmax(query.view(B, T, H, -1), dim=-1) |
|
key = F.softmax(key.view(B, N, H, -1), dim=1) |
|
|
|
|
|
re_feat_value = re_motion.reshape(B, -1, D) |
|
value_text = self.value_text(self.text_norm(xf)) * text_cond_type |
|
value_retr = self.value_retr(self.retr_norm2(re_feat_value)) * retr_cond_type * re_mask |
|
value_motion = self.value_motion(self.norm(x)) * src_mask |
|
value = torch.cat((value_text, value_retr, value_motion), dim=1).view(B, N, H, -1) |
|
|
|
|
|
attention = torch.einsum('bnhd,bnhl->bhdl', key, value) |
|
y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) |
|
y = x + self.proj_out(y, emb) |
|
return y |
|
|
|
|
|
|
|
@ATTENTIONS.register_module() |
|
class DualSemanticsModulatedAttention(nn.Module): |
|
""" |
|
Dual semantics-modulated attention module that handles two streams of motion features and integrates |
|
them with text and retrieval features. |
|
|
|
Args: |
|
latent_dim (int): Dimensionality of the latent (motion) features. |
|
text_latent_dim (int): Dimensionality of the text features. |
|
num_heads (int): Number of attention heads. |
|
dropout (float): Dropout rate. |
|
time_embed_dim (int): Dimensionality of time embeddings. |
|
""" |
|
|
|
def __init__(self, latent_dim: int, text_latent_dim: int, num_heads: int, dropout: float, time_embed_dim: int): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
self.latent_dim = latent_dim |
|
|
|
|
|
self.norm = nn.LayerNorm(latent_dim) |
|
self.text_norm = nn.LayerNorm(text_latent_dim) |
|
|
|
|
|
self.query = nn.Linear(latent_dim, latent_dim) |
|
self.key_text = nn.Linear(text_latent_dim, latent_dim) |
|
self.value_text = nn.Linear(text_latent_dim, latent_dim) |
|
self.key_motion = nn.Linear(latent_dim, latent_dim) |
|
self.value_motion = nn.Linear(latent_dim, latent_dim) |
|
self.key_inter = nn.Linear(latent_dim, latent_dim) |
|
self.value_inter = nn.Linear(latent_dim, latent_dim) |
|
|
|
|
|
self.retr_norm1 = nn.LayerNorm(2 * latent_dim) |
|
self.retr_norm2 = nn.LayerNorm(latent_dim) |
|
self.key_retr = nn.Linear(2 * latent_dim, latent_dim) |
|
self.value_retr = zero_module(nn.Linear(latent_dim, latent_dim)) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) |
|
|
|
def forward(self, x: torch.Tensor, xf: torch.Tensor, emb: torch.Tensor, src_mask: torch.Tensor, |
|
cond_type: torch.Tensor, re_dict: dict) -> torch.Tensor: |
|
""" |
|
Forward pass of DualSemanticsModulatedAttention. |
|
|
|
Args: |
|
x (torch.Tensor): Motion features of shape (B, T, 2*D). |
|
xf (torch.Tensor): Text features of shape (B, N, L). |
|
emb (torch.Tensor): Time embedding. |
|
src_mask (torch.Tensor): Source mask for the input motion features. |
|
cond_type (torch.Tensor): Condition type tensor. |
|
re_dict (dict): Dictionary containing retrieval motion, text, and mask data. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor after dual attention modulation, shape (B, T, 2*D). |
|
""" |
|
x1 = x[:, :, :self.latent_dim].contiguous() |
|
x2 = x[:, :, self.latent_dim:].contiguous() |
|
B, T, D = x1.shape |
|
re_motion = re_dict['re_motion'] |
|
re_text = re_dict['re_text'] |
|
re_mask = re_dict['re_mask'].reshape(B, -1, 1) |
|
N = xf.shape[1] + x.shape[1] * 2 + re_motion.shape[1] * re_motion.shape[2] |
|
|
|
H = self.num_heads |
|
|
|
|
|
query1 = self.query(self.norm(x1)) |
|
query2 = self.query(self.norm(x2)) |
|
|
|
|
|
text_cond_type = (cond_type % 10 > 0).float() |
|
retr_cond_type = (cond_type // 10 > 0).float() |
|
re_text = re_text.repeat(1, 1, re_motion.shape[2], 1) |
|
re_feat_key = torch.cat((re_motion, re_text), dim=-1) |
|
re_feat_key = re_feat_key.reshape(B, -1, 2 * D) |
|
|
|
|
|
key_text = self.key_text(self.text_norm(xf)) + (1 - text_cond_type) * -1000000 |
|
key_retr = self.key_retr(self.retr_norm1(re_feat_key)) + (1 - retr_cond_type) * -1000000 + (1 - re_mask) * -1000000 |
|
key_motion1 = self.key_motion(self.norm(x1)) + (1 - src_mask) * -1000000 |
|
key_motion2 = self.key_motion(self.norm(x2)) + (1 - src_mask) * -1000000 |
|
|
|
|
|
key_inter1 = self.key_inter(self.norm(x2)) + (1 - src_mask) * -1000000 |
|
key_inter2 = self.key_inter(self.norm(x1)) + (1 - src_mask) * -1000000 |
|
|
|
|
|
key1 = torch.cat((key_text, key_retr, key_motion1, key_inter1), dim=1) |
|
key2 = torch.cat((key_text, key_retr, key_motion2, key_inter2), dim=1) |
|
|
|
|
|
query1 = F.softmax(query1.view(B, T, H, -1), dim=-1) |
|
query2 = F.softmax(query2.view(B, T, H, -1), dim=-1) |
|
key1 = F.softmax(key1.view(B, N, H, -1), dim=1) |
|
key2 = F.softmax(key2.view(B, N, H, -1), dim=1) |
|
|
|
|
|
re_feat_value = re_motion.reshape(B, -1, D) |
|
value_text = self.value_text(self.text_norm(xf)) * text_cond_type |
|
value_retr = self.value_retr(self.retr_norm2(re_feat_value)) * retr_cond_type * re_mask |
|
value_motion1 = self.value_motion(self.norm(x1)) * src_mask |
|
value_motion2 = self.value_motion(self.norm(x2)) * src_mask |
|
|
|
|
|
value_inter1 = self.value_inter(self.norm(x2)) * src_mask |
|
value_inter2 = self.value_inter(self.norm(x1)) * src_mask |
|
|
|
|
|
value1 = torch.cat((value_text, value_retr, value_motion1, value_inter1), dim=1).view(B, N, H, -1) |
|
value2 = torch.cat((value_text, value_retr, value_motion2, value_inter2), dim=1).view(B, N, H, -1) |
|
|
|
|
|
attention1 = torch.einsum('bnhd,bnhl->bhdl', key1, value1) |
|
attention2 = torch.einsum('bnhd,bnhl->bhdl', key2, value2) |
|
|
|
|
|
y1 = torch.einsum('bnhd,bhdl->bnhl', query1, attention1).reshape(B, T, D) |
|
y2 = torch.einsum('bnhd,bhdl->bnhl', query2, attention2).reshape(B, T, D) |
|
|
|
|
|
y1 = x1 + self.proj_out(y1, emb) |
|
y2 = x2 + self.proj_out(y2, emb) |
|
y = torch.cat((y1, y2), dim=-1) |
|
|
|
return y |
|
|