|
from typing import Optional, Tuple |
|
import torch |
|
from torch import Tensor, nn |
|
from torch_geometric.nn import MessagePassing |
|
from torch_geometric.utils import softmax |
|
from torch_scatter import scatter |
|
from torch_sparse import SparseTensor |
|
import loralib as lora |
|
from esm.multihead_attention import MultiheadAttention |
|
import math |
|
from torch import _dynamo |
|
_dynamo.config.suppress_errors = True |
|
from ..module.utils import ( |
|
CosineCutoff, |
|
act_class_mapping, |
|
get_template_fn, |
|
gelu |
|
) |
|
|
|
|
|
|
|
class EquivariantMultiHeadAttention(MessagePassing): |
|
"""Equivariant multi-head attention layer.""" |
|
|
|
def __init__( |
|
self, |
|
x_channels, |
|
x_hidden_channels, |
|
vec_channels, |
|
vec_hidden_channels, |
|
share_kv, |
|
edge_attr_channels, |
|
distance_influence, |
|
num_heads, |
|
activation, |
|
attn_activation, |
|
cutoff_lower, |
|
cutoff_upper, |
|
use_lora=None, |
|
): |
|
super(EquivariantMultiHeadAttention, self).__init__( |
|
aggr="mean", node_dim=0) |
|
assert x_hidden_channels % num_heads == 0 \ |
|
and vec_channels % num_heads == 0, ( |
|
f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) " |
|
f"and vec_channels ({vec_channels}) " |
|
f"must be evenly divisible by the number of " |
|
f"attention heads ({num_heads})" |
|
) |
|
assert vec_hidden_channels == x_channels, ( |
|
f"The number of hidden channels x_channels ({x_channels}) " |
|
f"and vec_hidden_channels ({vec_hidden_channels}) " |
|
f"must be equal" |
|
) |
|
|
|
self.distance_influence = distance_influence |
|
self.num_heads = num_heads |
|
self.x_channels = x_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.x_head_dim = x_hidden_channels // num_heads |
|
self.vec_channels = vec_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
|
|
self.vec_head_dim = vec_channels // num_heads |
|
self.share_kv = share_kv |
|
self.layernorm = nn.LayerNorm(x_channels) |
|
self.act = activation() |
|
self.attn_activation = act_class_mapping[attn_activation]() |
|
self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) |
|
|
|
if use_lora is not None: |
|
self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) |
|
self.k_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) if not share_kv else None |
|
self.v_proj = lora.Linear( |
|
x_channels, x_hidden_channels + vec_channels * 2, r=use_lora) |
|
self.o_proj = lora.Linear( |
|
x_hidden_channels, x_channels * 2 + vec_channels, r=use_lora) |
|
self.vec_proj = lora.Linear( |
|
vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False, r=use_lora) |
|
else: |
|
self.q_proj = nn.Linear(x_channels, x_hidden_channels) |
|
self.k_proj = nn.Linear(x_channels, x_hidden_channels) if not share_kv else None |
|
self.v_proj = nn.Linear( |
|
x_channels, x_hidden_channels + vec_channels * 2) |
|
self.o_proj = nn.Linear( |
|
x_hidden_channels, x_channels * 2 + vec_channels) |
|
self.vec_proj = nn.Linear( |
|
vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False) |
|
|
|
self.dk_proj = None |
|
if distance_influence in ["keys", "both"]: |
|
if use_lora is not None: |
|
self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) |
|
else: |
|
self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) |
|
|
|
self.dv_proj = None |
|
if distance_influence in ["values", "both"]: |
|
if use_lora is not None: |
|
self.dv_proj = lora.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2, r=use_lora) |
|
else: |
|
self.dv_proj = nn.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
self.layernorm.reset_parameters() |
|
nn.init.xavier_uniform_(self.q_proj.weight) |
|
self.q_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.k_proj.weight) |
|
self.k_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.v_proj.weight) |
|
self.v_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.o_proj.weight) |
|
self.o_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.vec_proj.weight) |
|
if self.dk_proj: |
|
nn.init.xavier_uniform_(self.dk_proj.weight) |
|
self.dk_proj.bias.data.fill_(0) |
|
if self.dv_proj: |
|
nn.init.xavier_uniform_(self.dv_proj.weight) |
|
self.dv_proj.bias.data.fill_(0) |
|
|
|
def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij, return_attn=False): |
|
x = self.layernorm(x) |
|
q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim) |
|
v = self.v_proj(x).reshape(-1, self.num_heads, |
|
self.x_head_dim + self.vec_head_dim * 2) |
|
if self.share_kv: |
|
k = v[:, :, :self.x_head_dim] |
|
else: |
|
k = self.k_proj(x).reshape(-1, self.num_heads, self.x_head_dim) |
|
|
|
vec1, vec2, vec3 = torch.split(self.vec_proj(vec), |
|
[self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1) |
|
vec = vec.reshape(-1, 3, self.num_heads, self.vec_head_dim) |
|
vec_dot = (vec1 * vec2).sum(dim=1) |
|
|
|
dk = ( |
|
self.act(self.dk_proj(f_ij)).reshape(-1, |
|
self.num_heads, self.x_head_dim) |
|
if self.dk_proj is not None |
|
else None |
|
) |
|
dv = ( |
|
self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, |
|
self.x_head_dim + self.vec_head_dim * 2) |
|
if self.dv_proj is not None |
|
else None |
|
) |
|
|
|
|
|
|
|
x, vec, attn = self.propagate( |
|
edge_index, |
|
q=q, |
|
k=k, |
|
v=v, |
|
vec=vec, |
|
dk=dk, |
|
dv=dv, |
|
r_ij=r_ij, |
|
d_ij=d_ij, |
|
size=None, |
|
) |
|
x = x.reshape(-1, self.x_hidden_channels) |
|
vec = vec.reshape(-1, 3, self.vec_channels) |
|
|
|
o1, o2, o3 = torch.split(self.o_proj( |
|
x), [self.vec_channels, self.x_channels, self.x_channels], dim=1) |
|
dx = vec_dot * o2 + o3 |
|
dvec = vec3 * o1.unsqueeze(1) + vec |
|
if return_attn: |
|
return dx, dvec, torch.concat((edge_index.T, attn), dim=1) |
|
else: |
|
return dx, dvec, None |
|
|
|
def message(self, q_i, k_j, v_j, vec_j, dk, dv, r_ij, d_ij): |
|
|
|
if dk is None: |
|
attn = (q_i * k_j).sum(dim=-1) |
|
else: |
|
attn = (q_i * k_j * dk).sum(dim=-1) |
|
|
|
|
|
attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1) |
|
|
|
|
|
if dv is not None: |
|
v_j = v_j * dv |
|
x, vec1, vec2 = torch.split( |
|
v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2) |
|
|
|
|
|
x = x * attn.unsqueeze(2) |
|
|
|
vec = vec_j * vec1.unsqueeze(1) + vec2.unsqueeze(1) * \ |
|
d_ij.unsqueeze(2).unsqueeze(3) |
|
return x, vec, attn |
|
|
|
def aggregate( |
|
self, |
|
features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], |
|
index: torch.Tensor, |
|
ptr: Optional[torch.Tensor], |
|
dim_size: Optional[int], |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
x, vec, attn = features |
|
x = scatter(x, index, dim=self.node_dim, dim_size=dim_size) |
|
vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size) |
|
return x, vec, attn |
|
|
|
def update( |
|
self, inputs: Tuple[torch.Tensor, torch.Tensor] |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
return inputs |
|
|
|
def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: |
|
pass |
|
|
|
def edge_update(self) -> Tensor: |
|
pass |
|
|
|
|
|
|
|
class ESMMultiheadAttention(MultiheadAttention): |
|
"""Multi-headed attention. |
|
|
|
See "Attention Is All You Need" for more details. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim, |
|
num_heads, |
|
kdim=None, |
|
vdim=None, |
|
dropout=0.0, |
|
bias=True, |
|
add_bias_kv: bool = False, |
|
add_zero_attn: bool = False, |
|
self_attention: bool = False, |
|
encoder_decoder_attention: bool = False, |
|
use_rotary_embeddings: bool = False, |
|
): |
|
super().__init__(embed_dim, num_heads, kdim, vdim, dropout, bias, add_bias_kv, add_zero_attn, self_attention, |
|
encoder_decoder_attention, use_rotary_embeddings) |
|
|
|
self.k_proj = lora.Linear(self.kdim, embed_dim, bias=bias, r=16) |
|
self.v_proj = lora.Linear(self.vdim, embed_dim, bias=bias, r=16) |
|
self.q_proj = lora.Linear(embed_dim, embed_dim, bias=bias, r=16) |
|
self.out_proj = lora.Linear(embed_dim, embed_dim, bias=bias, r=16) |
|
|
|
|
|
|
|
class EquivariantPAEMultiHeadAttention(EquivariantMultiHeadAttention): |
|
"""Equivariant multi-head attention layer.""" |
|
|
|
def __init__( |
|
self, |
|
x_channels, |
|
x_hidden_channels, |
|
vec_channels, |
|
vec_hidden_channels, |
|
share_kv, |
|
edge_attr_channels, |
|
edge_attr_dist_channels, |
|
distance_influence, |
|
num_heads, |
|
activation, |
|
attn_activation, |
|
cutoff_lower, |
|
cutoff_upper, |
|
use_lora=None, |
|
): |
|
super(EquivariantPAEMultiHeadAttention, self).__init__( |
|
x_channels=x_channels, |
|
x_hidden_channels=x_hidden_channels, |
|
vec_channels=vec_channels, |
|
vec_hidden_channels=vec_hidden_channels, |
|
share_kv=share_kv, |
|
edge_attr_channels=edge_attr_channels, |
|
distance_influence=distance_influence, |
|
num_heads=num_heads, |
|
activation=activation, |
|
attn_activation=attn_activation, |
|
cutoff_lower=cutoff_lower, |
|
cutoff_upper=cutoff_upper, |
|
use_lora=use_lora) |
|
|
|
self.cutoff = None |
|
|
|
self.dk_dist_proj = None |
|
if distance_influence in ["keys", "both"]: |
|
if use_lora is not None: |
|
self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora) |
|
else: |
|
self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels) |
|
self.dv_dist_proj = None |
|
if distance_influence in ["values", "both"]: |
|
if use_lora is not None: |
|
self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora) |
|
else: |
|
self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2) |
|
if self.dk_dist_proj: |
|
nn.init.xavier_uniform_(self.dk_dist_proj.weight) |
|
self.dk_dist_proj.bias.data.fill_(0) |
|
if self.dv_dist_proj: |
|
nn.init.xavier_uniform_(self.dv_dist_proj.weight) |
|
self.dv_dist_proj.bias.data.fill_(0) |
|
|
|
def forward(self, x, vec, edge_index, w_ij, f_dist_ij, f_ij, d_ij, plddt, return_attn=False): |
|
|
|
|
|
x = self.layernorm(x) |
|
q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim) |
|
v = self.v_proj(x).reshape(-1, self.num_heads, |
|
self.x_head_dim + self.vec_head_dim * 2) |
|
if self.share_kv: |
|
k = v[:, :, :self.x_head_dim] |
|
else: |
|
k = self.k_proj(x).reshape(-1, self.num_heads, self.x_head_dim) |
|
|
|
vec1, vec2, vec3 = torch.split(self.vec_proj(vec), |
|
[self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1) |
|
vec = vec.reshape(-1, 3, self.num_heads, self.vec_head_dim) |
|
vec_dot = (vec1 * vec2).sum(dim=1) |
|
|
|
dk = ( |
|
self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim) |
|
if self.dk_proj is not None |
|
else None |
|
) |
|
dk_dist = ( |
|
self.act(self.dk_dist_proj(f_dist_ij)).reshape(-1, self.num_heads, self.x_head_dim) |
|
if self.dk_dist_proj is not None |
|
else None |
|
) |
|
dv = ( |
|
self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2) |
|
if self.dv_proj is not None |
|
else None |
|
) |
|
dv_dist = ( |
|
self.act(self.dv_dist_proj(f_dist_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2) |
|
if self.dv_dist_proj is not None |
|
else None |
|
) |
|
|
|
|
|
|
|
x, vec, attn = self.propagate( |
|
edge_index, |
|
q=q, |
|
k=k, |
|
v=v, |
|
vec=vec, |
|
dk=dk, |
|
dk_dist=dk_dist, |
|
dv=dv, |
|
dv_dist=dv_dist, |
|
d_ij=d_ij, |
|
w_ij=w_ij, |
|
size=None, |
|
) |
|
x = x.reshape(-1, self.x_hidden_channels) |
|
vec = vec.reshape(-1, 3, self.vec_channels) |
|
|
|
o1, o2, o3 = torch.split(self.o_proj( |
|
x), [self.vec_channels, self.x_channels, self.x_channels], dim=1) |
|
dx = vec_dot * o2 * plddt.unsqueeze(1) + o3 |
|
dvec = vec3 * o1.unsqueeze(1) * plddt.unsqueeze(1).unsqueeze(2) + vec |
|
if return_attn: |
|
return dx, dvec, torch.concat((edge_index.T, attn), dim=1) |
|
else: |
|
return dx, dvec, None |
|
|
|
def message(self, q_i, k_j, v_j, vec_j, dk, dk_dist, dv, dv_dist, d_ij, w_ij): |
|
|
|
attn = (q_i * k_j) |
|
if dk is not None: |
|
attn += dk |
|
if dk_dist is not None: |
|
attn += dk_dist * w_ij.unsqueeze(1).unsqueeze(2) |
|
attn = attn.sum(dim=-1) |
|
|
|
|
|
attn = self.attn_activation(attn) |
|
|
|
|
|
if dv is not None: |
|
v_j += dv |
|
if dv_dist is not None: |
|
v_j += dv_dist * w_ij.unsqueeze(1).unsqueeze(2) |
|
x, vec1, vec2 = torch.split( |
|
v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2) |
|
|
|
|
|
x = x * attn.unsqueeze(2) |
|
|
|
vec = vec_j * vec1.unsqueeze(1) + vec2.unsqueeze(1) * \ |
|
d_ij.unsqueeze(2).unsqueeze(3) |
|
return x, vec, attn |
|
|
|
|
|
|
|
class EquivariantWeightedPAEMultiHeadAttention(EquivariantMultiHeadAttention): |
|
"""Equivariant multi-head attention layer.""" |
|
|
|
def __init__( |
|
self, |
|
x_channels, |
|
x_hidden_channels, |
|
vec_channels, |
|
vec_hidden_channels, |
|
share_kv, |
|
edge_attr_channels, |
|
edge_attr_dist_channels, |
|
distance_influence, |
|
num_heads, |
|
activation, |
|
attn_activation, |
|
cutoff_lower, |
|
cutoff_upper, |
|
use_lora=None, |
|
): |
|
super(EquivariantWeightedPAEMultiHeadAttention, self).__init__( |
|
x_channels=x_channels, |
|
x_hidden_channels=x_hidden_channels, |
|
vec_channels=vec_channels, |
|
vec_hidden_channels=vec_hidden_channels, |
|
share_kv=share_kv, |
|
edge_attr_channels=edge_attr_channels, |
|
distance_influence=distance_influence, |
|
num_heads=num_heads, |
|
activation=activation, |
|
attn_activation=attn_activation, |
|
cutoff_lower=cutoff_lower, |
|
cutoff_upper=cutoff_upper, |
|
use_lora=use_lora) |
|
|
|
self.cutoff = None |
|
|
|
self.pae_weight = nn.Linear(1, 1, bias=True) |
|
self.pae_weight.weight.data.fill_(-0.5) |
|
self.pae_weight.bias.data.fill_(7.5) |
|
|
|
self.dk_dist_proj = None |
|
if distance_influence in ["keys", "both"]: |
|
if use_lora is not None: |
|
self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora) |
|
else: |
|
self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels) |
|
self.dv_dist_proj = None |
|
if distance_influence in ["values", "both"]: |
|
if use_lora is not None: |
|
self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora) |
|
else: |
|
self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2) |
|
if self.dk_dist_proj: |
|
nn.init.xavier_uniform_(self.dk_dist_proj.weight) |
|
self.dk_dist_proj.bias.data.fill_(0) |
|
if self.dv_dist_proj: |
|
nn.init.xavier_uniform_(self.dv_dist_proj.weight) |
|
self.dv_dist_proj.bias.data.fill_(0) |
|
|
|
def forward(self, x, vec, edge_index, w_ij, f_dist_ij, f_ij, d_ij, plddt, return_attn=False): |
|
|
|
|
|
x = self.layernorm(x) |
|
q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim) |
|
v = self.v_proj(x).reshape(-1, self.num_heads, |
|
self.x_head_dim + self.vec_head_dim * 2) |
|
if self.share_kv: |
|
k = v[:, :, :self.x_head_dim] |
|
else: |
|
k = self.k_proj(x).reshape(-1, self.num_heads, self.x_head_dim) |
|
|
|
vec1, vec2, vec3 = torch.split(self.vec_proj(vec), |
|
[self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1) |
|
vec = vec.reshape(-1, 3, self.num_heads, self.vec_head_dim) |
|
vec_dot = (vec1 * vec2).sum(dim=1) |
|
|
|
dk = ( |
|
self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim) |
|
if self.dk_proj is not None |
|
else None |
|
) |
|
dk_dist = ( |
|
self.act(self.dk_dist_proj(f_dist_ij)).reshape(-1, self.num_heads, self.x_head_dim) |
|
if self.dk_dist_proj is not None |
|
else None |
|
) |
|
dv = ( |
|
self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2) |
|
if self.dv_proj is not None |
|
else None |
|
) |
|
dv_dist = ( |
|
self.act(self.dv_dist_proj(f_dist_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2) |
|
if self.dv_dist_proj is not None |
|
else None |
|
) |
|
|
|
|
|
|
|
x, vec, attn = self.propagate( |
|
edge_index, |
|
q=q, |
|
k=k, |
|
v=v, |
|
vec=vec, |
|
dk=dk, |
|
dk_dist=dk_dist, |
|
dv=dv, |
|
dv_dist=dv_dist, |
|
d_ij=d_ij, |
|
w_ij=nn.functional.sigmoid(self.pae_weight(w_ij.unsqueeze(-1)).squeeze(-1)), |
|
size=None, |
|
) |
|
x = x.reshape(-1, self.x_hidden_channels) |
|
vec = vec.reshape(-1, 3, self.vec_channels) |
|
|
|
o1, o2, o3 = torch.split(self.o_proj( |
|
x), [self.vec_channels, self.x_channels, self.x_channels], dim=1) |
|
dx = vec_dot * o2 * plddt.unsqueeze(1) + o3 |
|
dvec = vec3 * o1.unsqueeze(1) * plddt.unsqueeze(1).unsqueeze(2) + vec |
|
if return_attn: |
|
return dx, dvec, torch.concat((edge_index.T, attn), dim=1) |
|
else: |
|
return dx, dvec, None |
|
|
|
def message(self, q_i, k_j, v_j, vec_j, dk, dk_dist, dv, dv_dist, d_ij, w_ij): |
|
|
|
attn = (q_i * k_j) |
|
if dk_dist is not None: |
|
if dk is not None: |
|
attn *= (dk + dk_dist * w_ij.unsqueeze(1).unsqueeze(2)) |
|
else: |
|
attn *= dk_dist * w_ij |
|
else: |
|
if dk is not None: |
|
attn *= dk |
|
attn = attn.sum(dim=-1) |
|
|
|
|
|
attn = self.attn_activation(attn) |
|
|
|
|
|
if dv is not None: |
|
v_j += dv |
|
if dv_dist is not None: |
|
v_j += dv_dist * w_ij.unsqueeze(1).unsqueeze(2) |
|
x, vec1, vec2 = torch.split( |
|
v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2) |
|
|
|
|
|
x = x * attn.unsqueeze(2) |
|
|
|
vec = vec_j * vec1.unsqueeze(1) + vec2.unsqueeze(1) * \ |
|
d_ij.unsqueeze(2).unsqueeze(3) |
|
return x, vec, attn |
|
|
|
|
|
class EquivariantPAEMultiHeadAttentionSoftMaxFullGraph(nn.Module): |
|
"""Equivariant multi-head attention layer with softmax, apply attention on full graph by default""" |
|
def __init__( |
|
self, |
|
x_channels, |
|
x_hidden_channels, |
|
vec_channels, |
|
vec_hidden_channels, |
|
share_kv, |
|
edge_attr_channels, |
|
edge_attr_dist_channels, |
|
distance_influence, |
|
num_heads, |
|
activation, |
|
attn_activation, |
|
cutoff_lower, |
|
cutoff_upper, |
|
use_lora=None, |
|
): |
|
|
|
super(EquivariantPAEMultiHeadAttentionSoftMaxFullGraph, self).__init__() |
|
assert x_hidden_channels % num_heads == 0 \ |
|
and vec_channels % num_heads == 0, ( |
|
f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) " |
|
f"and vec_channels ({vec_channels}) " |
|
f"must be evenly divisible by the number of " |
|
f"attention heads ({num_heads})" |
|
) |
|
assert vec_hidden_channels == x_channels, ( |
|
f"The number of hidden channels x_channels ({x_channels}) " |
|
f"and vec_hidden_channels ({vec_hidden_channels}) " |
|
f"must be equal" |
|
) |
|
|
|
self.distance_influence = distance_influence |
|
self.num_heads = num_heads |
|
self.x_channels = x_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.x_head_dim = x_hidden_channels // num_heads |
|
self.vec_channels = vec_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
|
|
self.vec_head_dim = vec_channels // num_heads |
|
self.share_kv = share_kv |
|
self.layernorm = nn.LayerNorm(x_channels) |
|
self.act = activation() |
|
self.cutoff = None |
|
self.scaling = self.x_head_dim**-0.5 |
|
if use_lora is not None: |
|
self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) |
|
self.k_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) if not share_kv else None |
|
self.v_proj = lora.Linear(x_channels, x_hidden_channels + vec_channels * 2, r=use_lora) |
|
self.o_proj = lora.Linear(x_hidden_channels, x_channels * 2 + vec_channels, r=use_lora) |
|
self.vec_proj = lora.Linear(vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False, r=use_lora) |
|
else: |
|
self.q_proj = nn.Linear(x_channels, x_hidden_channels) |
|
self.k_proj = nn.Linear(x_channels, x_hidden_channels) if not share_kv else None |
|
self.v_proj = nn.Linear(x_channels, x_hidden_channels + vec_channels * 2) |
|
self.o_proj = nn.Linear(x_hidden_channels, x_channels * 2 + vec_channels) |
|
self.vec_proj = nn.Linear(vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False) |
|
|
|
self.dk_proj = None |
|
self.dk_dist_proj = None |
|
self.dv_proj = None |
|
self.dv_dist_proj = None |
|
if distance_influence in ["keys", "both"]: |
|
if use_lora is not None: |
|
self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) |
|
self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora) |
|
else: |
|
self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) |
|
self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels) |
|
|
|
if distance_influence in ["values", "both"]: |
|
if use_lora is not None: |
|
self.dv_proj = lora.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2, r=use_lora) |
|
self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora) |
|
else: |
|
self.dv_proj = nn.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2) |
|
self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2) |
|
|
|
self.pae_weight = nn.Linear(1, 1, bias=True) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
self.layernorm.reset_parameters() |
|
nn.init.xavier_uniform_(self.q_proj.weight) |
|
self.q_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.k_proj.weight) |
|
self.k_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.v_proj.weight) |
|
self.v_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.o_proj.weight) |
|
self.o_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.vec_proj.weight) |
|
self.pae_weight.weight.data.fill_(-0.5) |
|
self.pae_weight.bias.data.fill_(7.5) |
|
if self.dk_proj: |
|
nn.init.xavier_uniform_(self.dk_proj.weight) |
|
self.dk_proj.bias.data.fill_(0) |
|
if self.dv_proj: |
|
nn.init.xavier_uniform_(self.dv_proj.weight) |
|
self.dv_proj.bias.data.fill_(0) |
|
if self.dk_dist_proj: |
|
nn.init.xavier_uniform_(self.dk_dist_proj.weight) |
|
self.dk_dist_proj.bias.data.fill_(0) |
|
if self.dv_dist_proj: |
|
nn.init.xavier_uniform_(self.dv_dist_proj.weight) |
|
self.dv_dist_proj.bias.data.fill_(0) |
|
|
|
def forward(self, x, vec, edge_index, w_ij, f_dist_ij, f_ij, d_ij, plddt, key_padding_mask, return_attn=False): |
|
|
|
|
|
|
|
x = self.layernorm(x) |
|
q = self.q_proj(x) * self.scaling |
|
v = self.v_proj(x) |
|
|
|
|
|
|
|
k = self.k_proj(x) |
|
|
|
vec1, vec2, vec3 = torch.split(self.vec_proj(vec), |
|
[self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1) |
|
vec_dot = (vec1 * vec2).sum(dim=-2) |
|
|
|
dk = self.act(self.dk_proj(f_ij)) |
|
dk_dist = self.act(self.dk_dist_proj(f_dist_ij)) |
|
dv = self.act(self.dv_proj(f_ij)) |
|
dv_dist = self.act(self.dv_dist_proj(f_dist_ij)) |
|
|
|
|
|
x, vec, attn = self.attention( |
|
q=q, |
|
k=k, |
|
v=v, |
|
vec=vec, |
|
dk=dk, |
|
dk_dist=dk_dist, |
|
dv=dv, |
|
dv_dist=dv_dist, |
|
d_ij=d_ij, |
|
w_ij=nn.functional.sigmoid(self.pae_weight(w_ij.unsqueeze(-1)).squeeze(-1)), |
|
key_padding_mask=key_padding_mask, |
|
) |
|
o1, o2, o3 = torch.split(self.o_proj(x), [self.vec_channels, self.x_channels, self.x_channels], dim=-1) |
|
dx = vec_dot * o2 * plddt.unsqueeze(-1) + o3 |
|
dvec = vec3 * o1.unsqueeze(-2) * plddt.unsqueeze(-1).unsqueeze(-2) + vec |
|
|
|
dx = dx.masked_fill(key_padding_mask.unsqueeze(-1), 0) |
|
dvec = dvec.masked_fill(key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0) |
|
if return_attn: |
|
return dx, dvec, attn |
|
else: |
|
return dx, dvec, None |
|
|
|
def attention(self, q, k, v, vec, dk, dk_dist, dv, dv_dist, d_ij, w_ij, key_padding_mask=None, need_head_weights=False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bsz, tgt_len, _ = q.size() |
|
src_len = k.size(1) |
|
|
|
|
|
q = q.transpose(0, 1).reshape(tgt_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() |
|
k = k.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() |
|
v = v.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).transpose(0, 1).contiguous() |
|
|
|
vec = vec.permute(1, 2, 0, 3).reshape(src_len, 3, bsz * self.num_heads, self.vec_head_dim).permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
|
|
dk = dk.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
dk_dist = dk_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
|
|
dv = dv.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
dv_dist = dv_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
assert key_padding_mask.size(0) == bsz |
|
assert key_padding_mask.size(1) == src_len |
|
|
|
attn_weights = torch.multiply(q[:, :, None, :], k[:, None, :, :]) |
|
|
|
|
|
|
|
|
|
assert w_ij is not None |
|
|
|
attn_weights *= (dk + dk_dist * w_ij[:, :, :, None].repeat(self.num_heads, 1, 1, self.x_head_dim)) |
|
|
|
v = v.unsqueeze(1) + dv + dv_dist * w_ij[:, :, :, None].repeat(self.num_heads, 1, 1, self.x_head_dim + 2 * self.vec_head_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_weights = attn_weights.sum(dim=-1) |
|
|
|
|
|
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).contiguous() |
|
attn_weights = attn_weights.masked_fill( |
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") |
|
) |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len).contiguous() |
|
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32) |
|
|
|
x, vec1, vec2 = torch.split(v, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=-1) |
|
|
|
x_out = torch.einsum('bts,btsh->bth', attn_weights, x) |
|
|
|
vec_out_1 = torch.einsum('bsih,btsh->btih', vec, vec1) |
|
|
|
vec_out_2 = torch.einsum('btsi,btsh->btih', d_ij, vec2) |
|
|
|
vec_out = vec_out_1 + vec_out_2 |
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) |
|
|
|
|
|
attn_weights = attn_weights.mean(dim=0) |
|
|
|
x_out = x_out.transpose(1, 0).reshape(tgt_len, bsz, self.num_heads * self.x_head_dim).transpose(1, 0).contiguous() |
|
|
|
vec_out = vec_out.permute(1, 2, 0, 3).reshape(tgt_len, 3, bsz, self.num_heads * self.vec_head_dim).permute(2, 0, 1, 3).contiguous() |
|
return x_out, vec_out, attn_weights |
|
|
|
|
|
class MultiHeadAttentionSoftMaxFullGraph(nn.Module): |
|
""" |
|
Multi-head attention layer with softmax, apply attention on full graph by default |
|
No equivariant property, but can take structure information as input, just didn't use it |
|
""" |
|
def __init__( |
|
self, |
|
x_channels, |
|
x_hidden_channels, |
|
vec_channels, |
|
vec_hidden_channels, |
|
share_kv, |
|
edge_attr_channels, |
|
edge_attr_dist_channels, |
|
distance_influence, |
|
num_heads, |
|
activation, |
|
attn_activation, |
|
cutoff_lower, |
|
cutoff_upper, |
|
use_lora=None, |
|
): |
|
|
|
super(MultiHeadAttentionSoftMaxFullGraph, self).__init__() |
|
assert x_hidden_channels % num_heads == 0 \ |
|
and vec_channels % num_heads == 0, ( |
|
f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) " |
|
f"and vec_channels ({vec_channels}) " |
|
f"must be evenly divisible by the number of " |
|
f"attention heads ({num_heads})" |
|
) |
|
assert vec_hidden_channels == x_channels, ( |
|
f"The number of hidden channels x_channels ({x_channels}) " |
|
f"and vec_hidden_channels ({vec_hidden_channels}) " |
|
f"must be equal" |
|
) |
|
|
|
self.distance_influence = distance_influence |
|
self.num_heads = num_heads |
|
self.x_channels = x_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.x_head_dim = x_hidden_channels // num_heads |
|
self.vec_channels = vec_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
|
|
self.vec_head_dim = vec_channels // num_heads |
|
self.share_kv = share_kv |
|
self.layernorm = nn.LayerNorm(x_channels) |
|
self.act = activation() |
|
self.cutoff = None |
|
self.scaling = self.x_head_dim**-0.5 |
|
if use_lora is not None: |
|
self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) |
|
self.k_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) if not share_kv else None |
|
self.v_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) |
|
self.o_proj = lora.Linear(x_hidden_channels, x_channels, r=use_lora) |
|
|
|
else: |
|
self.q_proj = nn.Linear(x_channels, x_hidden_channels) |
|
self.k_proj = nn.Linear(x_channels, x_hidden_channels) if not share_kv else None |
|
self.v_proj = nn.Linear(x_channels, x_hidden_channels) |
|
self.o_proj = nn.Linear(x_hidden_channels, x_channels) |
|
|
|
|
|
self.dk_proj = None |
|
self.dk_dist_proj = None |
|
self.dv_proj = None |
|
self.dv_dist_proj = None |
|
if distance_influence in ["keys", "both"]: |
|
if use_lora is not None: |
|
self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) |
|
|
|
else: |
|
self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) |
|
|
|
|
|
if distance_influence in ["values", "both"]: |
|
if use_lora is not None: |
|
self.dv_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) |
|
|
|
else: |
|
self.dv_proj = nn.Linear(edge_attr_channels, x_hidden_channels) |
|
|
|
|
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
self.layernorm.reset_parameters() |
|
nn.init.xavier_uniform_(self.q_proj.weight) |
|
self.q_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.k_proj.weight) |
|
self.k_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.v_proj.weight) |
|
self.v_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.o_proj.weight) |
|
self.o_proj.bias.data.fill_(0) |
|
|
|
|
|
|
|
if self.dk_proj: |
|
nn.init.xavier_uniform_(self.dk_proj.weight) |
|
self.dk_proj.bias.data.fill_(0) |
|
if self.dv_proj: |
|
nn.init.xavier_uniform_(self.dv_proj.weight) |
|
self.dv_proj.bias.data.fill_(0) |
|
|
|
def forward(self, x, vec, edge_index, w_ij, f_dist_ij, f_ij, d_ij, plddt, key_padding_mask, return_attn=False): |
|
|
|
|
|
|
|
x = self.layernorm(x) |
|
q = self.q_proj(x) * self.scaling |
|
v = self.v_proj(x) |
|
|
|
|
|
|
|
k = self.k_proj(x) |
|
|
|
|
|
|
|
|
|
|
|
dk = self.act(self.dk_proj(f_ij)) |
|
|
|
dv = self.act(self.dv_proj(f_ij)) |
|
|
|
|
|
|
|
x, vec, attn = self.attention( |
|
q=q, |
|
k=k, |
|
v=v, |
|
vec=vec, |
|
dk=dk, |
|
|
|
dv=dv, |
|
|
|
|
|
|
|
key_padding_mask=key_padding_mask, |
|
) |
|
|
|
|
|
dx = self.o_proj(x) |
|
|
|
|
|
dx = dx.masked_fill(key_padding_mask.unsqueeze(-1), 0) |
|
|
|
if return_attn: |
|
return dx, vec, attn |
|
else: |
|
return dx, vec, None |
|
|
|
def attention(self, q, k, v, vec, dk, dv, key_padding_mask=None, need_head_weights=False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bsz, tgt_len, _ = q.size() |
|
src_len = k.size(1) |
|
|
|
|
|
q = q.transpose(0, 1).reshape(tgt_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() |
|
k = k.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() |
|
v = v.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
dk = dk.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
|
|
dv = dv.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
assert key_padding_mask.size(0) == bsz |
|
assert key_padding_mask.size(1) == src_len |
|
|
|
attn_weights = torch.multiply(q[:, :, None, :], k[:, None, :, :]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_weights *= dk |
|
|
|
v = v.unsqueeze(1) + dv |
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_weights = attn_weights.sum(dim=-1) |
|
|
|
|
|
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).contiguous() |
|
attn_weights = attn_weights.masked_fill( |
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") |
|
) |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len).contiguous() |
|
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32) |
|
|
|
|
|
|
|
x_out = torch.einsum('bts,btsh->bth', attn_weights, v) |
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) |
|
|
|
|
|
attn_weights = attn_weights.mean(dim=0) |
|
|
|
x_out = x_out.transpose(1, 0).reshape(tgt_len, bsz, self.num_heads * self.x_head_dim).transpose(1, 0).contiguous() |
|
|
|
|
|
return x_out, vec, attn_weights |
|
|
|
|
|
class PAEMultiHeadAttentionSoftMaxStarGraph(nn.Module): |
|
"""Equivariant multi-head attention layer with softmax, apply attention on full graph by default""" |
|
def __init__( |
|
self, |
|
x_channels, |
|
x_hidden_channels, |
|
vec_channels, |
|
vec_hidden_channels, |
|
share_kv, |
|
edge_attr_channels, |
|
edge_attr_dist_channels, |
|
distance_influence, |
|
num_heads, |
|
activation, |
|
cutoff_lower, |
|
cutoff_upper, |
|
use_lora=None, |
|
): |
|
|
|
super(PAEMultiHeadAttentionSoftMaxStarGraph, self).__init__() |
|
assert x_hidden_channels % num_heads == 0 \ |
|
and vec_channels % num_heads == 0, ( |
|
f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) " |
|
f"and vec_channels ({vec_channels}) " |
|
f"must be evenly divisible by the number of " |
|
f"attention heads ({num_heads})" |
|
) |
|
assert vec_hidden_channels == x_channels, ( |
|
f"The number of hidden channels x_channels ({x_channels}) " |
|
f"and vec_hidden_channels ({vec_hidden_channels}) " |
|
f"must be equal" |
|
) |
|
|
|
self.distance_influence = distance_influence |
|
self.num_heads = num_heads |
|
self.x_channels = x_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.x_head_dim = x_hidden_channels // num_heads |
|
self.vec_channels = vec_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
|
|
self.vec_head_dim = vec_channels // num_heads |
|
self.share_kv = share_kv |
|
self.layernorm = nn.LayerNorm(x_channels) |
|
self.act = activation() |
|
self.cutoff = None |
|
self.scaling = self.x_head_dim**-0.5 |
|
if use_lora is not None: |
|
self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) |
|
self.k_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) if not share_kv else None |
|
self.v_proj = lora.Linear(x_channels, x_hidden_channels + vec_channels * 2, r=use_lora) |
|
else: |
|
self.q_proj = nn.Linear(x_channels, x_hidden_channels) |
|
self.k_proj = nn.Linear(x_channels, x_hidden_channels) if not share_kv else None |
|
self.v_proj = nn.Linear(x_channels, x_hidden_channels) |
|
|
|
self.dk_proj = None |
|
self.dk_dist_proj = None |
|
self.dv_proj = None |
|
self.dv_dist_proj = None |
|
if distance_influence in ["keys", "both"]: |
|
if use_lora is not None: |
|
self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) |
|
self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora) |
|
else: |
|
self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) |
|
self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels) |
|
|
|
if distance_influence in ["values", "both"]: |
|
if use_lora is not None: |
|
self.dv_proj = lora.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2, r=use_lora) |
|
self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora) |
|
else: |
|
self.dv_proj = nn.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2) |
|
self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2) |
|
|
|
self.pae_weight = nn.Linear(1, 1, bias=True) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
self.layernorm.reset_parameters() |
|
nn.init.xavier_uniform_(self.q_proj.weight) |
|
self.q_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.k_proj.weight) |
|
self.k_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.v_proj.weight) |
|
self.v_proj.bias.data.fill_(0) |
|
self.pae_weight.weight.data.fill_(-0.5) |
|
self.pae_weight.bias.data.fill_(7.5) |
|
if self.dk_proj: |
|
nn.init.xavier_uniform_(self.dk_proj.weight) |
|
self.dk_proj.bias.data.fill_(0) |
|
if self.dv_proj: |
|
nn.init.xavier_uniform_(self.dv_proj.weight) |
|
self.dv_proj.bias.data.fill_(0) |
|
if self.dk_dist_proj: |
|
nn.init.xavier_uniform_(self.dk_dist_proj.weight) |
|
self.dk_dist_proj.bias.data.fill_(0) |
|
if self.dv_dist_proj: |
|
nn.init.xavier_uniform_(self.dv_dist_proj.weight) |
|
self.dv_dist_proj.bias.data.fill_(0) |
|
|
|
def forward(self, x, x_center_index, w_ij, f_dist_ij, f_ij, key_padding_mask, return_attn=False): |
|
|
|
|
|
|
|
x = self.layernorm(x) |
|
q = self.q_proj(x[x_center_index].unsqueeze(1)) * self.scaling |
|
v = self.v_proj(x) |
|
|
|
|
|
|
|
k = self.k_proj(x) |
|
|
|
dk = self.act(self.dk_proj(f_ij)) |
|
dk_dist = self.act(self.dk_dist_proj(f_dist_ij)) |
|
dv = self.act(self.dv_proj(f_ij)) |
|
dv_dist = self.act(self.dv_dist_proj(f_dist_ij)) |
|
|
|
|
|
x, attn = self.attention( |
|
q=q, |
|
k=k, |
|
v=v, |
|
dk=dk, |
|
dk_dist=dk_dist, |
|
dv=dv, |
|
dv_dist=dv_dist, |
|
w_ij=nn.functional.sigmoid(self.pae_weight(w_ij.unsqueeze(-1)).squeeze(-1)), |
|
key_padding_mask=key_padding_mask, |
|
) |
|
if return_attn: |
|
return x, attn |
|
else: |
|
return x, None |
|
|
|
def attention(self, q, k, v, dk, dk_dist, dv, dv_dist, w_ij, key_padding_mask=None, need_head_weights=False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bsz, tgt_len, _ = q.size() |
|
src_len = k.size(1) |
|
|
|
|
|
q = q.transpose(0, 1).reshape(tgt_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() |
|
k = k.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() |
|
v = v.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() |
|
|
|
|
|
|
|
dk = dk.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
dk_dist = dk_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
|
|
dv = dv.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
dv_dist = dv_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
assert key_padding_mask.size(0) == bsz |
|
assert key_padding_mask.size(1) == src_len |
|
|
|
attn_weights = torch.multiply(q[:, :, None, :], k[:, None, :, :]) |
|
|
|
|
|
|
|
|
|
assert w_ij is not None |
|
|
|
attn_weights *= (dk + dk_dist * w_ij[:, :, :, None].repeat(self.num_heads, 1, 1, self.x_head_dim)) |
|
|
|
v = v.unsqueeze(1) + dv + dv_dist * w_ij[:, :, :, None].repeat(self.num_heads, 1, 1, self.x_head_dim + 2 * self.vec_head_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_weights = attn_weights.sum(dim=-1) |
|
|
|
|
|
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).contiguous() |
|
attn_weights = attn_weights.masked_fill( |
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") |
|
) |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len).contiguous() |
|
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32) |
|
|
|
x_out = torch.einsum('bts,btsh->bth', attn_weights, v) |
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) |
|
|
|
|
|
attn_weights = attn_weights.mean(dim=0) |
|
|
|
x_out = x_out.transpose(1, 0).reshape(tgt_len, bsz, self.num_heads * self.x_head_dim).transpose(1, 0).contiguous() |
|
return x_out, attn_weights |
|
|
|
|
|
class MultiHeadAttentionSoftMaxStarGraph(nn.Module): |
|
"""Equivariant multi-head attention layer with softmax, apply attention on full graph by default""" |
|
def __init__( |
|
self, |
|
x_channels, |
|
x_hidden_channels, |
|
vec_channels, |
|
vec_hidden_channels, |
|
share_kv, |
|
edge_attr_channels, |
|
edge_attr_dist_channels, |
|
distance_influence, |
|
num_heads, |
|
activation, |
|
cutoff_lower, |
|
cutoff_upper, |
|
use_lora=None, |
|
): |
|
|
|
super(MultiHeadAttentionSoftMaxStarGraph, self).__init__() |
|
assert x_hidden_channels % num_heads == 0 \ |
|
and vec_channels % num_heads == 0, ( |
|
f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) " |
|
f"and vec_channels ({vec_channels}) " |
|
f"must be evenly divisible by the number of " |
|
f"attention heads ({num_heads})" |
|
) |
|
assert vec_hidden_channels == x_channels, ( |
|
f"The number of hidden channels x_channels ({x_channels}) " |
|
f"and vec_hidden_channels ({vec_hidden_channels}) " |
|
f"must be equal" |
|
) |
|
|
|
self.distance_influence = distance_influence |
|
self.num_heads = num_heads |
|
self.x_channels = x_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.x_head_dim = x_hidden_channels // num_heads |
|
self.vec_channels = vec_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
|
|
self.vec_head_dim = vec_channels // num_heads |
|
self.share_kv = share_kv |
|
self.layernorm = nn.LayerNorm(x_channels) |
|
self.act = activation() |
|
self.cutoff = None |
|
self.scaling = self.x_head_dim**-0.5 |
|
if use_lora is not None: |
|
self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) |
|
self.k_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) if not share_kv else None |
|
self.v_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) |
|
else: |
|
self.q_proj = nn.Linear(x_channels, x_hidden_channels) |
|
self.k_proj = nn.Linear(x_channels, x_hidden_channels) if not share_kv else None |
|
self.v_proj = nn.Linear(x_channels, x_hidden_channels) |
|
|
|
self.dk_proj = None |
|
|
|
self.dv_proj = None |
|
|
|
if distance_influence in ["keys", "both"]: |
|
if use_lora is not None: |
|
self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) |
|
|
|
else: |
|
self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) |
|
|
|
|
|
if distance_influence in ["values", "both"]: |
|
if use_lora is not None: |
|
self.dv_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) |
|
|
|
else: |
|
self.dv_proj = nn.Linear(edge_attr_channels, x_hidden_channels) |
|
|
|
|
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
self.layernorm.reset_parameters() |
|
nn.init.xavier_uniform_(self.q_proj.weight) |
|
self.q_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.k_proj.weight) |
|
self.k_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.v_proj.weight) |
|
self.v_proj.bias.data.fill_(0) |
|
|
|
|
|
if self.dk_proj: |
|
nn.init.xavier_uniform_(self.dk_proj.weight) |
|
self.dk_proj.bias.data.fill_(0) |
|
if self.dv_proj: |
|
nn.init.xavier_uniform_(self.dv_proj.weight) |
|
self.dv_proj.bias.data.fill_(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, x_center_index, w_ij, f_dist_ij, f_ij, key_padding_mask, return_attn=False): |
|
|
|
|
|
|
|
x = self.layernorm(x) |
|
q = self.q_proj(x[x_center_index].unsqueeze(1)) * self.scaling |
|
v = self.v_proj(x) |
|
|
|
|
|
|
|
k = self.k_proj(x) |
|
|
|
dk = self.act(self.dk_proj(f_ij)) |
|
|
|
dv = self.act(self.dv_proj(f_ij)) |
|
|
|
|
|
|
|
x, attn = self.attention( |
|
q=q, |
|
k=k, |
|
v=v, |
|
dk=dk, |
|
|
|
dv=dv, |
|
|
|
|
|
key_padding_mask=key_padding_mask, |
|
) |
|
if return_attn: |
|
return x, attn |
|
else: |
|
return x, None |
|
|
|
def attention(self, q, k, v, dk, dv, key_padding_mask=None, need_head_weights=False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bsz, tgt_len, _ = q.size() |
|
src_len = k.size(1) |
|
|
|
|
|
q = q.transpose(0, 1).reshape(tgt_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() |
|
k = k.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() |
|
v = v.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() |
|
|
|
|
|
|
|
dk = dk.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
|
|
dv = dv.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
assert key_padding_mask.size(0) == bsz |
|
assert key_padding_mask.size(1) == src_len |
|
|
|
attn_weights = torch.multiply(q[:, :, None, :], k[:, None, :, :]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_weights *= dk |
|
|
|
v = v.unsqueeze(1) + dv |
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_weights = attn_weights.sum(dim=-1) |
|
|
|
|
|
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).contiguous() |
|
attn_weights = attn_weights.masked_fill( |
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") |
|
) |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len).contiguous() |
|
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32) |
|
|
|
x_out = torch.einsum('bts,btsh->bth', attn_weights, v) |
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) |
|
|
|
|
|
attn_weights = attn_weights.mean(dim=0) |
|
|
|
x_out = x_out.transpose(1, 0).reshape(tgt_len, bsz, self.num_heads * self.x_head_dim).transpose(1, 0).contiguous() |
|
return x_out, attn_weights |
|
|
|
|
|
|
|
class EquivariantProMultiHeadAttention(MessagePassing): |
|
"""Equivariant multi-head attention layer.""" |
|
|
|
def __init__( |
|
self, |
|
x_channels, |
|
x_hidden_channels, |
|
vec_channels, |
|
vec_hidden_channels, |
|
edge_attr_channels, |
|
distance_influence, |
|
num_heads, |
|
activation, |
|
attn_activation, |
|
cutoff_lower, |
|
cutoff_upper, |
|
): |
|
super(EquivariantMultiHeadAttention, self).__init__( |
|
aggr="mean", node_dim=0) |
|
assert x_hidden_channels % num_heads == 0 \ |
|
and vec_channels % num_heads == 0, ( |
|
f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) " |
|
f"and vec_channels ({vec_channels}) " |
|
f"must be evenly divisible by the number of " |
|
f"attention heads ({num_heads})" |
|
) |
|
assert vec_hidden_channels == x_channels, ( |
|
f"The number of hidden channels x_channels ({x_channels}) " |
|
f"and vec_hidden_channels ({vec_hidden_channels}) " |
|
f"must be equal" |
|
) |
|
|
|
self.distance_influence = distance_influence |
|
self.num_heads = num_heads |
|
self.x_channels = x_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.x_head_dim = x_hidden_channels // num_heads |
|
self.vec_channels = vec_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
|
|
self.vec_head_dim = vec_channels // num_heads |
|
|
|
self.layernorm = nn.LayerNorm(x_channels) |
|
self.act = activation() |
|
self.attn_activation = act_class_mapping[attn_activation]() |
|
self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) |
|
|
|
self.q_proj = nn.Linear(x_channels, x_hidden_channels) |
|
|
|
self.kv_proj = nn.Linear( |
|
x_channels, x_hidden_channels + vec_channels * 2) |
|
self.o_proj = nn.Linear( |
|
x_hidden_channels, x_channels * 2 + vec_channels) |
|
|
|
self.vec_proj = nn.Linear( |
|
vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False) |
|
|
|
self.dk_proj = None |
|
if distance_influence in ["keys", "both"]: |
|
self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) |
|
|
|
self.dv_proj = None |
|
if distance_influence in ["values", "both"]: |
|
self.dv_proj = nn.Linear( |
|
edge_attr_channels, x_hidden_channels + vec_channels * 2) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
self.layernorm.reset_parameters() |
|
nn.init.xavier_uniform_(self.q_proj.weight) |
|
self.q_proj.bias.data.fill_(0) |
|
|
|
|
|
nn.init.xavier_uniform_(self.kv_proj.weight) |
|
self.kv_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.o_proj.weight) |
|
self.o_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.vec_proj.weight) |
|
if self.dk_proj: |
|
nn.init.xavier_uniform_(self.dk_proj.weight) |
|
self.dk_proj.bias.data.fill_(0) |
|
if self.dv_proj: |
|
nn.init.xavier_uniform_(self.dv_proj.weight) |
|
self.dv_proj.bias.data.fill_(0) |
|
|
|
def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij, return_attn=False): |
|
x = self.layernorm(x) |
|
q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim) |
|
|
|
v = self.kv_proj(x).reshape(-1, self.num_heads, |
|
self.x_head_dim + self.vec_head_dim * 2) |
|
k = v[:, :, :self.x_head_dim] |
|
|
|
vec1, vec2, vec3 = torch.split(self.vec_proj(vec), |
|
[self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1) |
|
vec = vec.reshape(-1, 3, self.num_heads, self.vec_head_dim) |
|
vec_dot = (vec1 * vec2).sum(dim=1) |
|
|
|
dk = ( |
|
self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim) |
|
if self.dk_proj is not None |
|
else None |
|
) |
|
dv = ( |
|
self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2) |
|
if self.dv_proj is not None |
|
else None |
|
) |
|
|
|
|
|
|
|
x, vec, attn = self.propagate( |
|
edge_index, |
|
q=q, |
|
k=k, |
|
v=v, |
|
vec=vec, |
|
dk=dk, |
|
dv=dv, |
|
r_ij=r_ij, |
|
d_ij=d_ij, |
|
size=None, |
|
) |
|
x = x.reshape(-1, self.x_hidden_channels) |
|
vec = vec.reshape(-1, 3, self.vec_channels) |
|
|
|
o1, o2, o3 = torch.split(self.o_proj( |
|
x), [self.vec_channels, self.x_channels, self.x_channels], dim=1) |
|
dx = vec_dot * o2 + o3 |
|
dvec = vec3 * o1.unsqueeze(1) + vec |
|
if return_attn: |
|
return dx, dvec, torch.concat((edge_index.T, attn), dim=1) |
|
else: |
|
return dx, dvec, None |
|
|
|
def message(self, q_i, k_j, v_j, vec_j, dk, dv, r_ij, d_ij): |
|
|
|
if dk is None: |
|
attn = (q_i * k_j).sum(dim=-1) |
|
else: |
|
attn = (q_i * k_j * dk).sum(dim=-1) |
|
|
|
|
|
attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1) |
|
|
|
|
|
if dv is not None: |
|
v_j = v_j * dv |
|
x, vec1, vec2 = torch.split( |
|
v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2) |
|
|
|
|
|
x = x * attn.unsqueeze(2) |
|
|
|
vec = vec_j * vec1.unsqueeze(1) + vec2.unsqueeze(1) * \ |
|
d_ij.unsqueeze(2).unsqueeze(3) |
|
return x, vec, attn |
|
|
|
def aggregate( |
|
self, |
|
features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], |
|
index: torch.Tensor, |
|
ptr: Optional[torch.Tensor], |
|
dim_size: Optional[int], |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
x, vec, attn = features |
|
x = scatter(x, index, dim=self.node_dim, dim_size=dim_size) |
|
vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size) |
|
return x, vec, attn |
|
|
|
def update( |
|
self, inputs: Tuple[torch.Tensor, torch.Tensor] |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
return inputs |
|
|
|
def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: |
|
pass |
|
|
|
def edge_update(self) -> Tensor: |
|
pass |
|
|
|
|
|
|
|
class EquivariantMultiHeadAttentionSoftMax(EquivariantMultiHeadAttention): |
|
"""Equivariant multi-head attention layer with softmax""" |
|
|
|
def __init__( |
|
self, |
|
x_channels, |
|
x_hidden_channels, |
|
vec_channels, |
|
vec_hidden_channels, |
|
share_kv, |
|
edge_attr_channels, |
|
distance_influence, |
|
num_heads, |
|
activation, |
|
attn_activation, |
|
cutoff_lower, |
|
cutoff_upper, |
|
use_lora=None, |
|
): |
|
super(EquivariantMultiHeadAttentionSoftMax, self).__init__(x_channels=x_channels, |
|
x_hidden_channels=x_hidden_channels, |
|
vec_channels=vec_channels, |
|
vec_hidden_channels=vec_hidden_channels, |
|
share_kv=share_kv, |
|
edge_attr_channels=edge_attr_channels, |
|
distance_influence=distance_influence, |
|
num_heads=num_heads, |
|
activation=activation, |
|
attn_activation=attn_activation, |
|
cutoff_lower=cutoff_lower, |
|
cutoff_upper=cutoff_upper, |
|
use_lora=use_lora) |
|
self.attn_activation = nn.LeakyReLU(0.2) |
|
|
|
def message(self, q_i, k_j, v_j, vec_j, dk, dv, r_ij, d_ij, |
|
index: Tensor, |
|
ptr: Optional[Tensor], |
|
size_i: Optional[int]): |
|
|
|
if dk is None: |
|
attn = (q_i * k_j).sum(dim=-1) |
|
else: |
|
attn = (q_i * k_j * dk).sum(dim=-1) |
|
|
|
|
|
attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1) |
|
attn = softmax(attn, index, ptr, size_i) |
|
|
|
|
|
|
|
if dv is not None: |
|
v_j = v_j * dv |
|
x, vec1, vec2 = torch.split( |
|
v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2) |
|
|
|
|
|
x = x * attn.unsqueeze(2) |
|
|
|
vec = (vec1.unsqueeze(1) * vec_j + vec2.unsqueeze(1) * d_ij.unsqueeze(2).unsqueeze(3)) \ |
|
* attn.unsqueeze(1).unsqueeze(3) |
|
return x, vec, attn |
|
|
|
|
|
|
|
class EquivariantPAEMultiHeadAttentionSoftMax(EquivariantPAEMultiHeadAttention): |
|
"""Equivariant multi-head attention layer with softmax""" |
|
|
|
def __init__( |
|
self, |
|
x_channels, |
|
x_hidden_channels, |
|
vec_channels, |
|
vec_hidden_channels, |
|
share_kv, |
|
edge_attr_channels, |
|
edge_attr_dist_channels, |
|
distance_influence, |
|
num_heads, |
|
activation, |
|
attn_activation, |
|
cutoff_lower, |
|
cutoff_upper, |
|
use_lora=None, |
|
): |
|
super(EquivariantPAEMultiHeadAttentionSoftMax, self).__init__( |
|
x_channels=x_channels, |
|
x_hidden_channels=x_hidden_channels, |
|
vec_channels=vec_channels, |
|
vec_hidden_channels=vec_hidden_channels, |
|
share_kv=share_kv, |
|
edge_attr_channels=edge_attr_channels, |
|
edge_attr_dist_channels=edge_attr_dist_channels, |
|
distance_influence=distance_influence, |
|
num_heads=num_heads, |
|
activation=activation, |
|
attn_activation=attn_activation, |
|
cutoff_lower=cutoff_lower, |
|
cutoff_upper=cutoff_upper, |
|
use_lora=use_lora) |
|
self.attn_activation = nn.LeakyReLU(0.2) |
|
|
|
def message(self, q_i, k_j, v_j, vec_j, dk, dk_dist, dv, dv_dist, d_ij, w_ij, |
|
index: Tensor, |
|
ptr: Optional[Tensor], |
|
size_i: Optional[int]): |
|
|
|
attn = (q_i * k_j) |
|
if dk is not None: |
|
attn += dk |
|
if dk_dist is not None: |
|
attn += dk_dist * w_ij.unsqueeze(1).unsqueeze(2) |
|
attn = attn.sum(dim=-1) |
|
|
|
attn = self.attn_activation(attn) |
|
attn = softmax(attn, index, ptr, size_i) |
|
|
|
|
|
|
|
if dv is not None: |
|
v_j += dv |
|
if dv_dist is not None: |
|
v_j += dv_dist * w_ij.unsqueeze(1).unsqueeze(2) |
|
x, vec1, vec2 = torch.split( |
|
v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2) |
|
|
|
|
|
x = x * attn.unsqueeze(2) |
|
|
|
vec = (vec1.unsqueeze(1) * vec_j + vec2.unsqueeze(1) * d_ij.unsqueeze(2).unsqueeze(3)) \ |
|
* attn.unsqueeze(1).unsqueeze(3) |
|
return x, vec, attn |
|
|
|
|
|
class EquivariantWeightedPAEMultiHeadAttentionSoftMax(EquivariantWeightedPAEMultiHeadAttention): |
|
"""Equivariant multi-head attention layer with softmax""" |
|
|
|
def __init__( |
|
self, |
|
x_channels, |
|
x_hidden_channels, |
|
vec_channels, |
|
vec_hidden_channels, |
|
share_kv, |
|
edge_attr_channels, |
|
edge_attr_dist_channels, |
|
distance_influence, |
|
num_heads, |
|
activation, |
|
attn_activation, |
|
cutoff_lower, |
|
cutoff_upper, |
|
use_lora=None, |
|
): |
|
super(EquivariantWeightedPAEMultiHeadAttentionSoftMax, self).__init__( |
|
x_channels=x_channels, |
|
x_hidden_channels=x_hidden_channels, |
|
vec_channels=vec_channels, |
|
vec_hidden_channels=vec_hidden_channels, |
|
share_kv=share_kv, |
|
edge_attr_channels=edge_attr_channels, |
|
edge_attr_dist_channels=edge_attr_dist_channels, |
|
distance_influence=distance_influence, |
|
num_heads=num_heads, |
|
activation=activation, |
|
attn_activation=attn_activation, |
|
cutoff_lower=cutoff_lower, |
|
cutoff_upper=cutoff_upper, |
|
use_lora=use_lora) |
|
self.attn_activation = nn.LeakyReLU(0.2) |
|
|
|
def message(self, q_i, k_j, v_j, vec_j, dk, dk_dist, dv, dv_dist, d_ij, w_ij, |
|
index: Tensor, |
|
ptr: Optional[Tensor], |
|
size_i: Optional[int]): |
|
|
|
attn = (q_i * k_j) |
|
if dk_dist is not None: |
|
if dk is not None: |
|
attn *= (dk + dk_dist * w_ij.unsqueeze(1).unsqueeze(2)) |
|
else: |
|
attn *= dk_dist * w_ij |
|
else: |
|
if dk is not None: |
|
attn *= dk |
|
attn = attn.sum(dim=-1) |
|
|
|
attn = self.attn_activation(attn) |
|
attn = softmax(attn, index, ptr, size_i) |
|
|
|
|
|
|
|
if dv is not None: |
|
v_j += dv |
|
if dv_dist is not None: |
|
v_j += dv_dist * w_ij.unsqueeze(1).unsqueeze(2) |
|
x, vec1, vec2 = torch.split( |
|
v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2) |
|
|
|
|
|
x = x * attn.unsqueeze(2) |
|
|
|
vec = (vec1.unsqueeze(1) * vec_j + vec2.unsqueeze(1) * d_ij.unsqueeze(2).unsqueeze(3)) \ |
|
* attn.unsqueeze(1).unsqueeze(3) |
|
return x, vec, attn |
|
|
|
|
|
|
|
class MSAEncoder(nn.Module): |
|
def __init__(self, num_species, pairwise_type, weighting_schema): |
|
"""[summary] |
|
|
|
Args: |
|
num_species (int): Number of species to use from MSA. [1,200] // 200 used in default gMVP |
|
pairwise_type ([str]): method for calculating pairwise coevolution. only "cov" supported |
|
weighting_schema ([str]): species weighting type; "spe" -> use dense layer to weight speices |
|
"none" -> uniformly weight species |
|
|
|
Raises: |
|
NotImplementedError: [description] |
|
""" |
|
super(MSAEncoder, self).__init__() |
|
self.num_species = num_species |
|
self.pairwise_type = pairwise_type |
|
self.weighting_schema = weighting_schema |
|
if self.weighting_schema == 'spe': |
|
self.W = nn.parameter.Parameter( |
|
torch.zeros((1, num_species)), |
|
requires_grad=True) |
|
|
|
elif self.weighting_schema == 'none': |
|
self.W = torch.tensor(1.0 / self.num_species).repeat(self.num_species) |
|
else: |
|
raise NotImplementedError |
|
|
|
def forward(self, x, edge_index): |
|
|
|
shape = x.shape |
|
L, N = shape[0], shape[1] |
|
E = edge_index.shape[1] |
|
|
|
A = 21 |
|
x = x[:, :self.num_species] |
|
if self.weighting_schema == 'spe': |
|
sm = torch.nn.Softmax(dim=-1) |
|
W = sm(self.W) |
|
else: |
|
W = self.W |
|
x = nn.functional.one_hot(x.type(torch.int64), A).type(torch.float32) |
|
x1 = torch.matmul(W[:, None], x) |
|
|
|
if self.pairwise_type == 'fre': |
|
x2 = torch.matmul(x[edge_index[0], :, :, None], x[edge_index[1], :, None, :]) |
|
x2 = x2.reshape((E, N, A * A)) |
|
x2 = (W[:, :, None] * x2).sum(dim=1) |
|
elif self.pairwise_type == 'cov': |
|
|
|
x2 = torch.matmul(x[edge_index[0], :, :, None], x[edge_index[1], :, None, :]) |
|
x2 = (W[:, :, None, None] * x2).sum(dim=1) |
|
x2_t = x1[edge_index[0], 0, :, None] * x1[edge_index[1], 0, None, :] |
|
x2 = (x2 - x2_t).reshape(E, A * A) |
|
x2 = x2.reshape(E, A * A) |
|
norm = torch.sqrt(torch.sum(torch.square(x2), dim=-1, keepdim=True) + 1e-12) |
|
x2 = torch.cat([x2, norm], dim=-1) |
|
elif self.pairwise_type == 'cov_all': |
|
print('cov_all not implemented in EvolEncoder2') |
|
raise NotImplementedError |
|
elif self.pairwise_type == 'inv_cov': |
|
print('in_cov not implemented in EvolEncoder2') |
|
raise NotImplementedError |
|
elif self.pairwise_type == 'none': |
|
x2 = None |
|
else: |
|
raise NotImplementedError( |
|
f'pairwise_type {self.pairwise_type} not implemented') |
|
|
|
x1 = torch.squeeze(x1, dim=1) |
|
|
|
return x1, x2 |
|
|
|
|
|
|
|
class MSAEncoderFullGraph(nn.Module): |
|
def __init__(self, num_species, pairwise_type, weighting_schema): |
|
"""[summary] |
|
|
|
Args: |
|
num_species (int): Number of species to use from MSA. [1,200] // 200 used in default gMVP |
|
pairwise_type ([str]): method for calculating pairwise coevolution. only "cov" supported |
|
weighting_schema ([str]): species weighting type; "spe" -> use dense layer to weight speices |
|
"none" -> uniformly weight species |
|
|
|
Raises: |
|
NotImplementedError: [description] |
|
""" |
|
super(MSAEncoderFullGraph, self).__init__() |
|
self.num_species = num_species |
|
self.pairwise_type = pairwise_type |
|
self.weighting_schema = weighting_schema |
|
if self.weighting_schema == 'spe': |
|
self.W = nn.parameter.Parameter( |
|
torch.zeros((num_species)), |
|
requires_grad=True) |
|
|
|
elif self.weighting_schema == 'none': |
|
self.W = torch.tensor(1.0 / self.num_species).repeat(self.num_species) |
|
else: |
|
raise NotImplementedError |
|
|
|
def forward(self, x): |
|
|
|
shape = x.shape |
|
B, L, N = shape[0], shape[1], shape[2] |
|
A = 21 |
|
x = x[:, :, :self.num_species] |
|
if self.weighting_schema == 'spe': |
|
W = torch.nn.functional.softmax(self.W, dim=-1) |
|
else: |
|
W = self.W |
|
x = nn.functional.one_hot(x.type(torch.int64), A).type(torch.float32) |
|
x1 = torch.einsum('blna,n->bla', x, W) |
|
|
|
if self.pairwise_type == 'cov': |
|
|
|
|
|
|
|
|
|
|
|
x2 = (torch.einsum('bLnA,blna,n->bLlAa', x, x, W) - x1[:, :, None, :, None] * x1[:, None, :, None, :]).reshape(B, L, L, A * A) |
|
norm = torch.sqrt(torch.sum(torch.square(x2), dim=-1, keepdim=True) + 1e-12) |
|
x2 = torch.cat([x2, norm], dim=-1) |
|
elif self.pairwise_type == 'cov_all': |
|
print('cov_all not implemented in EvolEncoder2') |
|
raise NotImplementedError |
|
elif self.pairwise_type == 'inv_cov': |
|
print('in_cov not implemented in EvolEncoder2') |
|
raise NotImplementedError |
|
elif self.pairwise_type == 'none': |
|
x2 = None |
|
else: |
|
raise NotImplementedError( |
|
f'pairwise_type {self.pairwise_type} not implemented') |
|
return x1, x2 |
|
|
|
|
|
class NodeToEdgeAttr(nn.Module): |
|
def __init__(self, node_channel, hidden_channel, edge_attr_channel, use_lora=None, layer_norm=False): |
|
super().__init__() |
|
self.layer_norm = layer_norm |
|
if layer_norm: |
|
self.layernorm = nn.LayerNorm(node_channel) |
|
if use_lora is not None: |
|
self.proj = lora.Linear(node_channel, hidden_channel * 2, bias=True, r=use_lora) |
|
self.o_proj = lora.Linear(2 * hidden_channel, edge_attr_channel, r=use_lora) |
|
else: |
|
self.proj = nn.Linear(node_channel, hidden_channel * 2, bias=True) |
|
self.o_proj = nn.Linear(2 * hidden_channel, edge_attr_channel, bias=True) |
|
|
|
torch.nn.init.zeros_(self.proj.bias) |
|
torch.nn.init.zeros_(self.o_proj.bias) |
|
|
|
def forward(self, x, edge_index): |
|
""" |
|
Inputs: |
|
x: N x sequence_state_dim |
|
|
|
Output: |
|
edge_attr: edge_index.shape[0] x pairwise_state_dim |
|
|
|
Intermediate state: |
|
B x L x L x 2*inner_dim |
|
""" |
|
x = self.layernorm(x) if self.layer_norm else x |
|
q, k = self.proj(x).chunk(2, dim=-1) |
|
|
|
prod = q[edge_index[0], :] * k[edge_index[1], :] |
|
diff = q[edge_index[0], :] - k[edge_index[1], :] |
|
|
|
edge_attr = torch.cat([prod, diff], dim=-1) |
|
edge_attr = self.o_proj(edge_attr) |
|
|
|
return edge_attr |
|
|
|
|
|
class MultiplicativeUpdate(MessagePassing): |
|
def __init__(self, vec_in_channel, hidden_channel, hidden_vec_channel, ee_channels=None, use_lora=None, layer_norm=True) -> None: |
|
super(MultiplicativeUpdate, self).__init__(aggr="mean") |
|
self.vec_in_channel = vec_in_channel |
|
self.hidden_channel = hidden_channel |
|
self.hidden_vec_channel = hidden_vec_channel |
|
|
|
if use_lora is not None: |
|
self.linear_a_p = lora.Linear(self.vec_in_channel, self.hidden_vec_channel, bias=False, r=use_lora) |
|
self.linear_b_p = lora.Linear(self.vec_in_channel, self.hidden_vec_channel, bias=False, r=use_lora) |
|
self.linear_g = lora.Linear(self.hidden_vec_channel, self.hidden_channel, r=use_lora) |
|
else: |
|
self.linear_a_p = nn.Linear(self.vec_in_channel, self.hidden_vec_channel, bias=False) |
|
self.linear_b_p = nn.Linear(self.vec_in_channel, self.hidden_vec_channel, bias=False) |
|
self.linear_g = nn.Linear(self.hidden_vec_channel, self.hidden_channel) |
|
if ee_channels is not None: |
|
if use_lora is not None: |
|
self.linear_ee = lora.Linear(ee_channels, self.hidden_channel, r=use_lora) |
|
else: |
|
self.linear_ee = nn.Linear(ee_channels, self.hidden_channel) |
|
else: |
|
self.linear_ee = None |
|
self.layer_norm = layer_norm |
|
if layer_norm: |
|
self.layer_norm_in = nn.LayerNorm(self.hidden_channel) |
|
self.layer_norm_out = nn.LayerNorm(self.hidden_channel) |
|
|
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, |
|
edge_attr: torch.Tensor, |
|
edge_vec: torch.Tensor, |
|
edge_edge_index: torch.Tensor, |
|
edge_edge_attr: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
edge_vec: |
|
[*, 3, in_channel] input tensor |
|
edge_attr: |
|
[*, hidden_channel] input mask |
|
Returns: |
|
[*, hidden_channel] output tensor |
|
""" |
|
if self.layer_norm: |
|
x = self.layer_norm_in(edge_attr) |
|
x = self.propagate(edge_index=edge_edge_index, |
|
a=self.linear_a_p(edge_vec).reshape(edge_attr.shape[0], -1), |
|
b=self.linear_b_p(edge_vec).reshape(edge_attr.shape[0], -1), |
|
edge_attr=x, |
|
ee_ij=edge_edge_attr, ) |
|
if self.layer_norm: |
|
x = self.layer_norm_out(x) |
|
edge_attr = edge_attr + x |
|
return edge_attr |
|
|
|
def message(self, a_i: Tensor, b_j: Tensor, edge_attr_j: Tensor, ee_ij: Tensor,) -> Tensor: |
|
|
|
|
|
s = (a_i.reshape(-1, 3, self.hidden_vec_channel).permute(0, 2, 1) \ |
|
* b_j.reshape(-1, 3, self.hidden_vec_channel).permute(0, 2, 1)).sum(dim=-1) |
|
if ee_ij is not None and self.linear_ee is not None: |
|
s = self.sigmoid(self.linear_ee(ee_ij) + self.linear_g(s)) |
|
else: |
|
s = self.sigmoid(self.linear_g(s)) |
|
return s * edge_attr_j |
|
|
|
|
|
|
|
class EquivariantTriAngularMultiHeadAttention(MessagePassing): |
|
"""Equivariant multi-head attention layer. Add Triangular update between edges.""" |
|
|
|
def __init__( |
|
self, |
|
x_channels, |
|
x_hidden_channels, |
|
vec_channels, |
|
vec_hidden_channels, |
|
edge_attr_channels, |
|
distance_influence, |
|
num_heads, |
|
activation, |
|
attn_activation, |
|
cutoff_lower, |
|
cutoff_upper, |
|
triangular_update=False, |
|
ee_channels=None, |
|
): |
|
super(EquivariantTriAngularMultiHeadAttention, self).__init__(aggr="mean", node_dim=0) |
|
|
|
self.distance_influence = distance_influence |
|
self.num_heads = num_heads |
|
self.x_channels = x_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.x_head_dim = x_hidden_channels // num_heads |
|
self.vec_channels = vec_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
self.ee_channels = ee_channels |
|
|
|
|
|
self.layernorm_in = nn.LayerNorm(x_channels) |
|
self.layernorm_out = nn.LayerNorm(x_hidden_channels) |
|
|
|
self.act = activation() |
|
self.attn_activation = act_class_mapping[attn_activation]() |
|
|
|
self.q_proj = nn.Linear(x_channels, x_hidden_channels) |
|
self.kv_proj = nn.Linear(x_channels, x_hidden_channels) |
|
|
|
self.o_proj = nn.Linear(x_hidden_channels, x_hidden_channels) |
|
self.out = nn.Linear(x_hidden_channels, x_channels) |
|
|
|
|
|
|
|
self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) |
|
self.triangular_update = triangular_update |
|
if self.triangular_update: |
|
self.edge_triangle_start_update = MultiplicativeUpdate(vec_in_channel=vec_channels, |
|
hidden_channel=edge_attr_channels, |
|
hidden_vec_channel=vec_hidden_channels, |
|
ee_channels=ee_channels, ) |
|
self.edge_triangle_end_update = MultiplicativeUpdate(vec_in_channel=vec_channels, |
|
hidden_channel=edge_attr_channels, |
|
hidden_vec_channel=vec_hidden_channels, |
|
ee_channels=ee_channels, ) |
|
self.node_to_edge_attr = NodeToEdgeAttr(node_channel=x_channels, |
|
hidden_channel=x_hidden_channels, |
|
edge_attr_channel=edge_attr_channels) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
self.layernorm_in.reset_parameters() |
|
self.layernorm_out.reset_parameters() |
|
nn.init.xavier_uniform_(self.q_proj.weight) |
|
self.q_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.kv_proj.weight) |
|
self.kv_proj.bias.data.fill_(0) |
|
|
|
|
|
nn.init.xavier_uniform_(self.o_proj.weight) |
|
self.o_proj.bias.data.fill_(0) |
|
if self.dk_proj: |
|
nn.init.xavier_uniform_(self.dk_proj.weight) |
|
self.dk_proj.bias.data.fill_(0) |
|
|
|
def get_start_index(self, edge_index): |
|
edge_start_index = [] |
|
start_node_count = edge_index[0].unique(return_counts=True) |
|
start_nodes = start_node_count[0][start_node_count[1] > 1] |
|
for i in start_nodes: |
|
node_start_index = torch.where(edge_index[0] == i)[0] |
|
candidates = torch.combinations(node_start_index, r=2).T |
|
edge_start_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) |
|
edge_start_index = torch.concat(edge_start_index, dim=1) |
|
edge_start_index = edge_start_index[:, edge_start_index[0] != edge_start_index[1]] |
|
return edge_start_index |
|
|
|
def get_end_index(self, edge_index): |
|
edge_end_index = [] |
|
end_node_count = edge_index[1].unique(return_counts=True) |
|
end_nodes = end_node_count[0][end_node_count[1] > 1] |
|
for i in end_nodes: |
|
node_end_index = torch.where(edge_index[1] == i)[0] |
|
candidates = torch.combinations(node_end_index, r=2).T |
|
edge_end_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) |
|
edge_end_index = torch.concat(edge_end_index, dim=1) |
|
edge_end_index = edge_end_index[:, edge_end_index[0] != edge_end_index[1]] |
|
return edge_end_index |
|
|
|
def forward(self, x, coords, edge_index, edge_attr, edge_vec, return_attn=False): |
|
residue = x |
|
x = self.layernorm_in(x) |
|
q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim) |
|
k = self.kv_proj(x).reshape(-1, self.num_heads, self.x_head_dim) |
|
v = k |
|
|
|
|
|
if self.triangular_update: |
|
edge_attr += self.node_to_edge_attr(x, edge_index) |
|
|
|
|
|
|
|
edge_edge_index = self.get_start_index(edge_index) |
|
if self.ee_channels is not None: |
|
edge_edge_attr = coords[edge_index[1][edge_edge_index[0]], :, [0]] - coords[edge_index[1][edge_edge_index[1]], :, [0]] |
|
edge_edge_attr = torch.norm(edge_edge_attr, dim=-1, keepdim=True) |
|
else: |
|
edge_edge_attr = None |
|
edge_attr = self.edge_triangle_start_update( |
|
edge_attr, edge_vec, |
|
edge_edge_index, |
|
edge_edge_attr |
|
) |
|
edge_edge_index = self.get_end_index(edge_index) |
|
if self.ee_channels is not None: |
|
edge_edge_attr = coords[edge_index[0][edge_edge_index[0]], :, [0]] - coords[edge_index[0][edge_edge_index[1]], :, [0]] |
|
edge_edge_attr = torch.norm(edge_edge_attr, dim=-1, keepdim=True) |
|
else: |
|
edge_edge_attr = None |
|
edge_attr = self.edge_triangle_end_update( |
|
edge_attr, edge_vec, |
|
edge_edge_index, |
|
edge_edge_attr |
|
) |
|
del edge_edge_attr, edge_edge_index |
|
|
|
dk = ( |
|
self.act(self.dk_proj(edge_attr)).reshape(-1, self.num_heads, self.x_head_dim) |
|
if self.dk_proj is not None else None |
|
) |
|
|
|
|
|
|
|
x, attn = self.propagate( |
|
edge_index, |
|
q=q, |
|
k=k, |
|
v=v, |
|
dk=dk, |
|
size=None, |
|
) |
|
x = x.reshape(-1, self.x_hidden_channels) |
|
x = residue + x |
|
x = self.layernorm_out(x) |
|
x = gelu(self.o_proj(x)) |
|
x = self.out(x) |
|
del residue, q, k, v, dk |
|
if return_attn: |
|
return x, edge_attr, torch.concat((edge_index.T, attn), dim=1) |
|
else: |
|
return x, edge_attr, None |
|
|
|
def message(self, q_i, k_j, v_j, dk): |
|
|
|
if dk is None: |
|
attn = (q_i * k_j).sum(dim=-1) |
|
else: |
|
attn = (q_i * k_j * dk).sum(dim=-1) |
|
|
|
|
|
attn = self.attn_activation(attn) |
|
|
|
|
|
x = v_j * attn.unsqueeze(2) |
|
return x, attn |
|
|
|
def aggregate( |
|
self, |
|
features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], |
|
index: torch.Tensor, |
|
ptr: Optional[torch.Tensor], |
|
dim_size: Optional[int], |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
x, attn = features |
|
x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr) |
|
return x, attn |
|
|
|
def update( |
|
self, inputs: Tuple[torch.Tensor, torch.Tensor] |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
return inputs |
|
|
|
def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: |
|
pass |
|
|
|
def edge_update(self) -> Tensor: |
|
pass |
|
|
|
|
|
|
|
class EquivariantTriAngularDropMultiHeadAttention(MessagePassing): |
|
"""Equivariant multi-head attention layer. Add Triangular update between edges.""" |
|
|
|
def __init__( |
|
self, |
|
x_channels, |
|
x_hidden_channels, |
|
vec_channels, |
|
vec_hidden_channels, |
|
edge_attr_channels, |
|
distance_influence, |
|
num_heads, |
|
activation, |
|
attn_activation, |
|
rbf_channels, |
|
triangular_update=False, |
|
ee_channels=None, |
|
drop_out_rate=0.0, |
|
use_lora=None, |
|
layer_norm=True, |
|
): |
|
super(EquivariantTriAngularDropMultiHeadAttention, self).__init__(aggr="mean", node_dim=0) |
|
|
|
self.distance_influence = distance_influence |
|
self.num_heads = num_heads |
|
self.x_channels = x_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.x_head_dim = x_hidden_channels // num_heads |
|
self.vec_channels = vec_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
self.ee_channels = ee_channels |
|
self.rbf_channels = rbf_channels |
|
self.layer_norm = layer_norm |
|
|
|
if layer_norm: |
|
self.layernorm_in = nn.LayerNorm(x_channels) |
|
self.layernorm_out = nn.LayerNorm(x_hidden_channels) |
|
|
|
self.act = activation() |
|
self.attn_activation = act_class_mapping[attn_activation]() |
|
|
|
if use_lora is not None: |
|
self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) |
|
self.kv_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) |
|
self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) |
|
self.o_proj = lora.Linear(x_hidden_channels, x_hidden_channels, r=use_lora) |
|
else: |
|
self.q_proj = nn.Linear(x_channels, x_hidden_channels) |
|
self.kv_proj = nn.Linear(x_channels, x_hidden_channels) |
|
self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) |
|
self.o_proj = nn.Linear(x_hidden_channels, x_hidden_channels) |
|
|
|
self.triangular_drop = nn.Dropout(drop_out_rate) |
|
self.rbf_drop = nn.Dropout(drop_out_rate) |
|
self.dense_drop = nn.Dropout(drop_out_rate) |
|
self.dropout = nn.Dropout(drop_out_rate) |
|
self.triangular_update = triangular_update |
|
if self.triangular_update: |
|
self.edge_triangle_end_update = MultiplicativeUpdate(vec_in_channel=vec_channels, |
|
hidden_channel=edge_attr_channels, |
|
hidden_vec_channel=vec_hidden_channels, |
|
ee_channels=ee_channels, |
|
layer_norm=layer_norm, |
|
use_lora=use_lora) |
|
self.node_to_edge_attr = NodeToEdgeAttr(node_channel=x_channels, |
|
hidden_channel=x_hidden_channels, |
|
edge_attr_channel=edge_attr_channels, |
|
use_lora=use_lora) |
|
self.triangle_update_dropout = nn.Dropout(0.5) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
if self.layer_norm: |
|
self.layernorm_in.reset_parameters() |
|
self.layernorm_out.reset_parameters() |
|
nn.init.xavier_uniform_(self.q_proj.weight) |
|
self.q_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.kv_proj.weight) |
|
self.kv_proj.bias.data.fill_(0) |
|
|
|
|
|
nn.init.xavier_uniform_(self.o_proj.weight) |
|
self.o_proj.bias.data.fill_(0) |
|
if self.dk_proj: |
|
nn.init.xavier_uniform_(self.dk_proj.weight) |
|
self.dk_proj.bias.data.fill_(0) |
|
|
|
def get_start_index(self, edge_index): |
|
edge_start_index = [] |
|
start_node_count = edge_index[0].unique(return_counts=True) |
|
start_nodes = start_node_count[0][start_node_count[1] > 1] |
|
for i in start_nodes: |
|
node_start_index = torch.where(edge_index[0] == i)[0] |
|
candidates = torch.combinations(node_start_index, r=2).T |
|
edge_start_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) |
|
edge_start_index = torch.concat(edge_start_index, dim=1) |
|
edge_start_index = edge_start_index[:, edge_start_index[0] != edge_start_index[1]] |
|
return edge_start_index |
|
|
|
def get_end_index(self, edge_index): |
|
edge_end_index = [] |
|
end_node_count = edge_index[1].unique(return_counts=True) |
|
end_nodes = end_node_count[0][end_node_count[1] > 1] |
|
for i in end_nodes: |
|
node_end_index = torch.where(edge_index[1] == i)[0] |
|
candidates = torch.combinations(node_end_index, r=2).T |
|
edge_end_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) |
|
edge_end_index = torch.concat(edge_end_index, dim=1) |
|
edge_end_index = edge_end_index[:, edge_end_index[0] != edge_end_index[1]] |
|
return edge_end_index |
|
|
|
def forward(self, x, coords, edge_index, edge_attr, edge_vec, return_attn=False): |
|
residue = x |
|
if self.layer_norm: |
|
x = self.layernorm_in(x) |
|
q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim) |
|
k = self.kv_proj(x).reshape(-1, self.num_heads, self.x_head_dim) |
|
v = k |
|
|
|
|
|
if self.triangular_update: |
|
edge_attr += self.node_to_edge_attr(x, edge_index) |
|
|
|
|
|
edge_edge_index = self.get_end_index(edge_index) |
|
edge_edge_index = edge_edge_index[:, self.triangular_drop( |
|
torch.ones(edge_edge_index.shape[1], device=edge_edge_index.device) |
|
).to(torch.bool)] |
|
if self.ee_channels is not None: |
|
edge_edge_attr = coords[edge_index[0][edge_edge_index[0]], :, [0]] - coords[edge_index[0][edge_edge_index[1]], :, [0]] |
|
edge_edge_attr = torch.norm(edge_edge_attr, dim=-1, keepdim=True) |
|
else: |
|
edge_edge_attr = None |
|
edge_attr = self.edge_triangle_end_update( |
|
edge_attr, edge_vec, |
|
edge_edge_index, |
|
edge_edge_attr |
|
) |
|
del edge_edge_attr, edge_edge_index |
|
|
|
|
|
edge_attr = torch.cat((edge_attr[:, :-self.rbf_channels], |
|
self.rbf_drop(edge_attr[:, -self.rbf_channels:])), |
|
dim=-1) |
|
|
|
dk = ( |
|
self.act(self.dk_proj(edge_attr)).reshape(-1, self.num_heads, self.x_head_dim) |
|
if self.dk_proj is not None else None |
|
) |
|
|
|
|
|
|
|
x, attn = self.propagate( |
|
edge_index, |
|
q=q, |
|
k=k, |
|
v=v, |
|
dk=dk, |
|
size=None, |
|
) |
|
x = x.reshape(-1, self.x_hidden_channels) |
|
if self.layer_norm: |
|
x = self.layernorm_out(x) |
|
x = self.dense_drop(x) |
|
x = residue + gelu(x) |
|
x = self.o_proj(x) |
|
x = self.dropout(x) |
|
del residue, q, k, v, dk |
|
if return_attn: |
|
return x, edge_attr, torch.concat((edge_index.T, attn), dim=1) |
|
else: |
|
return x, edge_attr, None |
|
|
|
def message(self, q_i, k_j, v_j, dk): |
|
|
|
if dk is None: |
|
attn = (q_i * k_j).sum(dim=-1) |
|
else: |
|
attn = (q_i * k_j * dk).sum(dim=-1) |
|
|
|
|
|
attn = self.attn_activation(attn) |
|
|
|
|
|
x = v_j * attn.unsqueeze(2) |
|
return x, attn |
|
|
|
def aggregate( |
|
self, |
|
features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], |
|
index: torch.Tensor, |
|
ptr: Optional[torch.Tensor], |
|
dim_size: Optional[int], |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
x, attn = features |
|
x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr) |
|
return x, attn |
|
|
|
def update( |
|
self, inputs: Tuple[torch.Tensor, torch.Tensor] |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
return inputs |
|
|
|
def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: |
|
pass |
|
|
|
def edge_update(self) -> Tensor: |
|
pass |
|
|
|
|
|
|
|
class EquivariantTriAngularStarMultiHeadAttention(MessagePassing): |
|
""" |
|
Equivariant multi-head attention layer. Add Triangular update between edges. Only update the center node. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_channels, |
|
x_hidden_channels, |
|
vec_channels, |
|
vec_hidden_channels, |
|
edge_attr_channels, |
|
distance_influence, |
|
num_heads, |
|
activation, |
|
attn_activation, |
|
cutoff_lower, |
|
cutoff_upper, |
|
triangular_update=False, |
|
ee_channels=None, |
|
): |
|
super(EquivariantTriAngularStarMultiHeadAttention, self).__init__(aggr="mean", node_dim=0) |
|
|
|
self.distance_influence = distance_influence |
|
self.num_heads = num_heads |
|
self.x_channels = x_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.x_head_dim = x_hidden_channels // num_heads |
|
self.vec_channels = vec_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
self.ee_channels = ee_channels |
|
|
|
|
|
|
|
self.layernorm_out = nn.LayerNorm(x_hidden_channels) |
|
|
|
self.act = activation() |
|
self.attn_activation = act_class_mapping[attn_activation]() |
|
|
|
self.q_proj = nn.Linear(x_channels, x_hidden_channels) |
|
self.kv_proj = nn.Linear(x_channels, x_hidden_channels) |
|
|
|
|
|
|
|
|
|
|
|
self.gru = nn.GRUCell(x_channels, x_channels) |
|
self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) |
|
self.triangular_update = triangular_update |
|
if self.triangular_update: |
|
|
|
|
|
|
|
|
|
self.edge_triangle_end_update = MultiplicativeUpdate(vec_in_channel=vec_channels, |
|
hidden_channel=edge_attr_channels, |
|
hidden_vec_channel=vec_hidden_channels, |
|
ee_channels=ee_channels, ) |
|
self.node_to_edge_attr = NodeToEdgeAttr(node_channel=x_channels, |
|
hidden_channel=x_hidden_channels, |
|
edge_attr_channel=edge_attr_channels) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
|
|
self.layernorm_out.reset_parameters() |
|
nn.init.xavier_uniform_(self.q_proj.weight) |
|
self.q_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.kv_proj.weight) |
|
self.kv_proj.bias.data.fill_(0) |
|
|
|
|
|
|
|
|
|
if self.dk_proj: |
|
nn.init.xavier_uniform_(self.dk_proj.weight) |
|
self.dk_proj.bias.data.fill_(0) |
|
|
|
def get_start_index(self, edge_index): |
|
edge_start_index = [] |
|
start_node_count = edge_index[0].unique(return_counts=True) |
|
start_nodes = start_node_count[0][start_node_count[1] > 1] |
|
for i in start_nodes: |
|
node_start_index = torch.where(edge_index[0] == i)[0] |
|
candidates = torch.combinations(node_start_index, r=2).T |
|
edge_start_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) |
|
edge_start_index = torch.concat(edge_start_index, dim=1) |
|
edge_start_index = edge_start_index[:, edge_start_index[0] != edge_start_index[1]] |
|
return edge_start_index |
|
|
|
def get_end_index(self, edge_index): |
|
edge_end_index = [] |
|
end_node_count = edge_index[1].unique(return_counts=True) |
|
end_nodes = end_node_count[0][end_node_count[1] > 1] |
|
for i in end_nodes: |
|
node_end_index = torch.where(edge_index[1] == i)[0] |
|
candidates = torch.combinations(node_end_index, r=2).T |
|
edge_end_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) |
|
edge_end_index = torch.concat(edge_end_index, dim=1) |
|
edge_end_index = edge_end_index[:, edge_end_index[0] != edge_end_index[1]] |
|
return edge_end_index |
|
|
|
def forward(self, x, coords, edge_index, edge_attr, edge_vec, return_attn=False): |
|
|
|
end_node_count = edge_index[1].unique(return_counts=True) |
|
center_nodes = end_node_count[0][end_node_count[1] > 1] |
|
other_nodes = end_node_count[0][end_node_count[1] <= 1] |
|
residue = x[center_nodes] |
|
|
|
edge_attr = edge_attr[torch.isin(edge_index[1], center_nodes), :] |
|
edge_vec = edge_vec[torch.isin(edge_index[1], center_nodes), :] |
|
edge_index = edge_index[:, torch.isin(edge_index[1], center_nodes)] |
|
|
|
q = self.q_proj(residue).reshape(-1, self.num_heads, self.x_head_dim) |
|
kv = self.kv_proj(x[other_nodes]).reshape(-1, self.num_heads, self.x_head_dim) |
|
qkv = torch.zeros(x.shape[0], self.num_heads, self.x_head_dim).to(x.device, non_blocking=True) |
|
qkv[center_nodes] = q |
|
qkv[other_nodes] = kv |
|
|
|
if self.triangular_update: |
|
edge_attr += self.node_to_edge_attr(x, edge_index) |
|
|
|
|
|
edge_edge_index = self.get_end_index(edge_index) |
|
if self.ee_channels is not None: |
|
edge_edge_attr = coords[edge_index[0][edge_edge_index[0]], :, [0]] - coords[edge_index[0][edge_edge_index[1]], :, [0]] |
|
edge_edge_attr = torch.norm(edge_edge_attr, dim=-1, keepdim=True) |
|
else: |
|
edge_edge_attr = None |
|
edge_attr = self.edge_triangle_end_update( |
|
edge_attr, edge_vec, |
|
edge_edge_index, |
|
edge_edge_attr |
|
) |
|
del edge_edge_attr, edge_edge_index |
|
|
|
dk = ( |
|
self.act(self.dk_proj(edge_attr)).reshape(-1, self.num_heads, self.x_head_dim) |
|
if self.dk_proj is not None else None |
|
) |
|
|
|
|
|
|
|
x, attn = self.propagate( |
|
edge_index, |
|
q=qkv, |
|
k=qkv, |
|
v=qkv, |
|
dk=dk, |
|
size=None, |
|
) |
|
x = x.reshape(-1, self.x_hidden_channels) |
|
|
|
x = x[center_nodes] |
|
x = self.layernorm_out(x) |
|
x = self.gru(residue, x) |
|
del residue, dk |
|
if return_attn: |
|
return x, edge_attr, torch.concat((edge_index.T, attn), dim=1) |
|
else: |
|
return x, edge_attr, None |
|
|
|
def message(self, q_i, k_j, v_j, dk): |
|
|
|
if dk is None: |
|
attn = (q_i * k_j).sum(dim=-1) |
|
else: |
|
attn = (q_i * k_j + dk).sum(dim=-1) |
|
|
|
|
|
attn = self.attn_activation(attn) / self.x_head_dim |
|
|
|
|
|
x = v_j * attn.unsqueeze(2) |
|
return x, attn |
|
|
|
def aggregate( |
|
self, |
|
features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], |
|
index: torch.Tensor, |
|
ptr: Optional[torch.Tensor], |
|
dim_size: Optional[int], |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
x, attn = features |
|
x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr) |
|
return x, attn |
|
|
|
def update( |
|
self, inputs: Tuple[torch.Tensor, torch.Tensor] |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
return inputs |
|
|
|
def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: |
|
pass |
|
|
|
def edge_update(self) -> Tensor: |
|
pass |
|
|
|
|
|
|
|
class EquivariantTriAngularStarDropMultiHeadAttention(MessagePassing): |
|
""" |
|
Equivariant multi-head attention layer. Add Triangular update between edges. Only update the center node. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_channels, |
|
x_hidden_channels, |
|
vec_channels, |
|
vec_hidden_channels, |
|
edge_attr_channels, |
|
distance_influence, |
|
num_heads, |
|
activation, |
|
attn_activation, |
|
rbf_channels, |
|
triangular_update=False, |
|
ee_channels=None, |
|
drop_out_rate=0.0, |
|
use_lora=None, |
|
): |
|
super(EquivariantTriAngularStarDropMultiHeadAttention, self).__init__(aggr="mean", node_dim=0) |
|
|
|
self.distance_influence = distance_influence |
|
self.num_heads = num_heads |
|
self.x_channels = x_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.x_head_dim = x_hidden_channels // num_heads |
|
self.vec_channels = vec_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
self.ee_channels = ee_channels |
|
self.rbf_channels = rbf_channels |
|
|
|
|
|
|
|
self.layernorm_out = nn.LayerNorm(x_hidden_channels) |
|
|
|
self.act = activation() |
|
self.attn_activation = act_class_mapping[attn_activation]() |
|
|
|
if use_lora is not None: |
|
self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) |
|
self.kv_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) |
|
self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) |
|
else: |
|
self.q_proj = nn.Linear(x_channels, x_hidden_channels) |
|
self.kv_proj = nn.Linear(x_channels, x_hidden_channels) |
|
self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) |
|
|
|
|
|
|
|
|
|
|
|
self.gru = nn.GRUCell(x_channels, x_channels) |
|
|
|
self.triangular_drop = nn.Dropout(drop_out_rate) |
|
self.rbf_drop = nn.Dropout(drop_out_rate) |
|
self.dense_drop = nn.Dropout(drop_out_rate) |
|
self.dropout = nn.Dropout(drop_out_rate) |
|
self.triangular_update = triangular_update |
|
if self.triangular_update: |
|
self.edge_triangle_end_update = MultiplicativeUpdate(vec_in_channel=vec_channels, |
|
hidden_channel=edge_attr_channels, |
|
hidden_vec_channel=vec_hidden_channels, |
|
ee_channels=ee_channels, |
|
use_lora=use_lora) |
|
self.node_to_edge_attr = NodeToEdgeAttr(node_channel=x_channels, |
|
hidden_channel=x_hidden_channels, |
|
edge_attr_channel=edge_attr_channels, |
|
use_lora=use_lora) |
|
self.triangle_update_dropout = nn.Dropout(0.5) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
|
|
self.layernorm_out.reset_parameters() |
|
nn.init.xavier_uniform_(self.q_proj.weight) |
|
self.q_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.kv_proj.weight) |
|
self.kv_proj.bias.data.fill_(0) |
|
|
|
|
|
|
|
|
|
if self.dk_proj: |
|
nn.init.xavier_uniform_(self.dk_proj.weight) |
|
self.dk_proj.bias.data.fill_(0) |
|
|
|
def get_start_index(self, edge_index): |
|
edge_start_index = [] |
|
start_node_count = edge_index[0].unique(return_counts=True) |
|
start_nodes = start_node_count[0][start_node_count[1] > 1] |
|
for i in start_nodes: |
|
node_start_index = torch.where(edge_index[0] == i)[0] |
|
candidates = torch.combinations(node_start_index, r=2).T |
|
edge_start_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) |
|
edge_start_index = torch.concat(edge_start_index, dim=1) |
|
edge_start_index = edge_start_index[:, edge_start_index[0] != edge_start_index[1]] |
|
return edge_start_index |
|
|
|
def get_end_index(self, edge_index): |
|
edge_end_index = [] |
|
end_node_count = edge_index[1].unique(return_counts=True) |
|
end_nodes = end_node_count[0][end_node_count[1] > 1] |
|
for i in end_nodes: |
|
node_end_index = torch.where(edge_index[1] == i)[0] |
|
candidates = torch.combinations(node_end_index, r=2).T |
|
edge_end_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) |
|
edge_end_index = torch.concat(edge_end_index, dim=1) |
|
edge_end_index = edge_end_index[:, edge_end_index[0] != edge_end_index[1]] |
|
return edge_end_index |
|
|
|
def forward(self, x, coords, edge_index, edge_attr, edge_vec, return_attn=False): |
|
|
|
end_node_count = edge_index[1].unique(return_counts=True) |
|
center_nodes = end_node_count[0][end_node_count[1] > 1] |
|
other_nodes = end_node_count[0][end_node_count[1] <= 1] |
|
residue = x[center_nodes] |
|
|
|
edge_attr = edge_attr[torch.isin(edge_index[1], center_nodes), :] |
|
edge_vec = edge_vec[torch.isin(edge_index[1], center_nodes), :] |
|
edge_index = edge_index[:, torch.isin(edge_index[1], center_nodes)] |
|
|
|
q = self.q_proj(residue).reshape(-1, self.num_heads, self.x_head_dim) |
|
kv = self.kv_proj(x[other_nodes]).reshape(-1, self.num_heads, self.x_head_dim) |
|
qkv = torch.zeros(x.shape[0], self.num_heads, self.x_head_dim).to(x.device, non_blocking=True) |
|
qkv[center_nodes] = q |
|
qkv[other_nodes] = kv |
|
|
|
if self.triangular_update: |
|
edge_attr += self.node_to_edge_attr(x, edge_index) |
|
|
|
|
|
edge_edge_index = self.get_end_index(edge_index) |
|
edge_edge_index = edge_edge_index[:, self.triangular_drop( |
|
torch.ones(edge_edge_index.shape[1], device=edge_edge_index.device) |
|
).to(torch.bool)] |
|
if self.ee_channels is not None: |
|
edge_edge_attr = coords[edge_index[0][edge_edge_index[0]], :, [0]] - coords[edge_index[0][edge_edge_index[1]], :, [0]] |
|
edge_edge_attr = torch.norm(edge_edge_attr, dim=-1, keepdim=True) |
|
else: |
|
edge_edge_attr = None |
|
edge_attr = self.edge_triangle_end_update( |
|
edge_attr, edge_vec, |
|
edge_edge_index, |
|
edge_edge_attr |
|
) |
|
del edge_edge_attr, edge_edge_index |
|
|
|
|
|
edge_attr = torch.cat((edge_attr[:, :-self.rbf_channels], |
|
self.rbf_drop(edge_attr[:, -self.rbf_channels:])), |
|
dim=-1) |
|
|
|
dk = ( |
|
self.act(self.dk_proj(edge_attr)).reshape(-1, self.num_heads, self.x_head_dim) |
|
if self.dk_proj is not None else None |
|
) |
|
|
|
|
|
|
|
x, attn = self.propagate( |
|
edge_index, |
|
q=qkv, |
|
k=qkv, |
|
v=qkv, |
|
dk=dk, |
|
size=None, |
|
) |
|
x = x.reshape(-1, self.x_hidden_channels) |
|
|
|
x = x[center_nodes] |
|
x = self.layernorm_out(x) |
|
x = self.dense_drop(x) |
|
x = self.gru(residue, x) |
|
x = self.dropout(x) |
|
del residue, dk |
|
if return_attn: |
|
return x, edge_attr, torch.concat((edge_index.T, attn), dim=1) |
|
else: |
|
return x, edge_attr, None |
|
|
|
def message(self, q_i, k_j, v_j, dk): |
|
|
|
if dk is None: |
|
attn = (q_i * k_j).sum(dim=-1) |
|
else: |
|
attn = (q_i * k_j + dk).sum(dim=-1) |
|
|
|
|
|
attn = self.attn_activation(attn) / self.x_head_dim |
|
|
|
|
|
x = v_j * attn.unsqueeze(2) |
|
return x, attn |
|
|
|
def aggregate( |
|
self, |
|
features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], |
|
index: torch.Tensor, |
|
ptr: Optional[torch.Tensor], |
|
dim_size: Optional[int], |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
x, attn = features |
|
x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr) |
|
return x, attn |
|
|
|
def update( |
|
self, inputs: Tuple[torch.Tensor, torch.Tensor] |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
return inputs |
|
|
|
def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: |
|
pass |
|
|
|
def edge_update(self) -> Tensor: |
|
pass |
|
|
|
|
|
|
|
class PairFeatureNet(nn.Module): |
|
|
|
def __init__(self, c_s, c_p, relpos_k=32, template_type="exp-normal-smearing-distance"): |
|
super(PairFeatureNet, self).__init__() |
|
|
|
self.c_s = c_s |
|
self.c_p = c_p |
|
|
|
self.linear_s_p_i = nn.Linear(c_s, c_p) |
|
self.linear_s_p_j = nn.Linear(c_s, c_p) |
|
|
|
self.relpos_k = relpos_k |
|
self.n_bin = 2 * relpos_k + 1 |
|
self.linear_relpos = nn.Linear(self.n_bin, c_p) |
|
|
|
|
|
self.template_fn, c_template = get_template_fn(template_type) |
|
self.linear_template = nn.Linear(c_template, c_p) |
|
|
|
def relpos(self, r): |
|
|
|
|
|
|
|
|
|
|
|
d = r[:, :, None] - r[:, None, :] |
|
|
|
|
|
v = torch.arange(-self.relpos_k, self.relpos_k + 1).to(r.device, non_blocking=True) |
|
|
|
|
|
v_reshaped = v.view(*((1,) * len(d.shape) + (len(v),))) |
|
|
|
|
|
b = torch.argmin(torch.abs(d[:, :, :, None] - v_reshaped), dim=-1) |
|
|
|
|
|
oh = nn.functional.one_hot(b, num_classes=len(v)).float() |
|
|
|
|
|
p = self.linear_relpos(oh) |
|
|
|
return p |
|
|
|
def template(self, t): |
|
return self.linear_template(self.template_fn(t)) |
|
|
|
def forward(self, s, t, r, mask): |
|
|
|
p_mask = mask.unsqueeze(1) * mask.unsqueeze(2) |
|
|
|
p_i = self.linear_s_p_i(s) |
|
p_j = self.linear_s_p_j(s) |
|
|
|
|
|
p = p_i[:, :, None, :] + p_j[:, None, :, :] |
|
|
|
|
|
p += self.relpos(r) |
|
p += self.template(t) |
|
|
|
|
|
p *= p_mask.unsqueeze(-1) |
|
|
|
return p |
|
|
|
|
|
|
|
|
|
class TriangularSelfAttentionBlock(nn.Module): |
|
def __init__( |
|
self, |
|
sequence_state_dim, |
|
pairwise_state_dim, |
|
sequence_head_width, |
|
pairwise_head_width, |
|
dropout=0, |
|
**__kwargs, |
|
): |
|
super().__init__() |
|
from openfold.model.triangular_multiplicative_update import ( |
|
TriangleMultiplicationIncoming, |
|
TriangleMultiplicationOutgoing, |
|
) |
|
from esm.esmfold.v1.misc import ( |
|
Attention, |
|
Dropout, |
|
PairToSequence, |
|
ResidueMLP, |
|
SequenceToPair, |
|
) |
|
assert sequence_state_dim % sequence_head_width == 0 |
|
assert pairwise_state_dim % pairwise_head_width == 0 |
|
sequence_num_heads = sequence_state_dim // sequence_head_width |
|
pairwise_num_heads = pairwise_state_dim // pairwise_head_width |
|
assert sequence_state_dim == sequence_num_heads * sequence_head_width |
|
assert pairwise_state_dim == pairwise_num_heads * pairwise_head_width |
|
assert pairwise_state_dim % 2 == 0 |
|
|
|
self.sequence_state_dim = sequence_state_dim |
|
self.pairwise_state_dim = pairwise_state_dim |
|
|
|
self.layernorm_1 = nn.LayerNorm(sequence_state_dim) |
|
|
|
self.sequence_to_pair = SequenceToPair( |
|
sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim |
|
) |
|
self.pair_to_sequence = PairToSequence( |
|
pairwise_state_dim, sequence_num_heads) |
|
|
|
self.seq_attention = Attention( |
|
sequence_state_dim, sequence_num_heads, sequence_head_width, gated=True |
|
) |
|
self.tri_mul_out = TriangleMultiplicationOutgoing( |
|
pairwise_state_dim, |
|
pairwise_state_dim, |
|
) |
|
self.tri_mul_in = TriangleMultiplicationIncoming( |
|
pairwise_state_dim, |
|
pairwise_state_dim, |
|
) |
|
|
|
self.mlp_seq = ResidueMLP( |
|
sequence_state_dim, 4 * sequence_state_dim, dropout=dropout) |
|
self.mlp_pair = ResidueMLP( |
|
pairwise_state_dim, 4 * pairwise_state_dim, dropout=dropout) |
|
|
|
assert dropout < 0.4 |
|
self.drop = nn.Dropout(dropout) |
|
self.row_drop = Dropout(dropout * 2, 2) |
|
self.col_drop = Dropout(dropout * 2, 1) |
|
|
|
torch.nn.init.zeros_(self.tri_mul_in.linear_z.weight) |
|
torch.nn.init.zeros_(self.tri_mul_in.linear_z.bias) |
|
torch.nn.init.zeros_(self.tri_mul_out.linear_z.weight) |
|
torch.nn.init.zeros_(self.tri_mul_out.linear_z.bias) |
|
|
|
torch.nn.init.zeros_(self.sequence_to_pair.o_proj.weight) |
|
torch.nn.init.zeros_(self.sequence_to_pair.o_proj.bias) |
|
torch.nn.init.zeros_(self.pair_to_sequence.linear.weight) |
|
torch.nn.init.zeros_(self.seq_attention.o_proj.weight) |
|
torch.nn.init.zeros_(self.seq_attention.o_proj.bias) |
|
torch.nn.init.zeros_(self.mlp_seq.mlp[-2].weight) |
|
torch.nn.init.zeros_(self.mlp_seq.mlp[-2].bias) |
|
torch.nn.init.zeros_(self.mlp_pair.mlp[-2].weight) |
|
torch.nn.init.zeros_(self.mlp_pair.mlp[-2].bias) |
|
|
|
def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs): |
|
""" |
|
Inputs: |
|
sequence_state: B x L x sequence_state_dim |
|
pairwise_state: B x L x L x pairwise_state_dim |
|
mask: B x L boolean tensor of valid positions |
|
|
|
Output: |
|
sequence_state: B x L x sequence_state_dim |
|
pairwise_state: B x L x L x pairwise_state_dim |
|
""" |
|
assert len(sequence_state.shape) == 3 |
|
assert len(pairwise_state.shape) == 4 |
|
if mask is not None: |
|
assert len(mask.shape) == 2 |
|
|
|
batch_dim, seq_dim, sequence_state_dim = sequence_state.shape |
|
pairwise_state_dim = pairwise_state.shape[3] |
|
assert sequence_state_dim == self.sequence_state_dim |
|
assert pairwise_state_dim == self.pairwise_state_dim |
|
assert batch_dim == pairwise_state.shape[0] |
|
assert seq_dim == pairwise_state.shape[1] |
|
assert seq_dim == pairwise_state.shape[2] |
|
|
|
|
|
bias = self.pair_to_sequence(pairwise_state) |
|
|
|
|
|
y = self.layernorm_1(sequence_state) |
|
y, _ = self.seq_attention(y, mask=mask, bias=bias) |
|
sequence_state = sequence_state + self.drop(y) |
|
sequence_state = self.mlp_seq(sequence_state) |
|
|
|
|
|
pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state) |
|
|
|
|
|
tri_mask = mask.unsqueeze( |
|
2) * mask.unsqueeze(1) if mask is not None else None |
|
pairwise_state = pairwise_state + self.row_drop( |
|
self.tri_mul_out(pairwise_state, mask=tri_mask) |
|
) |
|
pairwise_state = pairwise_state + self.col_drop( |
|
self.tri_mul_in(pairwise_state, mask=tri_mask) |
|
) |
|
|
|
|
|
pairwise_state = self.mlp_pair(pairwise_state) |
|
|
|
return sequence_state, pairwise_state |
|
|
|
|
|
|
|
class SeqPairAttentionOutput(nn.Module): |
|
def __init__(self, seq_state_dim, pairwise_state_dim, num_heads, output_dim, dropout): |
|
super(SeqPairAttentionOutput, self).__init__() |
|
from esm.esmfold.v1.misc import ( |
|
Attention, |
|
PairToSequence, |
|
ResidueMLP, |
|
) |
|
self.seq_state_dim = seq_state_dim |
|
self.pairwise_state_dim = pairwise_state_dim |
|
self.output_dim = output_dim |
|
seq_head_width = seq_state_dim // num_heads |
|
|
|
self.layernorm = nn.LayerNorm(seq_state_dim) |
|
self.seq_attention = Attention( |
|
seq_state_dim, num_heads, seq_head_width, gated=True |
|
) |
|
self.pair_to_sequence = PairToSequence(pairwise_state_dim, num_heads) |
|
self.mlp_seq = ResidueMLP( |
|
seq_state_dim, 4 * seq_state_dim, dropout=dropout) |
|
self.drop = nn.Dropout(dropout) |
|
|
|
def forward(self, sequence_state, pairwise_state, mask=None): |
|
|
|
bias = self.pair_to_sequence(pairwise_state) |
|
|
|
|
|
y = self.layernorm(sequence_state) |
|
y, _ = self.seq_attention(y, mask=mask, bias=bias) |
|
sequence_state = sequence_state + self.drop(y) |
|
sequence_state = self.mlp_seq(sequence_state) |
|
|
|
return sequence_state |
|
|