File size: 11,053 Bytes
373af33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any
from ..builder import ATTENTIONS
from ..utils.stylization_block import StylizationBlock
def zero_module(module: nn.Module) -> nn.Module:
"""
Zero out the parameters of a module and return it.
Args:
module (nn.Module): The input PyTorch module.
Returns:
nn.Module: The module with zeroed parameters.
"""
for p in module.parameters():
p.detach().zero_()
return module
@ATTENTIONS.register_module()
class SemanticsModulatedAttention(nn.Module):
"""
Semantics-modulated attention module that integrates motion, text, and retrieval features into attention computation.
Args:
latent_dim (int): Dimensionality of the latent (motion) features.
text_latent_dim (int): Dimensionality of the text features.
num_heads (int): Number of attention heads.
dropout (float): Dropout rate.
time_embed_dim (int): Dimensionality of time embeddings.
"""
def __init__(self, latent_dim: int, text_latent_dim: int, num_heads: int, dropout: float, time_embed_dim: int):
super().__init__()
self.num_heads = num_heads
# Layer Normalization for motion and text features
self.norm = nn.LayerNorm(latent_dim)
self.text_norm = nn.LayerNorm(text_latent_dim)
# Linear projections for queries, keys, and values
self.query = nn.Linear(latent_dim, latent_dim)
self.key_text = nn.Linear(text_latent_dim, latent_dim)
self.value_text = nn.Linear(text_latent_dim, latent_dim)
self.key_motion = nn.Linear(latent_dim, latent_dim)
self.value_motion = nn.Linear(latent_dim, latent_dim)
# Retrieval feature processing (motion and text)
self.retr_norm1 = nn.LayerNorm(2 * latent_dim)
self.retr_norm2 = nn.LayerNorm(latent_dim)
self.key_retr = nn.Linear(2 * latent_dim, latent_dim)
self.value_retr = zero_module(nn.Linear(latent_dim, latent_dim))
# Dropout and output projection
self.dropout = nn.Dropout(dropout)
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
def forward(self, x: torch.Tensor, xf: torch.Tensor, emb: torch.Tensor, src_mask: torch.Tensor,
cond_type: torch.Tensor, re_dict: dict) -> torch.Tensor:
"""
Forward pass of SemanticsModulatedAttention.
Args:
x (torch.Tensor): Motion features of shape (B, T, D).
xf (torch.Tensor): Text features of shape (B, N, L).
emb (torch.Tensor): Time embedding.
src_mask (torch.Tensor): Source mask for the input motion features.
cond_type (torch.Tensor): Condition type tensor.
re_dict (dict): Dictionary containing retrieval motion, text, and mask data.
Returns:
torch.Tensor: Output tensor after attention modulation, shape (B, T, D).
"""
B, T, D = x.shape
re_motion = re_dict['re_motion']
re_text = re_dict['re_text']
re_mask = re_dict['re_mask'].reshape(B, -1, 1)
N = xf.shape[1] + x.shape[1] + re_motion.shape[1] * re_motion.shape[2] # Total number of attention keys
H = self.num_heads
query = self.query(self.norm(x)) # Query from motion features
# Key and Value from text and retrieval features
text_cond_type = (cond_type % 10 > 0).float()
retr_cond_type = (cond_type // 10 > 0).float()
re_text = re_text.repeat(1, 1, re_motion.shape[2], 1)
re_feat_key = torch.cat((re_motion, re_text), dim=-1).reshape(B, -1, 2 * D)
# Calculate keys for text, retrieval, and motion
key_text = self.key_text(self.text_norm(xf)) + (1 - text_cond_type) * -1000000
key_retr = self.key_retr(self.retr_norm1(re_feat_key)) + (1 - retr_cond_type) * -1000000 + (1 - re_mask) * -1000000
key_motion = self.key_motion(self.norm(x)) + (1 - src_mask) * -1000000
key = torch.cat((key_text, key_retr, key_motion), dim=1) # Concatenate all keys
query = F.softmax(query.view(B, T, H, -1), dim=-1)
key = F.softmax(key.view(B, N, H, -1), dim=1)
# Value computation from text, retrieval, and motion features
re_feat_value = re_motion.reshape(B, -1, D)
value_text = self.value_text(self.text_norm(xf)) * text_cond_type
value_retr = self.value_retr(self.retr_norm2(re_feat_value)) * retr_cond_type * re_mask
value_motion = self.value_motion(self.norm(x)) * src_mask
value = torch.cat((value_text, value_retr, value_motion), dim=1).view(B, N, H, -1)
# Attention computation and output projection
attention = torch.einsum('bnhd,bnhl->bhdl', key, value)
y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D)
y = x + self.proj_out(y, emb)
return y
@ATTENTIONS.register_module()
class DualSemanticsModulatedAttention(nn.Module):
"""
Dual semantics-modulated attention module that handles two streams of motion features and integrates
them with text and retrieval features.
Args:
latent_dim (int): Dimensionality of the latent (motion) features.
text_latent_dim (int): Dimensionality of the text features.
num_heads (int): Number of attention heads.
dropout (float): Dropout rate.
time_embed_dim (int): Dimensionality of time embeddings.
"""
def __init__(self, latent_dim: int, text_latent_dim: int, num_heads: int, dropout: float, time_embed_dim: int):
super().__init__()
self.num_heads = num_heads
self.latent_dim = latent_dim
# Layer Normalization for motion and text features
self.norm = nn.LayerNorm(latent_dim)
self.text_norm = nn.LayerNorm(text_latent_dim)
# Linear projections for queries, keys, and values
self.query = nn.Linear(latent_dim, latent_dim)
self.key_text = nn.Linear(text_latent_dim, latent_dim)
self.value_text = nn.Linear(text_latent_dim, latent_dim)
self.key_motion = nn.Linear(latent_dim, latent_dim)
self.value_motion = nn.Linear(latent_dim, latent_dim)
self.key_inter = nn.Linear(latent_dim, latent_dim)
self.value_inter = nn.Linear(latent_dim, latent_dim)
# Retrieval feature processing (motion and text)
self.retr_norm1 = nn.LayerNorm(2 * latent_dim)
self.retr_norm2 = nn.LayerNorm(latent_dim)
self.key_retr = nn.Linear(2 * latent_dim, latent_dim)
self.value_retr = zero_module(nn.Linear(latent_dim, latent_dim))
# Dropout and output projection
self.dropout = nn.Dropout(dropout)
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
def forward(self, x: torch.Tensor, xf: torch.Tensor, emb: torch.Tensor, src_mask: torch.Tensor,
cond_type: torch.Tensor, re_dict: dict) -> torch.Tensor:
"""
Forward pass of DualSemanticsModulatedAttention.
Args:
x (torch.Tensor): Motion features of shape (B, T, 2*D).
xf (torch.Tensor): Text features of shape (B, N, L).
emb (torch.Tensor): Time embedding.
src_mask (torch.Tensor): Source mask for the input motion features.
cond_type (torch.Tensor): Condition type tensor.
re_dict (dict): Dictionary containing retrieval motion, text, and mask data.
Returns:
torch.Tensor: Output tensor after dual attention modulation, shape (B, T, 2*D).
"""
x1 = x[:, :, :self.latent_dim].contiguous()
x2 = x[:, :, self.latent_dim:].contiguous()
B, T, D = x1.shape
re_motion = re_dict['re_motion']
re_text = re_dict['re_text']
re_mask = re_dict['re_mask'].reshape(B, -1, 1)
N = xf.shape[1] + x.shape[1] * 2 + re_motion.shape[1] * re_motion.shape[2]
H = self.num_heads
# Query computation for both streams
query1 = self.query(self.norm(x1))
query2 = self.query(self.norm(x2))
# Retrieval key/value feature preparation
text_cond_type = (cond_type % 10 > 0).float()
retr_cond_type = (cond_type // 10 > 0).float()
re_text = re_text.repeat(1, 1, re_motion.shape[2], 1)
re_feat_key = torch.cat((re_motion, re_text), dim=-1)
re_feat_key = re_feat_key.reshape(B, -1, 2 * D)
# Keys for text, retrieval, and motion
key_text = self.key_text(self.text_norm(xf)) + (1 - text_cond_type) * -1000000
key_retr = self.key_retr(self.retr_norm1(re_feat_key)) + (1 - retr_cond_type) * -1000000 + (1 - re_mask) * -1000000
key_motion1 = self.key_motion(self.norm(x1)) + (1 - src_mask) * -1000000
key_motion2 = self.key_motion(self.norm(x2)) + (1 - src_mask) * -1000000
# Cross-attention keys for inter-stream communication
key_inter1 = self.key_inter(self.norm(x2)) + (1 - src_mask) * -1000000
key_inter2 = self.key_inter(self.norm(x1)) + (1 - src_mask) * -1000000
# Concatenate all keys for the two streams
key1 = torch.cat((key_text, key_retr, key_motion1, key_inter1), dim=1)
key2 = torch.cat((key_text, key_retr, key_motion2, key_inter2), dim=1)
# Softmax over queries and keys
query1 = F.softmax(query1.view(B, T, H, -1), dim=-1)
query2 = F.softmax(query2.view(B, T, H, -1), dim=-1)
key1 = F.softmax(key1.view(B, N, H, -1), dim=1)
key2 = F.softmax(key2.view(B, N, H, -1), dim=1)
# Value computation for text, retrieval, and motion
re_feat_value = re_motion.reshape(B, -1, D)
value_text = self.value_text(self.text_norm(xf)) * text_cond_type
value_retr = self.value_retr(self.retr_norm2(re_feat_value)) * retr_cond_type * re_mask
value_motion1 = self.value_motion(self.norm(x1)) * src_mask
value_motion2 = self.value_motion(self.norm(x2)) * src_mask
# Inter-stream value exchange
value_inter1 = self.value_inter(self.norm(x2)) * src_mask
value_inter2 = self.value_inter(self.norm(x1)) * src_mask
# Concatenate values for both streams
value1 = torch.cat((value_text, value_retr, value_motion1, value_inter1), dim=1).view(B, N, H, -1)
value2 = torch.cat((value_text, value_retr, value_motion2, value_inter2), dim=1).view(B, N, H, -1)
# Compute attention outputs for both streams
attention1 = torch.einsum('bnhd,bnhl->bhdl', key1, value1)
attention2 = torch.einsum('bnhd,bnhl->bhdl', key2, value2)
# Apply attention to queries and compute final output
y1 = torch.einsum('bnhd,bhdl->bnhl', query1, attention1).reshape(B, T, D)
y2 = torch.einsum('bnhd,bhdl->bnhl', query2, attention2).reshape(B, T, D)
# Combine both streams and apply output projection
y1 = x1 + self.proj_out(y1, emb)
y2 = x2 + self.proj_out(y2, emb)
y = torch.cat((y1, y2), dim=-1)
return y
|