import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from diffab.modules.common.geometry import global_to_local, local_to_global, normalize_vector, construct_3d_basis, angstrom_to_nm from diffab.modules.common.layers import mask_zero, LayerNorm from diffab.utils.protein.constants import BBHeavyAtom def _alpha_from_logits(logits, mask, inf=1e5): """ Args: logits: Logit matrices, (N, L_i, L_j, num_heads). mask: Masks, (N, L). Returns: alpha: Attention weights. """ N, L, _, _ = logits.size() mask_row = mask.view(N, L, 1, 1).expand_as(logits) # (N, L, *, *) mask_pair = mask_row * mask_row.permute(0, 2, 1, 3) # (N, L, L, *) logits = torch.where(mask_pair, logits, logits - inf) alpha = torch.softmax(logits, dim=2) # (N, L, L, num_heads) alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha)) return alpha def _heads(x, n_heads, n_ch): """ Args: x: (..., num_heads * num_channels) Returns: (..., num_heads, num_channels) """ s = list(x.size())[:-1] + [n_heads, n_ch] return x.view(*s) class GABlock(nn.Module): def __init__(self, node_feat_dim, pair_feat_dim, value_dim=32, query_key_dim=32, num_query_points=8, num_value_points=8, num_heads=12, bias=False): super().__init__() self.node_feat_dim = node_feat_dim self.pair_feat_dim = pair_feat_dim self.value_dim = value_dim self.query_key_dim = query_key_dim self.num_query_points = num_query_points self.num_value_points = num_value_points self.num_heads = num_heads # Node self.proj_query = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias) self.proj_key = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias) self.proj_value = nn.Linear(node_feat_dim, value_dim * num_heads, bias=bias) # Pair self.proj_pair_bias = nn.Linear(pair_feat_dim, num_heads, bias=bias) # Spatial self.spatial_coef = nn.Parameter(torch.full([1, 1, 1, self.num_heads], fill_value=np.log(np.exp(1.) - 1.)), requires_grad=True) self.proj_query_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias) self.proj_key_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias) self.proj_value_point = nn.Linear(node_feat_dim, num_value_points * num_heads * 3, bias=bias) # Output self.out_transform = nn.Linear( in_features=(num_heads * pair_feat_dim) + (num_heads * value_dim) + ( num_heads * num_value_points * (3 + 3 + 1)), out_features=node_feat_dim, ) self.layer_norm_1 = LayerNorm(node_feat_dim) self.mlp_transition = nn.Sequential(nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(), nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(), nn.Linear(node_feat_dim, node_feat_dim)) self.layer_norm_2 = LayerNorm(node_feat_dim) def _node_logits(self, x): query_l = _heads(self.proj_query(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch) key_l = _heads(self.proj_key(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch) logits_node = (query_l.unsqueeze(2) * key_l.unsqueeze(1) * (1 / np.sqrt(self.query_key_dim))).sum(-1) # (N, L, L, num_heads) return logits_node def _pair_logits(self, z): logits_pair = self.proj_pair_bias(z) return logits_pair def _spatial_logits(self, R, t, x): N, L, _ = t.size() # Query query_points = _heads(self.proj_query_point(x), self.num_heads * self.num_query_points, 3) # (N, L, n_heads * n_pnts, 3) query_points = local_to_global(R, t, query_points) # Global query coordinates, (N, L, n_heads * n_pnts, 3) query_s = query_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3) # Key key_points = _heads(self.proj_key_point(x), self.num_heads * self.num_query_points, 3) # (N, L, 3, n_heads * n_pnts) key_points = local_to_global(R, t, key_points) # Global key coordinates, (N, L, n_heads * n_pnts, 3) key_s = key_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3) # Q-K Product sum_sq_dist = ((query_s.unsqueeze(2) - key_s.unsqueeze(1)) ** 2).sum(-1) # (N, L, L, n_heads) gamma = F.softplus(self.spatial_coef) logits_spatial = sum_sq_dist * ((-1 * gamma * np.sqrt(2 / (9 * self.num_query_points))) / 2) # (N, L, L, n_heads) return logits_spatial def _pair_aggregation(self, alpha, z): N, L = z.shape[:2] feat_p2n = alpha.unsqueeze(-1) * z.unsqueeze(-2) # (N, L, L, n_heads, C) feat_p2n = feat_p2n.sum(dim=2) # (N, L, n_heads, C) return feat_p2n.reshape(N, L, -1) def _node_aggregation(self, alpha, x): N, L = x.shape[:2] value_l = _heads(self.proj_value(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, v_ch) feat_node = alpha.unsqueeze(-1) * value_l.unsqueeze(1) # (N, L, L, n_heads, *) @ (N, *, L, n_heads, v_ch) feat_node = feat_node.sum(dim=2) # (N, L, n_heads, v_ch) return feat_node.reshape(N, L, -1) def _spatial_aggregation(self, alpha, R, t, x): N, L, _ = t.size() value_points = _heads(self.proj_value_point(x), self.num_heads * self.num_value_points, 3) # (N, L, n_heads * n_v_pnts, 3) value_points = local_to_global(R, t, value_points.reshape(N, L, self.num_heads, self.num_value_points, 3)) # (N, L, n_heads, n_v_pnts, 3) aggr_points = alpha.reshape(N, L, L, self.num_heads, 1, 1) * \ value_points.unsqueeze(1) # (N, *, L, n_heads, n_pnts, 3) aggr_points = aggr_points.sum(dim=2) # (N, L, n_heads, n_pnts, 3) feat_points = global_to_local(R, t, aggr_points) # (N, L, n_heads, n_pnts, 3) feat_distance = feat_points.norm(dim=-1) # (N, L, n_heads, n_pnts) feat_direction = normalize_vector(feat_points, dim=-1, eps=1e-4) # (N, L, n_heads, n_pnts, 3) feat_spatial = torch.cat([ feat_points.reshape(N, L, -1), feat_distance.reshape(N, L, -1), feat_direction.reshape(N, L, -1), ], dim=-1) return feat_spatial def forward(self, R, t, x, z, mask): """ Args: R: Frame basis matrices, (N, L, 3, 3_index). t: Frame external (absolute) coordinates, (N, L, 3). x: Node-wise features, (N, L, F). z: Pair-wise features, (N, L, L, C). mask: Masks, (N, L). Returns: x': Updated node-wise features, (N, L, F). """ # Attention logits logits_node = self._node_logits(x) logits_pair = self._pair_logits(z) logits_spatial = self._spatial_logits(R, t, x) # Summing logits up and apply `softmax`. logits_sum = logits_node + logits_pair + logits_spatial alpha = _alpha_from_logits(logits_sum * np.sqrt(1 / 3), mask) # (N, L, L, n_heads) # Aggregate features feat_p2n = self._pair_aggregation(alpha, z) feat_node = self._node_aggregation(alpha, x) feat_spatial = self._spatial_aggregation(alpha, R, t, x) # Finally feat_all = self.out_transform(torch.cat([feat_p2n, feat_node, feat_spatial], dim=-1)) # (N, L, F) feat_all = mask_zero(mask.unsqueeze(-1), feat_all) x_updated = self.layer_norm_1(x + feat_all) x_updated = self.layer_norm_2(x_updated + self.mlp_transition(x_updated)) return x_updated class GAEncoder(nn.Module): def __init__(self, node_feat_dim, pair_feat_dim, num_layers, ga_block_opt={}): super(GAEncoder, self).__init__() self.blocks = nn.ModuleList([ GABlock(node_feat_dim, pair_feat_dim, **ga_block_opt) for _ in range(num_layers) ]) def forward(self, R, t, res_feat, pair_feat, mask): for i, block in enumerate(self.blocks): res_feat = block(R, t, res_feat, pair_feat, mask) return res_feat