luost26's picture
Update
753e275
raw
history blame
8.55 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from diffab.modules.common.geometry import global_to_local, local_to_global, normalize_vector, construct_3d_basis, angstrom_to_nm
from diffab.modules.common.layers import mask_zero, LayerNorm
from diffab.utils.protein.constants import BBHeavyAtom
def _alpha_from_logits(logits, mask, inf=1e5):
"""
Args:
logits: Logit matrices, (N, L_i, L_j, num_heads).
mask: Masks, (N, L).
Returns:
alpha: Attention weights.
"""
N, L, _, _ = logits.size()
mask_row = mask.view(N, L, 1, 1).expand_as(logits) # (N, L, *, *)
mask_pair = mask_row * mask_row.permute(0, 2, 1, 3) # (N, L, L, *)
logits = torch.where(mask_pair, logits, logits - inf)
alpha = torch.softmax(logits, dim=2) # (N, L, L, num_heads)
alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
return alpha
def _heads(x, n_heads, n_ch):
"""
Args:
x: (..., num_heads * num_channels)
Returns:
(..., num_heads, num_channels)
"""
s = list(x.size())[:-1] + [n_heads, n_ch]
return x.view(*s)
class GABlock(nn.Module):
def __init__(self, node_feat_dim, pair_feat_dim, value_dim=32, query_key_dim=32, num_query_points=8,
num_value_points=8, num_heads=12, bias=False):
super().__init__()
self.node_feat_dim = node_feat_dim
self.pair_feat_dim = pair_feat_dim
self.value_dim = value_dim
self.query_key_dim = query_key_dim
self.num_query_points = num_query_points
self.num_value_points = num_value_points
self.num_heads = num_heads
# Node
self.proj_query = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias)
self.proj_key = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias)
self.proj_value = nn.Linear(node_feat_dim, value_dim * num_heads, bias=bias)
# Pair
self.proj_pair_bias = nn.Linear(pair_feat_dim, num_heads, bias=bias)
# Spatial
self.spatial_coef = nn.Parameter(torch.full([1, 1, 1, self.num_heads], fill_value=np.log(np.exp(1.) - 1.)),
requires_grad=True)
self.proj_query_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias)
self.proj_key_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias)
self.proj_value_point = nn.Linear(node_feat_dim, num_value_points * num_heads * 3, bias=bias)
# Output
self.out_transform = nn.Linear(
in_features=(num_heads * pair_feat_dim) + (num_heads * value_dim) + (
num_heads * num_value_points * (3 + 3 + 1)),
out_features=node_feat_dim,
)
self.layer_norm_1 = LayerNorm(node_feat_dim)
self.mlp_transition = nn.Sequential(nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(),
nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(),
nn.Linear(node_feat_dim, node_feat_dim))
self.layer_norm_2 = LayerNorm(node_feat_dim)
def _node_logits(self, x):
query_l = _heads(self.proj_query(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch)
key_l = _heads(self.proj_key(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch)
logits_node = (query_l.unsqueeze(2) * key_l.unsqueeze(1) *
(1 / np.sqrt(self.query_key_dim))).sum(-1) # (N, L, L, num_heads)
return logits_node
def _pair_logits(self, z):
logits_pair = self.proj_pair_bias(z)
return logits_pair
def _spatial_logits(self, R, t, x):
N, L, _ = t.size()
# Query
query_points = _heads(self.proj_query_point(x), self.num_heads * self.num_query_points,
3) # (N, L, n_heads * n_pnts, 3)
query_points = local_to_global(R, t, query_points) # Global query coordinates, (N, L, n_heads * n_pnts, 3)
query_s = query_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3)
# Key
key_points = _heads(self.proj_key_point(x), self.num_heads * self.num_query_points,
3) # (N, L, 3, n_heads * n_pnts)
key_points = local_to_global(R, t, key_points) # Global key coordinates, (N, L, n_heads * n_pnts, 3)
key_s = key_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3)
# Q-K Product
sum_sq_dist = ((query_s.unsqueeze(2) - key_s.unsqueeze(1)) ** 2).sum(-1) # (N, L, L, n_heads)
gamma = F.softplus(self.spatial_coef)
logits_spatial = sum_sq_dist * ((-1 * gamma * np.sqrt(2 / (9 * self.num_query_points)))
/ 2) # (N, L, L, n_heads)
return logits_spatial
def _pair_aggregation(self, alpha, z):
N, L = z.shape[:2]
feat_p2n = alpha.unsqueeze(-1) * z.unsqueeze(-2) # (N, L, L, n_heads, C)
feat_p2n = feat_p2n.sum(dim=2) # (N, L, n_heads, C)
return feat_p2n.reshape(N, L, -1)
def _node_aggregation(self, alpha, x):
N, L = x.shape[:2]
value_l = _heads(self.proj_value(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, v_ch)
feat_node = alpha.unsqueeze(-1) * value_l.unsqueeze(1) # (N, L, L, n_heads, *) @ (N, *, L, n_heads, v_ch)
feat_node = feat_node.sum(dim=2) # (N, L, n_heads, v_ch)
return feat_node.reshape(N, L, -1)
def _spatial_aggregation(self, alpha, R, t, x):
N, L, _ = t.size()
value_points = _heads(self.proj_value_point(x), self.num_heads * self.num_value_points,
3) # (N, L, n_heads * n_v_pnts, 3)
value_points = local_to_global(R, t, value_points.reshape(N, L, self.num_heads, self.num_value_points,
3)) # (N, L, n_heads, n_v_pnts, 3)
aggr_points = alpha.reshape(N, L, L, self.num_heads, 1, 1) * \
value_points.unsqueeze(1) # (N, *, L, n_heads, n_pnts, 3)
aggr_points = aggr_points.sum(dim=2) # (N, L, n_heads, n_pnts, 3)
feat_points = global_to_local(R, t, aggr_points) # (N, L, n_heads, n_pnts, 3)
feat_distance = feat_points.norm(dim=-1) # (N, L, n_heads, n_pnts)
feat_direction = normalize_vector(feat_points, dim=-1, eps=1e-4) # (N, L, n_heads, n_pnts, 3)
feat_spatial = torch.cat([
feat_points.reshape(N, L, -1),
feat_distance.reshape(N, L, -1),
feat_direction.reshape(N, L, -1),
], dim=-1)
return feat_spatial
def forward(self, R, t, x, z, mask):
"""
Args:
R: Frame basis matrices, (N, L, 3, 3_index).
t: Frame external (absolute) coordinates, (N, L, 3).
x: Node-wise features, (N, L, F).
z: Pair-wise features, (N, L, L, C).
mask: Masks, (N, L).
Returns:
x': Updated node-wise features, (N, L, F).
"""
# Attention logits
logits_node = self._node_logits(x)
logits_pair = self._pair_logits(z)
logits_spatial = self._spatial_logits(R, t, x)
# Summing logits up and apply `softmax`.
logits_sum = logits_node + logits_pair + logits_spatial
alpha = _alpha_from_logits(logits_sum * np.sqrt(1 / 3), mask) # (N, L, L, n_heads)
# Aggregate features
feat_p2n = self._pair_aggregation(alpha, z)
feat_node = self._node_aggregation(alpha, x)
feat_spatial = self._spatial_aggregation(alpha, R, t, x)
# Finally
feat_all = self.out_transform(torch.cat([feat_p2n, feat_node, feat_spatial], dim=-1)) # (N, L, F)
feat_all = mask_zero(mask.unsqueeze(-1), feat_all)
x_updated = self.layer_norm_1(x + feat_all)
x_updated = self.layer_norm_2(x_updated + self.mlp_transition(x_updated))
return x_updated
class GAEncoder(nn.Module):
def __init__(self, node_feat_dim, pair_feat_dim, num_layers, ga_block_opt={}):
super(GAEncoder, self).__init__()
self.blocks = nn.ModuleList([
GABlock(node_feat_dim, pair_feat_dim, **ga_block_opt)
for _ in range(num_layers)
])
def forward(self, R, t, res_feat, pair_feat, mask):
for i, block in enumerate(self.blocks):
res_feat = block(R, t, res_feat, pair_feat, mask)
return res_feat