HMR2.0 / hmr2 /models /components /pose_transformer.py
brjathu
Adding HF files
29a229f
raw
history blame
No virus
11.3 kB
from inspect import isfunction
from typing import Callable, Optional
import torch
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import nn
from .t_cond_mlp import (
AdaptiveLayerNorm1D,
FrequencyEmbedder,
normalization_layer,
)
# from .vit import Attention, FeedForward
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
class PreNorm(nn.Module):
def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
super().__init__()
self.norm = normalization_layer(norm, dim, norm_cond_dim)
self.fn = fn
def forward(self, x: torch.Tensor, *args, **kwargs):
if isinstance(self.norm, AdaptiveLayerNorm1D):
return self.fn(self.norm(x, *args), **kwargs)
else:
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head**-0.5
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class CrossAttention(nn.Module):
def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head**-0.5
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
context_dim = default(context_dim, dim)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
def forward(self, x, context=None):
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim=-1)
q = self.to_q(x)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class Transformer(nn.Module):
def __init__(
self,
dim: int,
depth: int,
heads: int,
dim_head: int,
mlp_dim: int,
dropout: float = 0.0,
norm: str = "layer",
norm_cond_dim: int = -1,
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
ff = FeedForward(dim, mlp_dim, dropout=dropout)
self.layers.append(
nn.ModuleList(
[
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
]
)
)
def forward(self, x: torch.Tensor, *args):
for attn, ff in self.layers:
x = attn(x, *args) + x
x = ff(x, *args) + x
return x
class TransformerCrossAttn(nn.Module):
def __init__(
self,
dim: int,
depth: int,
heads: int,
dim_head: int,
mlp_dim: int,
dropout: float = 0.0,
norm: str = "layer",
norm_cond_dim: int = -1,
context_dim: Optional[int] = None,
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
ca = CrossAttention(
dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
)
ff = FeedForward(dim, mlp_dim, dropout=dropout)
self.layers.append(
nn.ModuleList(
[
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
]
)
)
def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
if context_list is None:
context_list = [context] * len(self.layers)
if len(context_list) != len(self.layers):
raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
x = self_attn(x, *args) + x
x = cross_attn(x, *args, context=context_list[i]) + x
x = ff(x, *args) + x
return x
class DropTokenDropout(nn.Module):
def __init__(self, p: float = 0.1):
super().__init__()
if p < 0 or p > 1:
raise ValueError(
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
)
self.p = p
def forward(self, x: torch.Tensor):
# x: (batch_size, seq_len, dim)
if self.training and self.p > 0:
zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
# TODO: permutation idx for each batch using torch.argsort
if zero_mask.any():
x = x[:, ~zero_mask, :]
return x
class ZeroTokenDropout(nn.Module):
def __init__(self, p: float = 0.1):
super().__init__()
if p < 0 or p > 1:
raise ValueError(
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
)
self.p = p
def forward(self, x: torch.Tensor):
# x: (batch_size, seq_len, dim)
if self.training and self.p > 0:
zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
# Zero-out the masked tokens
x[zero_mask, :] = 0
return x
class TransformerEncoder(nn.Module):
def __init__(
self,
num_tokens: int,
token_dim: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
dim_head: int = 64,
dropout: float = 0.0,
emb_dropout: float = 0.0,
emb_dropout_type: str = "drop",
emb_dropout_loc: str = "token",
norm: str = "layer",
norm_cond_dim: int = -1,
token_pe_numfreq: int = -1,
):
super().__init__()
if token_pe_numfreq > 0:
token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
self.to_token_embedding = nn.Sequential(
Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
nn.Linear(token_dim_new, dim),
)
else:
self.to_token_embedding = nn.Linear(token_dim, dim)
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
if emb_dropout_type == "drop":
self.dropout = DropTokenDropout(emb_dropout)
elif emb_dropout_type == "zero":
self.dropout = ZeroTokenDropout(emb_dropout)
else:
raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
self.emb_dropout_loc = emb_dropout_loc
self.transformer = Transformer(
dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
)
def forward(self, inp: torch.Tensor, *args, **kwargs):
x = inp
if self.emb_dropout_loc == "input":
x = self.dropout(x)
x = self.to_token_embedding(x)
if self.emb_dropout_loc == "token":
x = self.dropout(x)
b, n, _ = x.shape
x += self.pos_embedding[:, :n]
if self.emb_dropout_loc == "token_afterpos":
x = self.dropout(x)
x = self.transformer(x, *args)
return x
class TransformerDecoder(nn.Module):
def __init__(
self,
num_tokens: int,
token_dim: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
dim_head: int = 64,
dropout: float = 0.0,
emb_dropout: float = 0.0,
emb_dropout_type: str = 'drop',
norm: str = "layer",
norm_cond_dim: int = -1,
context_dim: Optional[int] = None,
skip_token_embedding: bool = False,
):
super().__init__()
if not skip_token_embedding:
self.to_token_embedding = nn.Linear(token_dim, dim)
else:
self.to_token_embedding = nn.Identity()
if token_dim != dim:
raise ValueError(
f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
if emb_dropout_type == "drop":
self.dropout = DropTokenDropout(emb_dropout)
elif emb_dropout_type == "zero":
self.dropout = ZeroTokenDropout(emb_dropout)
elif emb_dropout_type == "normal":
self.dropout = nn.Dropout(emb_dropout)
self.transformer = TransformerCrossAttn(
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout,
norm=norm,
norm_cond_dim=norm_cond_dim,
context_dim=context_dim,
)
def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
x = self.to_token_embedding(inp)
b, n, _ = x.shape
x = self.dropout(x)
x += self.pos_embedding[:, :n]
x = self.transformer(x, *args, context=context, context_list=context_list)
return x