|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): |
|
""" |
|
Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` |
|
that are temporally closest to the current frame at `frame_idx`. Here, we take |
|
- a) the closest conditioning frame before `frame_idx` (if any); |
|
- b) the closest conditioning frame after `frame_idx` (if any); |
|
- c) any other temporally closest conditioning frames until reaching a total |
|
of `max_cond_frame_num` conditioning frames. |
|
|
|
Outputs: |
|
- selected_outputs: selected items (keys & values) from `cond_frame_outputs`. |
|
- unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. |
|
""" |
|
if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: |
|
selected_outputs = cond_frame_outputs |
|
unselected_outputs = {} |
|
else: |
|
assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" |
|
selected_outputs = {} |
|
|
|
|
|
idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) |
|
if idx_before is not None: |
|
selected_outputs[idx_before] = cond_frame_outputs[idx_before] |
|
|
|
|
|
idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) |
|
if idx_after is not None: |
|
selected_outputs[idx_after] = cond_frame_outputs[idx_after] |
|
|
|
|
|
|
|
num_remain = max_cond_frame_num - len(selected_outputs) |
|
inds_remain = sorted( |
|
(t for t in cond_frame_outputs if t not in selected_outputs), |
|
key=lambda x: abs(x - frame_idx), |
|
)[:num_remain] |
|
selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) |
|
unselected_outputs = { |
|
t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs |
|
} |
|
|
|
return selected_outputs, unselected_outputs |
|
|
|
|
|
def get_1d_sine_pe(pos_inds, dim, temperature=10000): |
|
""" |
|
Get 1D sine positional embedding as in the original Transformer paper. |
|
""" |
|
pe_dim = dim // 2 |
|
dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) |
|
dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) |
|
|
|
pos_embed = pos_inds.unsqueeze(-1) / dim_t |
|
pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) |
|
return pos_embed |
|
|
|
|
|
def get_activation_fn(activation): |
|
"""Return an activation function given a string""" |
|
if activation == "relu": |
|
return F.relu |
|
if activation == "gelu": |
|
return F.gelu |
|
if activation == "glu": |
|
return F.glu |
|
raise RuntimeError(f"activation should be relu/gelu, not {activation}.") |
|
|
|
|
|
def get_clones(module, N): |
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
|
|
class DropPath(nn.Module): |
|
|
|
def __init__(self, drop_prob=0.0, scale_by_keep=True): |
|
super(DropPath, self).__init__() |
|
self.drop_prob = drop_prob |
|
self.scale_by_keep = scale_by_keep |
|
|
|
def forward(self, x): |
|
if self.drop_prob == 0.0 or not self.training: |
|
return x |
|
keep_prob = 1 - self.drop_prob |
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob) |
|
if keep_prob > 0.0 and self.scale_by_keep: |
|
random_tensor.div_(keep_prob) |
|
return x * random_tensor |
|
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__( |
|
self, |
|
input_dim: int, |
|
hidden_dim: int, |
|
output_dim: int, |
|
num_layers: int, |
|
activation: nn.Module = nn.ReLU, |
|
sigmoid_output: bool = False, |
|
) -> None: |
|
super().__init__() |
|
self.num_layers = num_layers |
|
h = [hidden_dim] * (num_layers - 1) |
|
self.layers = nn.ModuleList( |
|
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) |
|
) |
|
self.sigmoid_output = sigmoid_output |
|
self.act = activation() |
|
|
|
def forward(self, x): |
|
for i, layer in enumerate(self.layers): |
|
x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) |
|
if self.sigmoid_output: |
|
x = F.sigmoid(x) |
|
return x |
|
|
|
|
|
|
|
|
|
class LayerNorm2d(nn.Module): |
|
def __init__(self, num_channels: int, eps: float = 1e-6) -> None: |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(num_channels)) |
|
self.bias = nn.Parameter(torch.zeros(num_channels)) |
|
self.eps = eps |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
u = x.mean(1, keepdim=True) |
|
s = (x - u).pow(2).mean(1, keepdim=True) |
|
x = (x - u) / torch.sqrt(s + self.eps) |
|
x = self.weight[:, None, None] * x + self.bias[:, None, None] |
|
return x |
|
|