|
from typing import Tuple, List |
|
import torch |
|
from torch import _dynamo |
|
_dynamo.config.suppress_errors = True |
|
from torch import Tensor, nn |
|
import loralib as lora |
|
import math |
|
import esm |
|
from ..module.utils import ( |
|
NeighborEmbedding, |
|
Distance, |
|
DistanceV2, |
|
rbf_class_mapping, |
|
act_class_mapping |
|
) |
|
from ..module.attention import ( |
|
EquivariantMultiHeadAttention, |
|
EquivariantMultiHeadAttentionSoftMax, |
|
EquivariantPAEMultiHeadAttention, |
|
EquivariantPAEMultiHeadAttentionSoftMax, |
|
EquivariantWeightedPAEMultiHeadAttention, |
|
EquivariantWeightedPAEMultiHeadAttentionSoftMax, |
|
EquivariantPAEMultiHeadAttentionSoftMaxFullGraph, |
|
MultiHeadAttentionSoftMaxFullGraph, |
|
MSAEncoderFullGraph, |
|
EquivariantTriAngularMultiHeadAttention, |
|
EquivariantTriAngularStarMultiHeadAttention, |
|
EquivariantTriAngularStarDropMultiHeadAttention, |
|
EquivariantTriAngularDropMultiHeadAttention, |
|
PairFeatureNet, |
|
TriangularSelfAttentionBlock, |
|
SeqPairAttentionOutput, |
|
MSAEncoder, |
|
ESMMultiheadAttention |
|
) |
|
|
|
|
|
class PassForward(nn.Module): |
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnorm", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=True, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=False, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
): |
|
super(PassForward, self).__init__() |
|
self.x_in_channels = x_in_channels |
|
self.x_channels = x_channels |
|
|
|
def reset_parameters(self): |
|
pass |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
pos: Tensor, |
|
batch: Tensor, |
|
edge_index: Tensor, |
|
edge_index_star: Tensor = None, |
|
edge_attr: Tensor = None, |
|
edge_attr_star: Tensor = None, |
|
edge_vec: Tensor = None, |
|
edge_vec_star: Tensor = None, |
|
node_vec_attr: Tensor = None, |
|
return_attn: bool = False, |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
|
|
|
vec = node_vec_attr |
|
attn_weight_layers = [] |
|
return x, vec, pos, edge_attr, batch, attn_weight_layers |
|
|
|
|
|
class ESMTransformerLayer(nn.Module): |
|
"""Transformer layer block.""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim, |
|
ffn_embed_dim, |
|
attention_heads, |
|
add_bias_kv=True, |
|
use_esm1b_layer_norm=False, |
|
use_rotary_embeddings: bool = False, |
|
): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.ffn_embed_dim = ffn_embed_dim |
|
self.attention_heads = attention_heads |
|
self.use_rotary_embeddings = use_rotary_embeddings |
|
self._init_submodules(add_bias_kv, use_esm1b_layer_norm) |
|
|
|
def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm): |
|
BertLayerNorm = nn.LayerNorm |
|
|
|
self.self_attn = ESMMultiheadAttention( |
|
self.embed_dim, |
|
self.attention_heads, |
|
add_bias_kv=add_bias_kv, |
|
add_zero_attn=False, |
|
use_rotary_embeddings=self.use_rotary_embeddings, |
|
) |
|
self.self_attn_layer_norm = BertLayerNorm(self.embed_dim) |
|
|
|
self.fc1 = lora.Linear(self.embed_dim, self.ffn_embed_dim, r=16) |
|
self.fc2 = lora.Linear(self.ffn_embed_dim, self.embed_dim, r=16) |
|
|
|
self.final_layer_norm = BertLayerNorm(self.embed_dim) |
|
|
|
def gelu(self, x): |
|
"""Implementation of the gelu activation function. |
|
|
|
For information: OpenAI GPT's gelu is slightly different |
|
(and gives slightly different results): |
|
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
|
""" |
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
|
|
|
def forward( |
|
self, x, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False |
|
): |
|
residual = x |
|
x = self.self_attn_layer_norm(x) |
|
x, attn = self.self_attn( |
|
query=x, |
|
key=x, |
|
value=x, |
|
key_padding_mask=self_attn_padding_mask, |
|
need_weights=True, |
|
need_head_weights=need_head_weights, |
|
attn_mask=self_attn_mask, |
|
) |
|
x = residual + x |
|
|
|
residual = x |
|
x = self.final_layer_norm(x) |
|
x = self.gelu(self.fc1(x)) |
|
x = self.fc2(x) |
|
x = residual + x |
|
|
|
return x, attn |
|
|
|
|
|
class LoRAESM2(nn.Module): |
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnorm", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=True, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=False, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
): |
|
super(LoRAESM2, self).__init__() |
|
self.x_in_channels = x_in_channels |
|
self.x_channels = 1280 |
|
self.num_layers = 33 |
|
self.embed_dim = 1280 |
|
self.attention_heads = 20 |
|
self.embed_scale = 1 |
|
_, alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
self.alphabet = alphabet |
|
self.alphabet_size = len(alphabet) |
|
self.padding_idx = alphabet.padding_idx |
|
self.mask_idx = alphabet.mask_idx |
|
self.cls_idx = alphabet.cls_idx |
|
self.eos_idx = alphabet.eos_idx |
|
self.prepend_bos = alphabet.prepend_bos |
|
self.append_eos = alphabet.append_eos |
|
self.token_dropout = True |
|
|
|
|
|
self.embed_tokens = lora.Embedding( |
|
self.alphabet_size, |
|
self.embed_dim, |
|
padding_idx=self.padding_idx, |
|
r=16, |
|
) |
|
self.layers = nn.ModuleList( |
|
[ |
|
ESMTransformerLayer( |
|
self.embed_dim, |
|
4 * self.embed_dim, |
|
self.attention_heads, |
|
add_bias_kv=False, |
|
use_esm1b_layer_norm=True, |
|
use_rotary_embeddings=True, |
|
) |
|
for _ in range(self.num_layers) |
|
] |
|
) |
|
self.emb_layer_norm_after = nn.LayerNorm(self.embed_dim) |
|
|
|
def reset_parameters(self): |
|
|
|
esm_weights, _ = esm.pretrained.esm2_t33_650M_UR50D() |
|
self.load_state_dict(esm_weights.state_dict(), strict=False) |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
pos: Tensor, |
|
batch: Tensor, |
|
edge_index: Tensor, |
|
edge_index_star: Tensor = None, |
|
edge_attr: Tensor = None, |
|
edge_attr_star: Tensor = None, |
|
edge_vec: Tensor = None, |
|
edge_vec_star: Tensor = None, |
|
node_vec_attr: Tensor = None, |
|
return_attn: bool = False, |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
|
|
|
vec = node_vec_attr |
|
attn_weight_layers = [] |
|
tokens = x |
|
|
|
assert tokens.ndim == 2 |
|
padding_mask = tokens.eq(self.padding_idx) |
|
|
|
x = self.embed_scale * self.embed_tokens(tokens) |
|
|
|
if self.token_dropout: |
|
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0) |
|
|
|
mask_ratio_train = 0.15 * 0.8 |
|
src_lengths = (~padding_mask).sum(-1) |
|
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths |
|
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] |
|
|
|
if padding_mask is not None: |
|
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
if not padding_mask.any(): |
|
padding_mask = None |
|
|
|
for _, layer in enumerate(self.layers): |
|
x, attn = layer( |
|
x, |
|
self_attn_padding_mask=padding_mask, |
|
need_head_weights=False, |
|
) |
|
attn_weight_layers.append(attn) |
|
|
|
x = self.emb_layer_norm_after(x) |
|
x = x.transpose(0, 1) |
|
|
|
return x, vec, pos, edge_attr, batch, attn_weight_layers |
|
|
|
|
|
class eqTransformer(nn.Module): |
|
"""The equivariant Transformer architecture. |
|
|
|
Args: |
|
x_channels (int, optional): Hidden embedding size. |
|
(default: :obj:`128`) |
|
num_layers (int, optional): The number of attention layers. |
|
(default: :obj:`6`) |
|
num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
|
(default: :obj:`50`) |
|
rbf_type (string, optional): The type of radial basis function to use. |
|
(default: :obj:`"expnorm"`) |
|
trainable_rbf (bool, optional): Whether to train RBF parameters with |
|
backpropagation. (default: :obj:`True`) |
|
activation (string, optional): The type of activation function to use. |
|
(default: :obj:`"silu"`) |
|
attn_activation (string, optional): The type of activation function to use |
|
inside the attention mechanism. (default: :obj:`"silu"`) |
|
neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
|
embedding step. (default: :obj:`True`) |
|
num_heads (int, optional): Number of attention heads. |
|
(default: :obj:`8`) |
|
distance_influence (string, optional): Where distance information is used inside |
|
the attention mechanism. (default: :obj:`"both"`) |
|
cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
|
(default: :obj:`0.0`) |
|
cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
|
(default: :obj:`5.0`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
share_kv=False, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnorm", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=True, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=False, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
): |
|
super(eqTransformer, self).__init__() |
|
|
|
assert distance_influence in ["keys", "values", "both", "none"] |
|
assert rbf_type in rbf_class_mapping, ( |
|
f'Unknown RBF type "{rbf_type}". ' |
|
f'Choose from {", ".join(rbf_class_mapping.keys())}.' |
|
) |
|
assert activation in act_class_mapping, ( |
|
f'Unknown activation function "{activation}". ' |
|
f'Choose from {", ".join(act_class_mapping.keys())}.' |
|
) |
|
assert attn_activation in act_class_mapping, ( |
|
f'Unknown attention activation function "{attn_activation}". ' |
|
f'Choose from {", ".join(act_class_mapping.keys())}.' |
|
) |
|
|
|
self.x_in_channels = x_in_channels |
|
self.x_channels = x_channels |
|
self.vec_in_channels = vec_in_channels |
|
self.vec_channels = vec_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
self.share_kv = share_kv |
|
self.num_layers = num_layers |
|
self.num_rbf = num_rbf |
|
self.num_edge_attr = num_edge_attr |
|
self.rbf_type = rbf_type |
|
self.trainable_rbf = trainable_rbf |
|
self.activation = activation |
|
self.attn_activation = attn_activation |
|
self.neighbor_embedding = neighbor_embedding |
|
self.num_heads = num_heads |
|
self.distance_influence = distance_influence |
|
self.cutoff_lower = cutoff_lower |
|
self.cutoff_upper = cutoff_upper |
|
self.use_lora = use_lora |
|
self.use_msa = x_use_msa |
|
|
|
self.distance = Distance( |
|
cutoff_lower, |
|
cutoff_upper, |
|
return_vecs=True, |
|
loop=True, |
|
) |
|
self.distance_expansion = rbf_class_mapping[rbf_type]( |
|
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf |
|
) |
|
self.neighbor_embedding = ( |
|
NeighborEmbedding( |
|
x_channels, num_rbf + num_edge_attr, cutoff_lower, cutoff_upper, |
|
) |
|
if neighbor_embedding |
|
else None |
|
) |
|
self.msa_encoder = MSAEncoder( |
|
num_species=199, |
|
weighting_schema='spe', |
|
pairwise_type='cov', |
|
) if x_use_msa else None |
|
|
|
self.node_x_proj = None |
|
if x_in_channels is not None: |
|
if x_in_embedding_type == "Linear": |
|
self.node_x_proj = nn.Linear(x_in_channels, x_channels) |
|
elif x_in_embedding_type == "Linear_gelu": |
|
self.node_x_proj = nn.Sequential( |
|
nn.Linear(x_in_channels, x_channels), |
|
nn.GELU(), |
|
) |
|
else: |
|
self.node_x_proj = nn.Embedding(x_in_channels, x_channels) |
|
self.node_vec_proj = nn.Linear( |
|
vec_in_channels, vec_channels, bias=False) |
|
|
|
self.attention_layers = nn.ModuleList() |
|
self._set_attn_layers() |
|
self.drop = nn.Dropout(drop_out_rate) |
|
self.out_norm = nn.LayerNorm(x_channels) |
|
|
|
self.reset_parameters() |
|
|
|
def _set_attn_layers(self): |
|
for _ in range(self.num_layers): |
|
layer = EquivariantMultiHeadAttention( |
|
x_channels=self.x_channels, |
|
x_hidden_channels=self.x_hidden_channels, |
|
vec_channels=self.vec_channels, |
|
vec_hidden_channels=self.vec_hidden_channels, |
|
share_kv=self.share_kv, |
|
edge_attr_channels=self.num_rbf + self.num_edge_attr, |
|
distance_influence=self.distance_influence, |
|
num_heads=self.num_heads, |
|
activation=act_class_mapping[self.activation], |
|
attn_activation=self.attn_activation, |
|
cutoff_lower=self.cutoff_lower, |
|
cutoff_upper=self.cutoff_upper, |
|
) |
|
self.attention_layers.append(layer) |
|
|
|
def reset_parameters(self): |
|
self.distance_expansion.reset_parameters() |
|
if self.neighbor_embedding is not None: |
|
self.neighbor_embedding.reset_parameters() |
|
for attn in self.attention_layers: |
|
attn.reset_parameters() |
|
self.out_norm.reset_parameters() |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
pos: Tensor, |
|
batch: Tensor, |
|
edge_index: Tensor, |
|
edge_index_star: Tensor = None, |
|
edge_attr: Tensor = None, |
|
edge_attr_star: Tensor = None, |
|
edge_vec: Tensor = None, |
|
edge_vec_star: Tensor = None, |
|
node_vec_attr: Tensor = None, |
|
return_attn: bool = False, |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
|
if edge_vec is None: |
|
edge_index, edge_weight, edge_vec = self.distance(pos, edge_index) |
|
assert ( |
|
edge_vec is not None |
|
), "Distance module did not return directional information" |
|
|
|
edge_attr_distance = self.distance_expansion( |
|
edge_weight) |
|
|
|
|
|
edge_attr = torch.cat([edge_attr, edge_attr_distance], dim=-1) |
|
|
|
if (self.x_in_channels is not None and x.shape[1] > self.x_in_channels) or x.shape[1] > self.x_channels: |
|
if self.node_x_proj is not None: |
|
x, x_msa = x[:, :self.x_in_channels], x[:, self.x_in_channels:] |
|
else: |
|
x, x_msa = x[:, :self.x_channels], x[:, self.x_channels:] |
|
else: |
|
x_msa = None |
|
|
|
|
|
if self.msa_encoder is not None and x_msa is not None: |
|
_, msa_edge_attr_star = self.msa_encoder(x_msa, edge_index_star) |
|
edge_attr_star = torch.cat([edge_attr_star, msa_edge_attr_star], dim=-1) |
|
_, msa_edge_attr = self.msa_encoder(x_msa, edge_index) |
|
edge_attr = torch.cat([edge_attr, msa_edge_attr], dim=-1) |
|
mask = edge_index[0] != edge_index[1] |
|
edge_vec[mask] = edge_vec[mask] / \ |
|
torch.norm(edge_vec[mask], dim=1).unsqueeze(1) |
|
|
|
x = self.node_x_proj(x) if self.node_x_proj is not None else x |
|
|
|
if self.neighbor_embedding is not None: |
|
x = self.neighbor_embedding(x, edge_index, edge_weight, edge_attr) |
|
|
|
vec = self.node_vec_proj(node_vec_attr) if node_vec_attr is not None \ |
|
else torch.zeros(x.size(0), 3, self.vec_channels, device=x.device) |
|
|
|
attn_weight_layers = [] |
|
for attn in self.attention_layers: |
|
dx, dvec, attn_weight = attn( |
|
x, vec, edge_index, edge_weight, edge_attr, edge_vec) |
|
x = x + self.drop(dx) |
|
vec = vec + self.drop(dvec) |
|
if return_attn: |
|
attn_weight_layers.append(attn_weight) |
|
x = self.out_norm(x) |
|
|
|
return x, vec, pos, edge_attr, batch, attn_weight_layers |
|
|
|
def __repr__(self): |
|
return ( |
|
f"{self.__class__.__name__}(" |
|
f"x_channels={self.x_channels}, " |
|
f"x_hidden_channels={self.x_hidden_channels}, " |
|
f"vec_in_channels={self.vec_in_channels}, " |
|
f"vec_channels={self.vec_channels}, " |
|
f"vec_hidden_channels={self.vec_hidden_channels}, " |
|
f"num_layers={self.num_layers}, " |
|
f"num_rbf={self.num_rbf}, " |
|
f"rbf_type={self.rbf_type}, " |
|
f"trainable_rbf={self.trainable_rbf}, " |
|
f"activation={self.activation}, " |
|
f"attn_activation={self.attn_activation}, " |
|
f"neighbor_embedding={self.neighbor_embedding}, " |
|
f"num_heads={self.num_heads}, " |
|
f"distance_influence={self.distance_influence}, " |
|
f"cutoff_lower={self.cutoff_lower}, " |
|
f"cutoff_upper={self.cutoff_upper})" |
|
) |
|
|
|
|
|
|
|
class eqStarTransformer(eqTransformer): |
|
"""The equivariant Transformer architecture. |
|
First Layer is Star Graph, next layer is full graph |
|
|
|
Args: |
|
x_channels (int, optional): Hidden embedding size. |
|
(default: :obj:`128`) |
|
num_layers (int, optional): The number of attention layers. |
|
(default: :obj:`6`) |
|
num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
|
(default: :obj:`50`) |
|
rbf_type (string, optional): The type of radial basis function to use. |
|
(default: :obj:`"expnorm"`) |
|
trainable_rbf (bool, optional): Whether to train RBF parameters with |
|
backpropagation. (default: :obj:`True`) |
|
activation (string, optional): The type of activation function to use. |
|
(default: :obj:`"silu"`) |
|
attn_activation (string, optional): The type of activation function to use |
|
inside the attention mechanism. (default: :obj:`"silu"`) |
|
neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
|
embedding step. (default: :obj:`True`) |
|
num_heads (int, optional): Number of attention heads. |
|
(default: :obj:`8`) |
|
distance_influence (string, optional): Where distance information is used inside |
|
the attention mechanism. (default: :obj:`"both"`) |
|
cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
|
(default: :obj:`0.0`) |
|
cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
|
(default: :obj:`5.0`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
share_kv=False, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnorm", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=True, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=False, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
): |
|
super(eqStarTransformer, self).__init__(x_in_channels=x_in_channels, |
|
x_channels=x_channels, |
|
x_hidden_channels=x_hidden_channels, |
|
vec_in_channels=vec_in_channels, |
|
vec_channels=vec_channels, |
|
vec_hidden_channels=vec_hidden_channels, |
|
share_kv=share_kv, |
|
num_layers=num_layers, |
|
num_edge_attr=num_edge_attr, |
|
num_rbf=num_rbf, |
|
rbf_type=rbf_type, |
|
trainable_rbf=trainable_rbf, |
|
activation=activation, |
|
attn_activation=attn_activation, |
|
neighbor_embedding=neighbor_embedding, |
|
num_heads=num_heads, |
|
distance_influence=distance_influence, |
|
cutoff_lower=cutoff_lower, |
|
cutoff_upper=cutoff_upper, |
|
x_in_embedding_type=x_in_embedding_type, |
|
x_use_msa=x_use_msa, |
|
drop_out_rate=drop_out_rate, |
|
use_lora=use_lora) |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
pos: Tensor, |
|
batch: Tensor, |
|
edge_index: Tensor, |
|
edge_index_star: Tensor = None, |
|
edge_attr: Tensor = None, |
|
edge_attr_star: Tensor = None, |
|
node_vec_attr: Tensor = None, |
|
return_attn: bool = False, |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
|
edge_index, edge_weight, edge_vec = self.distance(pos, edge_index) |
|
edge_index_star, edge_weight_star, edge_vec_star = self.distance( |
|
pos, edge_index_star) |
|
|
|
assert ( |
|
edge_vec is not None and edge_vec_star is not None |
|
), "Distance module did not return directional information" |
|
|
|
edge_attr_distance = self.distance_expansion( |
|
edge_weight) |
|
edge_attr_distance_star = self.distance_expansion( |
|
edge_weight_star) |
|
|
|
if edge_attr is not None: |
|
|
|
edge_attr = torch.cat([edge_attr, edge_attr_distance], dim=-1) |
|
else: |
|
edge_attr = edge_attr_distance |
|
if edge_attr_star is not None: |
|
edge_attr_star = torch.cat( |
|
[edge_attr_star, edge_attr_distance_star], dim=-1) |
|
else: |
|
edge_attr_star = edge_attr_distance_star |
|
|
|
if self.node_x_proj is not None: |
|
if x.shape[1] > self.x_in_channels: |
|
x, x_msa = x[:, :self.x_in_channels], x[:, self.x_in_channels:] |
|
else: |
|
x_msa = None |
|
elif x.shape[1] > self.x_channels: |
|
x, x_msa = x[:, :self.x_channels], x[:, self.x_channels:] |
|
else: |
|
x_msa = None |
|
|
|
|
|
if self.msa_encoder is not None and x_msa is not None: |
|
_, msa_edge_attr_star = self.msa_encoder(x_msa, edge_index_star) |
|
edge_attr_star = torch.cat([edge_attr_star, msa_edge_attr_star], dim=-1) |
|
|
|
|
|
|
|
|
|
mask = edge_index[0] != edge_index[1] |
|
edge_vec[mask] = edge_vec[mask] / \ |
|
torch.norm(edge_vec[mask], dim=1).unsqueeze(1) |
|
mask = edge_index_star[0] != edge_index_star[1] |
|
edge_vec_star[mask] = edge_vec_star[mask] / \ |
|
torch.norm(edge_vec_star[mask], dim=1).unsqueeze(1) |
|
|
|
x = self.node_x_proj(x) if self.node_x_proj is not None else x |
|
if self.neighbor_embedding is not None: |
|
|
|
x = self.neighbor_embedding( |
|
x, edge_index_star, edge_weight_star, edge_attr_star) |
|
|
|
vec = self.node_vec_proj(node_vec_attr) if node_vec_attr is not None \ |
|
else torch.zeros(x.size(0), 3, self.vec_channels, device=x.device) |
|
|
|
attn_weight_layers = [] |
|
for i, attn in enumerate(self.attention_layers): |
|
|
|
if i == 0: |
|
dx, dvec, attn_weight = attn(x, vec, |
|
edge_index_star, edge_weight_star, edge_attr_star, edge_vec_star, |
|
return_attn=return_attn) |
|
else: |
|
dx, dvec, attn_weight = attn(x, vec, |
|
edge_index, edge_weight, edge_attr, edge_vec, |
|
return_attn=return_attn) |
|
x = x + self.drop(dx) |
|
vec = vec + self.drop(dvec) |
|
if return_attn: |
|
attn_weight_layers.append(attn_weight) |
|
x = self.out_norm(x) |
|
|
|
|
|
|
|
|
|
|
|
return x, vec, pos, edge_attr_star, batch, attn_weight_layers |
|
|
|
|
|
|
|
class eqTransformerSoftMax(eqTransformer): |
|
"""The equivariant Transformer architecture. |
|
|
|
Args: |
|
x_channels (int, optional): Hidden embedding size. |
|
(default: :obj:`128`) |
|
num_layers (int, optional): The number of attention layers. |
|
(default: :obj:`6`) |
|
num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
|
(default: :obj:`50`) |
|
rbf_type (string, optional): The type of radial basis function to use. |
|
(default: :obj:`"expnorm"`) |
|
trainable_rbf (bool, optional): Whether to train RBF parameters with |
|
backpropagation. (default: :obj:`True`) |
|
activation (string, optional): The type of activation function to use. |
|
(default: :obj:`"silu"`) |
|
attn_activation (string, optional): The type of activation function to use |
|
inside the attention mechanism. (default: :obj:`"silu"`) |
|
neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
|
embedding step. (default: :obj:`True`) |
|
num_heads (int, optional): Number of attention heads. |
|
(default: :obj:`8`) |
|
distance_influence (string, optional): Where distance information is used inside |
|
the attention mechanism. (default: :obj:`"both"`) |
|
cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
|
(default: :obj:`0.0`) |
|
cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
|
(default: :obj:`5.0`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnorm", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=True, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=False, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
): |
|
super(eqTransformerSoftMax, self).__init__(x_in_channels=x_in_channels, |
|
x_channels=x_channels, |
|
x_hidden_channels=x_hidden_channels, |
|
vec_in_channels=vec_in_channels, |
|
vec_channels=vec_channels, |
|
vec_hidden_channels=vec_hidden_channels, |
|
num_layers=num_layers, |
|
num_edge_attr=num_edge_attr, |
|
num_rbf=num_rbf, |
|
rbf_type=rbf_type, |
|
trainable_rbf=trainable_rbf, |
|
activation=activation, |
|
attn_activation=attn_activation, |
|
neighbor_embedding=neighbor_embedding, |
|
num_heads=num_heads, |
|
distance_influence=distance_influence, |
|
cutoff_lower=cutoff_lower, |
|
cutoff_upper=cutoff_upper, |
|
x_in_embedding_type=x_in_embedding_type, |
|
x_use_msa=x_use_msa, |
|
drop_out_rate=drop_out_rate, |
|
use_lora=use_lora) |
|
|
|
def _set_attn_layers(self): |
|
for _ in range(self.num_layers): |
|
layer = EquivariantMultiHeadAttentionSoftMax( |
|
x_channels=self.x_channels, |
|
x_hidden_channels=self.x_hidden_channels, |
|
vec_channels=self.vec_channels, |
|
vec_hidden_channels=self.vec_hidden_channels, |
|
share_kv=self.share_kv, |
|
edge_attr_channels=self.num_rbf + self.num_edge_attr, |
|
distance_influence=self.distance_influence, |
|
num_heads=self.num_heads, |
|
activation=act_class_mapping[self.activation], |
|
attn_activation=self.attn_activation, |
|
cutoff_lower=self.cutoff_lower, |
|
cutoff_upper=self.cutoff_upper, |
|
) |
|
self.attention_layers.append(layer) |
|
|
|
|
|
|
|
class eqStarTransformerSoftMax(eqStarTransformer): |
|
"""The equivariant Transformer architecture. |
|
First Layer is Star Graph, next layer is full graph |
|
|
|
Args: |
|
x_channels (int, optional): Hidden embedding size. |
|
(default: :obj:`128`) |
|
num_layers (int, optional): The number of attention layers. |
|
(default: :obj:`6`) |
|
num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
|
(default: :obj:`50`) |
|
rbf_type (string, optional): The type of radial basis function to use. |
|
(default: :obj:`"expnorm"`) |
|
trainable_rbf (bool, optional): Whether to train RBF parameters with |
|
backpropagation. (default: :obj:`True`) |
|
activation (string, optional): The type of activation function to use. |
|
(default: :obj:`"silu"`) |
|
attn_activation (string, optional): The type of activation function to use |
|
inside the attention mechanism. (default: :obj:`"silu"`) |
|
neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
|
embedding step. (default: :obj:`True`) |
|
num_heads (int, optional): Number of attention heads. |
|
(default: :obj:`8`) |
|
distance_influence (string, optional): Where distance information is used inside |
|
the attention mechanism. (default: :obj:`"both"`) |
|
cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
|
(default: :obj:`0.0`) |
|
cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
|
(default: :obj:`5.0`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
share_kv=False, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnorm", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=True, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=False, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
): |
|
super(eqStarTransformerSoftMax, self).__init__(x_in_channels=x_in_channels, |
|
x_channels=x_channels, |
|
x_hidden_channels=x_hidden_channels, |
|
vec_in_channels=vec_in_channels, |
|
vec_channels=vec_channels, |
|
vec_hidden_channels=vec_hidden_channels, |
|
share_kv=share_kv, |
|
num_layers=num_layers, |
|
num_edge_attr=num_edge_attr, |
|
num_rbf=num_rbf, |
|
rbf_type=rbf_type, |
|
trainable_rbf=trainable_rbf, |
|
activation=activation, |
|
attn_activation=attn_activation, |
|
neighbor_embedding=neighbor_embedding, |
|
num_heads=num_heads, |
|
distance_influence=distance_influence, |
|
cutoff_lower=cutoff_lower, |
|
cutoff_upper=cutoff_upper, |
|
x_in_embedding_type=x_in_embedding_type, |
|
x_use_msa=x_use_msa, |
|
drop_out_rate=drop_out_rate, |
|
use_lora=use_lora) |
|
|
|
def _set_attn_layers(self): |
|
for _ in range(self.num_layers): |
|
layer = EquivariantMultiHeadAttentionSoftMax( |
|
x_channels=self.x_channels, |
|
x_hidden_channels=self.x_hidden_channels, |
|
vec_channels=self.vec_channels, |
|
vec_hidden_channels=self.vec_hidden_channels, |
|
share_kv=self.share_kv, |
|
edge_attr_channels=self.num_rbf + self.num_edge_attr, |
|
distance_influence=self.distance_influence, |
|
num_heads=self.num_heads, |
|
activation=act_class_mapping[self.activation], |
|
attn_activation=self.attn_activation, |
|
cutoff_lower=self.cutoff_lower, |
|
cutoff_upper=self.cutoff_upper, |
|
) |
|
self.attention_layers.append(layer) |
|
|
|
|
|
class eqStar2TransformerSoftMax(eqStarTransformer): |
|
"""The equivariant Transformer architecture. |
|
First Layer is Star Graph, next layer is full graph |
|
|
|
Args: |
|
x_channels (int, optional): Hidden embedding size. |
|
(default: :obj:`128`) |
|
num_layers (int, optional): The number of attention layers. |
|
(default: :obj:`6`) |
|
num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
|
(default: :obj:`50`) |
|
rbf_type (string, optional): The type of radial basis function to use. |
|
(default: :obj:`"expnorm"`) |
|
trainable_rbf (bool, optional): Whether to train RBF parameters with |
|
backpropagation. (default: :obj:`True`) |
|
activation (string, optional): The type of activation function to use. |
|
(default: :obj:`"silu"`) |
|
attn_activation (string, optional): The type of activation function to use |
|
inside the attention mechanism. (default: :obj:`"silu"`) |
|
neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
|
embedding step. (default: :obj:`True`) |
|
num_heads (int, optional): Number of attention heads. |
|
(default: :obj:`8`) |
|
distance_influence (string, optional): Where distance information is used inside |
|
the attention mechanism. (default: :obj:`"both"`) |
|
cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
|
(default: :obj:`0.0`) |
|
cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
|
(default: :obj:`5.0`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
share_kv=False, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnorm", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=True, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=False, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
): |
|
super(eqStar2TransformerSoftMax, self).__init__(x_in_channels=x_in_channels, |
|
x_channels=x_channels, |
|
x_hidden_channels=x_hidden_channels, |
|
vec_in_channels=vec_in_channels, |
|
vec_channels=vec_channels, |
|
vec_hidden_channels=vec_hidden_channels, |
|
share_kv=share_kv, |
|
num_layers=num_layers, |
|
num_edge_attr=num_edge_attr, |
|
num_rbf=num_rbf, |
|
rbf_type=rbf_type, |
|
trainable_rbf=trainable_rbf, |
|
activation=activation, |
|
attn_activation=attn_activation, |
|
neighbor_embedding=neighbor_embedding, |
|
num_heads=num_heads, |
|
distance_influence=distance_influence, |
|
cutoff_lower=cutoff_lower, |
|
cutoff_upper=cutoff_upper, |
|
x_in_embedding_type=x_in_embedding_type, |
|
x_use_msa=x_use_msa, |
|
drop_out_rate=drop_out_rate, |
|
use_lora=use_lora) |
|
|
|
def _set_attn_layers(self): |
|
assert self.num_layers > 0, "num_layers must be greater than 0" |
|
|
|
self.attention_layers.append( |
|
EquivariantMultiHeadAttention( |
|
x_channels=self.x_channels, |
|
x_hidden_channels=self.x_hidden_channels, |
|
vec_channels=self.vec_channels, |
|
vec_hidden_channels=self.vec_hidden_channels, |
|
share_kv=self.share_kv, |
|
edge_attr_channels=self.num_rbf + self.num_edge_attr, |
|
distance_influence=self.distance_influence, |
|
num_heads=self.num_heads, |
|
activation=act_class_mapping[self.activation], |
|
attn_activation=self.attn_activation, |
|
cutoff_lower=self.cutoff_lower, |
|
cutoff_upper=self.cutoff_upper, |
|
use_lora=self.use_lora, |
|
) |
|
) |
|
|
|
for _ in range(self.num_layers - 1): |
|
layer = EquivariantMultiHeadAttentionSoftMax( |
|
x_channels=self.x_channels, |
|
x_hidden_channels=self.x_hidden_channels, |
|
vec_channels=self.vec_channels, |
|
vec_hidden_channels=self.vec_hidden_channels, |
|
share_kv=self.share_kv, |
|
edge_attr_channels=self.num_rbf + self.num_edge_attr - 442 if self.use_msa else self.num_rbf + self.num_edge_attr, |
|
distance_influence=self.distance_influence, |
|
num_heads=self.num_heads, |
|
activation=act_class_mapping[self.activation], |
|
attn_activation=self.attn_activation, |
|
cutoff_lower=self.cutoff_lower, |
|
cutoff_upper=self.cutoff_upper, |
|
use_lora=self.use_lora, |
|
) |
|
self.attention_layers.append(layer) |
|
|
|
|
|
class eqStar2PAETransformerSoftMax(eqStar2TransformerSoftMax): |
|
"""The equivariant Transformer architecture. |
|
First Layer is Star Graph, next layer is full graph |
|
|
|
Args: |
|
x_channels (int, optional): Hidden embedding size. |
|
(default: :obj:`128`) |
|
num_layers (int, optional): The number of attention layers. |
|
(default: :obj:`6`) |
|
num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
|
(default: :obj:`50`) |
|
rbf_type (string, optional): The type of radial basis function to use. |
|
(default: :obj:`"expnorm"`) |
|
trainable_rbf (bool, optional): Whether to train RBF parameters with |
|
backpropagation. (default: :obj:`True`) |
|
activation (string, optional): The type of activation function to use. |
|
(default: :obj:`"silu"`) |
|
attn_activation (string, optional): The type of activation function to use |
|
inside the attention mechanism. (default: :obj:`"silu"`) |
|
neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
|
embedding step. (default: :obj:`True`) |
|
num_heads (int, optional): Number of attention heads. |
|
(default: :obj:`8`) |
|
distance_influence (string, optional): Where distance information is used inside |
|
the attention mechanism. (default: :obj:`"both"`) |
|
cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
|
(default: :obj:`0.0`) |
|
cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
|
(default: :obj:`5.0`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
share_kv=False, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnorm", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=True, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=False, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
): |
|
super(eqStar2PAETransformerSoftMax, self).__init__(x_in_channels=x_in_channels, |
|
x_channels=x_channels, |
|
x_hidden_channels=x_hidden_channels, |
|
vec_in_channels=vec_in_channels, |
|
vec_channels=vec_channels, |
|
vec_hidden_channels=vec_hidden_channels, |
|
share_kv=share_kv, |
|
num_layers=num_layers, |
|
num_edge_attr=num_edge_attr, |
|
num_rbf=num_rbf, |
|
rbf_type=rbf_type, |
|
trainable_rbf=trainable_rbf, |
|
activation=activation, |
|
attn_activation=attn_activation, |
|
neighbor_embedding=neighbor_embedding, |
|
num_heads=num_heads, |
|
distance_influence=distance_influence, |
|
cutoff_lower=cutoff_lower, |
|
cutoff_upper=cutoff_upper, |
|
x_in_embedding_type=x_in_embedding_type, |
|
x_use_msa=x_use_msa, |
|
drop_out_rate=drop_out_rate, |
|
use_lora=use_lora) |
|
|
|
self.neighbor_embedding = ( |
|
NeighborEmbedding( |
|
x_channels, num_edge_attr, |
|
cutoff_lower, cutoff_upper, |
|
) |
|
if neighbor_embedding |
|
else None |
|
) |
|
self.neighbor_embedding.reset_parameters() |
|
|
|
def _set_attn_layers(self): |
|
assert self.num_layers > 0, "num_layers must be greater than 0" |
|
|
|
self.attention_layers.append( |
|
EquivariantPAEMultiHeadAttention( |
|
x_channels=self.x_channels, |
|
x_hidden_channels=self.x_hidden_channels, |
|
vec_channels=self.vec_channels, |
|
vec_hidden_channels=self.vec_hidden_channels, |
|
share_kv=self.share_kv, |
|
edge_attr_dist_channels=self.num_rbf, |
|
edge_attr_channels=self.num_edge_attr, |
|
distance_influence=self.distance_influence, |
|
num_heads=self.num_heads, |
|
activation=act_class_mapping[self.activation], |
|
attn_activation=self.attn_activation, |
|
cutoff_lower=self.cutoff_lower, |
|
cutoff_upper=self.cutoff_upper, |
|
use_lora=self.use_lora, |
|
) |
|
) |
|
|
|
for _ in range(self.num_layers - 1): |
|
layer = EquivariantPAEMultiHeadAttentionSoftMax( |
|
x_channels=self.x_channels, |
|
x_hidden_channels=self.x_hidden_channels, |
|
vec_channels=self.vec_channels, |
|
vec_hidden_channels=self.vec_hidden_channels, |
|
share_kv=self.share_kv, |
|
edge_attr_dist_channels=self.num_rbf, |
|
edge_attr_channels=self.num_edge_attr - 442 if self.use_msa else self.num_edge_attr, |
|
distance_influence=self.distance_influence, |
|
num_heads=self.num_heads, |
|
activation=act_class_mapping[self.activation], |
|
attn_activation=self.attn_activation, |
|
cutoff_lower=self.cutoff_lower, |
|
cutoff_upper=self.cutoff_upper, |
|
use_lora=self.use_lora, |
|
) |
|
self.attention_layers.append(layer) |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
pos: Tensor, |
|
batch: Tensor, |
|
edge_index: Tensor, |
|
edge_index_star: Tensor = None, |
|
edge_attr: Tensor = None, |
|
edge_attr_star: Tensor = None, |
|
node_vec_attr: Tensor = None, |
|
plddt: Tensor = None, |
|
edge_confidence: Tensor = None, |
|
edge_confidence_star: Tensor = None, |
|
return_attn: bool = False, |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
|
edge_index, edge_weight, edge_vec = self.distance(pos, edge_index) |
|
edge_index_star, edge_weight_star, edge_vec_star = self.distance(pos, edge_index_star) |
|
|
|
assert ( |
|
edge_vec is not None and edge_vec_star is not None |
|
), "Distance module did not return directional information" |
|
|
|
edge_attr_distance = self.distance_expansion( |
|
edge_weight) |
|
edge_attr_distance_star = self.distance_expansion( |
|
edge_weight_star) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.node_x_proj is not None: |
|
if x.shape[1] > self.x_in_channels: |
|
x, x_msa = x[:, :self.x_in_channels], x[:, self.x_in_channels:] |
|
else: |
|
x_msa = None |
|
elif x.shape[1] > self.x_channels: |
|
x, x_msa = x[:, :self.x_channels], x[:, self.x_channels:] |
|
else: |
|
x_msa = None |
|
|
|
|
|
if self.msa_encoder is not None and x_msa is not None: |
|
_, msa_edge_attr_star = self.msa_encoder(x_msa, edge_index_star) |
|
if edge_attr_star is not None: |
|
edge_attr_star = torch.cat([edge_attr_star, msa_edge_attr_star], dim=-1) |
|
else: |
|
edge_attr_star = msa_edge_attr_star |
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = edge_index[0] != edge_index[1] |
|
edge_vec[mask] = edge_vec[mask] / \ |
|
torch.norm(edge_vec[mask], dim=1).unsqueeze(1) |
|
mask = edge_index_star[0] != edge_index_star[1] |
|
edge_vec_star[mask] = edge_vec_star[mask] / \ |
|
torch.norm(edge_vec_star[mask], dim=1).unsqueeze(1) |
|
|
|
x = self.node_x_proj(x) if self.node_x_proj is not None else x |
|
if self.neighbor_embedding is not None: |
|
|
|
x = self.neighbor_embedding( |
|
x, edge_index_star, edge_weight_star, edge_attr_star) |
|
|
|
vec = self.node_vec_proj(node_vec_attr) if node_vec_attr is not None \ |
|
else torch.zeros(x.size(0), 3, self.vec_channels, device=x.device) |
|
|
|
attn_weight_layers = [] |
|
for i, attn in enumerate(self.attention_layers): |
|
|
|
if i == 0: |
|
dx, dvec, attn_weight = attn(x, vec, |
|
edge_index_star, edge_confidence_star, |
|
edge_attr_distance_star, edge_attr_star, |
|
edge_vec_star, plddt, |
|
return_attn=return_attn) |
|
else: |
|
dx, dvec, attn_weight = attn(x, vec, |
|
edge_index, edge_confidence, |
|
edge_attr_distance, edge_attr, |
|
edge_vec, plddt, |
|
return_attn=return_attn) |
|
x = x + self.drop(dx) |
|
vec = vec + self.drop(dvec) |
|
if return_attn: |
|
attn_weight_layers.append(attn_weight) |
|
x = self.out_norm(x) |
|
return x, vec, pos, edge_attr_star, batch, attn_weight_layers |
|
|
|
|
|
class eqStar2FullGraphPAETransformerSoftMax(nn.Module): |
|
"""The equivariant Transformer architecture. |
|
First Layer is Star Graph, next layer is full graph |
|
|
|
Args: |
|
x_channels (int, optional): Hidden embedding size. |
|
(default: :obj:`128`) |
|
num_layers (int, optional): The number of attention layers. |
|
(default: :obj:`6`) |
|
num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
|
(default: :obj:`50`) |
|
rbf_type (string, optional): The type of radial basis function to use. |
|
(default: :obj:`"expnorm"`) |
|
trainable_rbf (bool, optional): Whether to train RBF parameters with |
|
backpropagation. (default: :obj:`True`) |
|
activation (string, optional): The type of activation function to use. |
|
(default: :obj:`"silu"`) |
|
attn_activation (string, optional): The type of activation function to use |
|
inside the attention mechanism. (default: :obj:`"silu"`) |
|
neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
|
embedding step. (default: :obj:`True`) |
|
num_heads (int, optional): Number of attention heads. |
|
(default: :obj:`8`) |
|
distance_influence (string, optional): Where distance information is used inside |
|
the attention mechanism. (default: :obj:`"both"`) |
|
cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
|
(default: :obj:`0.0`) |
|
cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
|
(default: :obj:`5.0`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
share_kv=False, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnorm", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=True, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=False, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
): |
|
super(eqStar2FullGraphPAETransformerSoftMax, self).__init__() |
|
|
|
assert distance_influence in ["keys", "values", "both", "none"] |
|
assert rbf_type in rbf_class_mapping, ( |
|
f'Unknown RBF type "{rbf_type}". ' |
|
f'Choose from {", ".join(rbf_class_mapping.keys())}.' |
|
) |
|
assert activation in act_class_mapping, ( |
|
f'Unknown activation function "{activation}". ' |
|
f'Choose from {", ".join(act_class_mapping.keys())}.' |
|
) |
|
assert attn_activation in act_class_mapping, ( |
|
f'Unknown attention activation function "{attn_activation}". ' |
|
f'Choose from {", ".join(act_class_mapping.keys())}.' |
|
) |
|
|
|
self.x_in_channels = x_in_channels |
|
self.x_channels = x_channels |
|
self.vec_in_channels = vec_in_channels |
|
self.vec_channels = vec_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
self.share_kv = share_kv |
|
self.num_layers = num_layers |
|
self.num_rbf = num_rbf |
|
self.num_edge_attr = num_edge_attr |
|
self.rbf_type = rbf_type |
|
self.trainable_rbf = trainable_rbf |
|
self.activation = activation |
|
self.attn_activation = attn_activation |
|
self.neighbor_embedding = neighbor_embedding |
|
self.num_heads = num_heads |
|
self.distance_influence = distance_influence |
|
self.cutoff_lower = cutoff_lower |
|
self.cutoff_upper = cutoff_upper |
|
self.use_lora = use_lora |
|
self.use_msa = x_use_msa |
|
|
|
self.distance = Distance( |
|
cutoff_lower, |
|
cutoff_upper, |
|
return_vecs=True, |
|
loop=True, |
|
) |
|
self.distance_expansion = rbf_class_mapping[rbf_type]( |
|
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf |
|
) |
|
self.neighbor_embedding = None |
|
self.msa_encoder = MSAEncoderFullGraph( |
|
num_species=199, |
|
weighting_schema='spe', |
|
pairwise_type='cov', |
|
) if x_use_msa else None |
|
|
|
self.node_x_proj = None |
|
if x_in_channels is not None: |
|
if x_in_embedding_type == "Linear": |
|
self.node_x_proj = nn.Linear(x_in_channels, x_channels) |
|
elif x_in_embedding_type == "Linear_gelu": |
|
self.node_x_proj = nn.Sequential( |
|
nn.Linear(x_in_channels, x_channels), |
|
nn.GELU(), |
|
) |
|
else: |
|
self.node_x_proj = nn.Embedding(x_in_channels, x_channels) |
|
self.node_vec_proj = nn.Linear( |
|
vec_in_channels, vec_channels, bias=False) |
|
|
|
self.attention_layers = nn.ModuleList() |
|
self._set_attn_layers() |
|
self.drop = nn.Dropout(drop_out_rate) |
|
self.out_norm = nn.LayerNorm(x_channels) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
self.distance_expansion.reset_parameters() |
|
if self.neighbor_embedding is not None: |
|
self.neighbor_embedding.reset_parameters() |
|
for attn in self.attention_layers: |
|
attn.reset_parameters() |
|
self.out_norm.reset_parameters() |
|
|
|
def _set_attn_layers(self): |
|
assert self.num_layers > 0, "num_layers must be greater than 0" |
|
|
|
|
|
input_dic = { |
|
"x_channels": self.x_channels, |
|
"x_hidden_channels": self.x_hidden_channels, |
|
"vec_channels": self.vec_channels, |
|
"vec_hidden_channels": self.vec_hidden_channels, |
|
"share_kv": self.share_kv, |
|
"edge_attr_dist_channels": self.num_rbf, |
|
"edge_attr_channels": self.num_edge_attr, |
|
"distance_influence": self.distance_influence, |
|
"num_heads": self.num_heads, |
|
"activation": act_class_mapping[self.activation], |
|
"attn_activation": self.attn_activation, |
|
"cutoff_lower": self.cutoff_lower, |
|
"cutoff_upper": self.cutoff_upper, |
|
"use_lora": self.use_lora |
|
} |
|
for _ in range(self.num_layers): |
|
layer = EquivariantPAEMultiHeadAttentionSoftMaxFullGraph(**input_dic) |
|
self.attention_layers.append(layer) |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
pos: Tensor, |
|
batch: Tensor = None, |
|
x_padding_mask: Tensor = None, |
|
edge_index: Tensor = None, |
|
edge_index_star: Tensor = None, |
|
edge_attr: Tensor = None, |
|
edge_attr_star: Tensor = None, |
|
node_vec_attr: Tensor = None, |
|
plddt: Tensor = None, |
|
edge_confidence: Tensor = None, |
|
edge_confidence_star: Tensor = None, |
|
return_attn: bool = False, |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
|
edge_vec = pos[:, :, None, :] - pos[:, None, :, :] |
|
edge_weight = torch.norm(edge_vec, dim=-1) |
|
|
|
|
|
edge_attr_distance = self.distance_expansion(edge_weight) |
|
|
|
|
|
x, x_msa = x[..., :self.x_in_channels], x[..., self.x_in_channels:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_, msa_edge_attr = self.msa_encoder(x_msa) |
|
|
|
edge_attr = torch.cat([edge_attr, msa_edge_attr], dim=-1) |
|
|
|
|
|
|
|
mask = torch.ones((edge_vec.shape[0], edge_vec.shape[1], edge_vec.shape[2]), device=edge_vec.device, dtype=torch.bool)^torch.eye(edge_vec.shape[1], device=edge_vec.device, dtype=torch.bool).unsqueeze(0) |
|
edge_vec[mask] = edge_vec[mask] / torch.norm(edge_vec[mask] + 1e-12, dim=-1).unsqueeze(-1) |
|
|
|
x = self.node_x_proj(x) if self.node_x_proj is not None else x |
|
|
|
vec = self.node_vec_proj(node_vec_attr) if node_vec_attr is not None \ |
|
else torch.zeros(x.size(0), 3, self.vec_channels, device=x.device) |
|
|
|
attn_weight_layers = [] |
|
for i, attn in enumerate(self.attention_layers): |
|
|
|
dx, dvec, attn_weight = attn(x, vec, |
|
edge_index, edge_confidence, |
|
edge_attr_distance, edge_attr, |
|
edge_vec, plddt, x_padding_mask, |
|
return_attn=return_attn) |
|
x = x + self.drop(dx) |
|
vec = vec + self.drop(dvec) |
|
if return_attn: |
|
attn_weight_layers.append(attn_weight) |
|
x = self.out_norm(x) |
|
return x, vec, pos, [edge_confidence, edge_attr_distance, edge_attr, plddt], batch, attn_weight_layers |
|
|
|
|
|
class FullGraphPAETransformerSoftMax(eqStar2FullGraphPAETransformerSoftMax): |
|
"""The equivariant Transformer architecture. |
|
First Layer is Star Graph, next layer is full graph |
|
|
|
Args: |
|
x_channels (int, optional): Hidden embedding size. |
|
(default: :obj:`128`) |
|
num_layers (int, optional): The number of attention layers. |
|
(default: :obj:`6`) |
|
num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
|
(default: :obj:`50`) |
|
rbf_type (string, optional): The type of radial basis function to use. |
|
(default: :obj:`"expnorm"`) |
|
trainable_rbf (bool, optional): Whether to train RBF parameters with |
|
backpropagation. (default: :obj:`True`) |
|
activation (string, optional): The type of activation function to use. |
|
(default: :obj:`"silu"`) |
|
attn_activation (string, optional): The type of activation function to use |
|
inside the attention mechanism. (default: :obj:`"silu"`) |
|
neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
|
embedding step. (default: :obj:`True`) |
|
num_heads (int, optional): Number of attention heads. |
|
(default: :obj:`8`) |
|
distance_influence (string, optional): Where distance information is used inside |
|
the attention mechanism. (default: :obj:`"both"`) |
|
cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
|
(default: :obj:`0.0`) |
|
cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
|
(default: :obj:`5.0`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
share_kv=False, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnorm", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=True, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=False, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
): |
|
super(FullGraphPAETransformerSoftMax, self).__init__(x_in_channels=x_in_channels, |
|
x_channels=x_channels, |
|
x_hidden_channels=x_hidden_channels, |
|
vec_in_channels=vec_in_channels, |
|
vec_channels=vec_channels, |
|
vec_hidden_channels=vec_hidden_channels, |
|
share_kv=share_kv, |
|
num_layers=num_layers, |
|
num_edge_attr=num_edge_attr, |
|
num_rbf=num_rbf, |
|
rbf_type=rbf_type, |
|
trainable_rbf=trainable_rbf, |
|
activation=activation, |
|
attn_activation=attn_activation, |
|
neighbor_embedding=neighbor_embedding, |
|
num_heads=num_heads, |
|
distance_influence=distance_influence, |
|
cutoff_lower=cutoff_lower, |
|
cutoff_upper=cutoff_upper, |
|
x_in_embedding_type=x_in_embedding_type, |
|
x_use_msa=x_use_msa, |
|
drop_out_rate=drop_out_rate, |
|
use_lora=use_lora) |
|
|
|
def _set_attn_layers(self): |
|
assert self.num_layers > 0, "num_layers must be greater than 0" |
|
|
|
|
|
input_dic = { |
|
"x_channels": self.x_channels, |
|
"x_hidden_channels": self.x_hidden_channels, |
|
"vec_channels": self.vec_channels, |
|
"vec_hidden_channels": self.vec_hidden_channels, |
|
"share_kv": self.share_kv, |
|
"edge_attr_dist_channels": self.num_rbf, |
|
"edge_attr_channels": self.num_edge_attr, |
|
"distance_influence": self.distance_influence, |
|
"num_heads": self.num_heads, |
|
"activation": act_class_mapping[self.activation], |
|
"attn_activation": self.attn_activation, |
|
"cutoff_lower": self.cutoff_lower, |
|
"cutoff_upper": self.cutoff_upper, |
|
"use_lora": self.use_lora |
|
} |
|
for _ in range(self.num_layers): |
|
layer = MultiHeadAttentionSoftMaxFullGraph(**input_dic) |
|
self.attention_layers.append(layer) |
|
|
|
|
|
class eqStar2WeightedPAETransformerSoftMax(eqStar2PAETransformerSoftMax): |
|
"""The equivariant Transformer architecture. |
|
First Layer is Star Graph, next layer is full graph |
|
|
|
Args: |
|
x_channels (int, optional): Hidden embedding size. |
|
(default: :obj:`128`) |
|
num_layers (int, optional): The number of attention layers. |
|
(default: :obj:`6`) |
|
num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
|
(default: :obj:`50`) |
|
rbf_type (string, optional): The type of radial basis function to use. |
|
(default: :obj:`"expnorm"`) |
|
trainable_rbf (bool, optional): Whether to train RBF parameters with |
|
backpropagation. (default: :obj:`True`) |
|
activation (string, optional): The type of activation function to use. |
|
(default: :obj:`"silu"`) |
|
attn_activation (string, optional): The type of activation function to use |
|
inside the attention mechanism. (default: :obj:`"silu"`) |
|
neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
|
embedding step. (default: :obj:`True`) |
|
num_heads (int, optional): Number of attention heads. |
|
(default: :obj:`8`) |
|
distance_influence (string, optional): Where distance information is used inside |
|
the attention mechanism. (default: :obj:`"both"`) |
|
cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
|
(default: :obj:`0.0`) |
|
cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
|
(default: :obj:`5.0`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
share_kv=False, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnorm", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=True, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=False, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
): |
|
super(eqStar2WeightedPAETransformerSoftMax, self).__init__(x_in_channels=x_in_channels, |
|
x_channels=x_channels, |
|
x_hidden_channels=x_hidden_channels, |
|
vec_in_channels=vec_in_channels, |
|
vec_channels=vec_channels, |
|
vec_hidden_channels=vec_hidden_channels, |
|
share_kv=share_kv, |
|
num_layers=num_layers, |
|
num_edge_attr=num_edge_attr, |
|
num_rbf=num_rbf, |
|
rbf_type=rbf_type, |
|
trainable_rbf=trainable_rbf, |
|
activation=activation, |
|
attn_activation=attn_activation, |
|
neighbor_embedding=neighbor_embedding, |
|
num_heads=num_heads, |
|
distance_influence=distance_influence, |
|
cutoff_lower=cutoff_lower, |
|
cutoff_upper=cutoff_upper, |
|
x_in_embedding_type=x_in_embedding_type, |
|
x_use_msa=x_use_msa, |
|
drop_out_rate=drop_out_rate, |
|
use_lora=use_lora) |
|
|
|
def _set_attn_layers(self): |
|
assert self.num_layers > 0, "num_layers must be greater than 0" |
|
|
|
self.attention_layers.append( |
|
EquivariantWeightedPAEMultiHeadAttention( |
|
x_channels=self.x_channels, |
|
x_hidden_channels=self.x_hidden_channels, |
|
vec_channels=self.vec_channels, |
|
vec_hidden_channels=self.vec_hidden_channels, |
|
share_kv=self.share_kv, |
|
edge_attr_dist_channels=self.num_rbf, |
|
edge_attr_channels=self.num_edge_attr, |
|
distance_influence=self.distance_influence, |
|
num_heads=self.num_heads, |
|
activation=act_class_mapping[self.activation], |
|
attn_activation=self.attn_activation, |
|
cutoff_lower=self.cutoff_lower, |
|
cutoff_upper=self.cutoff_upper, |
|
use_lora=self.use_lora, |
|
) |
|
) |
|
|
|
for _ in range(self.num_layers - 1): |
|
layer = EquivariantWeightedPAEMultiHeadAttentionSoftMax( |
|
x_channels=self.x_channels, |
|
x_hidden_channels=self.x_hidden_channels, |
|
vec_channels=self.vec_channels, |
|
vec_hidden_channels=self.vec_hidden_channels, |
|
share_kv=self.share_kv, |
|
edge_attr_dist_channels=self.num_rbf, |
|
edge_attr_channels=self.num_edge_attr - 442 if self.use_msa else self.num_edge_attr, |
|
distance_influence=self.distance_influence, |
|
num_heads=self.num_heads, |
|
activation=act_class_mapping[self.activation], |
|
attn_activation=self.attn_activation, |
|
cutoff_lower=self.cutoff_lower, |
|
cutoff_upper=self.cutoff_upper, |
|
use_lora=self.use_lora, |
|
) |
|
self.attention_layers.append(layer) |
|
|
|
|
|
class eqTriStarTransformer(nn.Module): |
|
"""The equivariant Transformer architecture. |
|
|
|
Args: |
|
x_channels (int, optional): Hidden embedding size. |
|
(default: :obj:`128`) |
|
num_layers (int, optional): The number of attention layers. |
|
(default: :obj:`6`) |
|
num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
|
(default: :obj:`50`) |
|
rbf_type (string, optional): The type of radial basis function to use. |
|
(default: :obj:`"expnorm"`) |
|
trainable_rbf (bool, optional): Whether to train RBF parameters with |
|
backpropagation. (default: :obj:`True`) |
|
activation (string, optional): The type of activation function to use. |
|
(default: :obj:`"silu"`) |
|
attn_activation (string, optional): The type of activation function to use |
|
inside the attention mechanism. (default: :obj:`"silu"`) |
|
neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
|
embedding step. (default: :obj:`True`) |
|
num_heads (int, optional): Number of attention heads. |
|
(default: :obj:`8`) |
|
distance_influence (string, optional): Where distance information is used inside |
|
the attention mechanism. (default: :obj:`"both"`) |
|
cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
|
(default: :obj:`0.0`) |
|
cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
|
(default: :obj:`5.0`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnormunlim", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=False, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=False, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
): |
|
super(eqTriStarTransformer, self).__init__() |
|
|
|
assert distance_influence in ["keys", "values", "both", "none"] |
|
assert rbf_type in rbf_class_mapping, ( |
|
f'Unknown RBF type "{rbf_type}". ' |
|
f'Choose from {", ".join(rbf_class_mapping.keys())}.' |
|
) |
|
assert activation in act_class_mapping, ( |
|
f'Unknown activation function "{activation}". ' |
|
f'Choose from {", ".join(act_class_mapping.keys())}.' |
|
) |
|
assert attn_activation in act_class_mapping, ( |
|
f'Unknown attention activation function "{attn_activation}". ' |
|
f'Choose from {", ".join(act_class_mapping.keys())}.' |
|
) |
|
|
|
self.x_in_channels = x_in_channels |
|
self.x_channels = x_channels |
|
self.vec_in_channels = vec_in_channels |
|
self.vec_channels = vec_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
self.num_layers = num_layers |
|
self.num_rbf = num_rbf |
|
self.num_edge_attr = num_edge_attr |
|
self.rbf_type = rbf_type |
|
self.trainable_rbf = trainable_rbf |
|
self.activation = activation |
|
self.attn_activation = attn_activation |
|
self.neighbor_embedding = neighbor_embedding |
|
self.num_heads = num_heads |
|
self.distance_influence = distance_influence |
|
self.cutoff_lower = cutoff_lower |
|
self.cutoff_upper = cutoff_upper |
|
|
|
self.distance = DistanceV2( |
|
return_vecs=True, |
|
loop=True, |
|
) |
|
self.distance_expansion = rbf_class_mapping[rbf_type]( |
|
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf |
|
) |
|
|
|
self.node_x_proj = None |
|
if x_in_channels is not None: |
|
self.node_x_proj = nn.Linear(x_in_channels, x_channels) if x_in_embedding_type == "Linear" \ |
|
else nn.Embedding(x_in_channels, x_channels) |
|
|
|
self.attention_layers = nn.ModuleList() |
|
self._set_attn_layers() |
|
self.drop = nn.Dropout(drop_out_rate) |
|
self.out_norm = nn.LayerNorm(x_channels) |
|
|
|
self.reset_parameters() |
|
|
|
def _set_attn_layers(self): |
|
for _ in range(self.num_layers): |
|
layer = EquivariantTriAngularMultiHeadAttention( |
|
x_channels=self.x_channels, |
|
x_hidden_channels=self.x_hidden_channels, |
|
vec_channels=self.vec_in_channels, |
|
vec_hidden_channels=self.vec_channels, |
|
edge_attr_channels=self.num_rbf + self.num_edge_attr, |
|
distance_influence=self.distance_influence, |
|
num_heads=self.num_heads, |
|
activation=act_class_mapping[self.activation], |
|
attn_activation=self.attn_activation, |
|
cutoff_lower=self.cutoff_lower, |
|
cutoff_upper=self.cutoff_upper, |
|
) |
|
self.attention_layers.append(layer) |
|
|
|
def reset_parameters(self): |
|
self.distance_expansion.reset_parameters() |
|
for attn in self.attention_layers: |
|
attn.reset_parameters() |
|
self.out_norm.reset_parameters() |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
pos: Tensor, |
|
batch: Tensor, |
|
edge_index: Tensor, |
|
edge_index_star: Tensor = None, |
|
edge_attr: Tensor = None, |
|
edge_attr_star: Tensor = None, |
|
node_vec_attr: Tensor = None, |
|
return_attn: bool = False, |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
|
coords = node_vec_attr + pos.unsqueeze(2) |
|
edge_index, edge_weight, edge_vec = self.distance(pos, coords, edge_index) |
|
edge_index_star, edge_weight_star, edge_vec_star = self.distance(pos, coords, edge_index_star) |
|
assert ( |
|
edge_vec is not None |
|
), "Distance module did not return directional information" |
|
|
|
|
|
|
|
|
|
|
|
|
|
edge_attr = torch.cat([edge_attr, self.distance_expansion(edge_weight)], dim=-1) |
|
edge_attr_star = torch.cat([edge_attr_star, self.distance_expansion(edge_weight_star)], dim=-1) |
|
mask = edge_index[0] != edge_index[1] |
|
edge_vec[mask] = edge_vec[mask] / torch.norm(edge_vec[mask], dim=1).unsqueeze(1) |
|
mask = edge_index_star[0] != edge_index_star[1] |
|
edge_vec_star[mask] = edge_vec_star[mask] / torch.norm(edge_vec_star[mask], dim=1).unsqueeze(1) |
|
del mask, edge_weight, edge_weight_star |
|
|
|
x = self.node_x_proj(x) if self.node_x_proj is not None else x |
|
|
|
attn_weight_layers = [] |
|
for i, attn in enumerate(self.attention_layers): |
|
if i == 0: |
|
dx, edge_attr_star, attn_weight = attn( |
|
x, edge_index_star, edge_attr_star, edge_vec_star) |
|
else: |
|
dx, edge_attr, attn_weight = attn( |
|
x, edge_index, edge_attr, edge_vec) |
|
x = x + self.drop(dx) |
|
if return_attn: |
|
attn_weight_layers.append(attn_weight) |
|
x = self.out_norm(x) |
|
return x, None, pos, edge_attr, batch, attn_weight_layers |
|
|
|
def __repr__(self): |
|
return ( |
|
f"{self.__class__.__name__}(" |
|
f"x_channels={self.x_channels}, " |
|
f"x_hidden_channels={self.x_hidden_channels}, " |
|
f"vec_in_channels={self.vec_in_channels}, " |
|
f"vec_channels={self.vec_channels}, " |
|
f"vec_hidden_channels={self.vec_hidden_channels}, " |
|
f"num_layers={self.num_layers}, " |
|
f"num_rbf={self.num_rbf}, " |
|
f"rbf_type={self.rbf_type}, " |
|
f"trainable_rbf={self.trainable_rbf}, " |
|
f"activation={self.activation}, " |
|
f"attn_activation={self.attn_activation}, " |
|
f"neighbor_embedding={self.neighbor_embedding}, " |
|
f"num_heads={self.num_heads}, " |
|
f"distance_influence={self.distance_influence}, " |
|
f"cutoff_lower={self.cutoff_lower}, " |
|
f"cutoff_upper={self.cutoff_upper})" |
|
) |
|
|
|
|
|
class eqMSATriStarTransformer(nn.Module): |
|
"""The equivariant Transformer architecture. Edge attributes are MSA weights. |
|
|
|
Args: |
|
x_channels (int, optional): Hidden embedding size. |
|
(default: :obj:`128`) |
|
num_layers (int, optional): The number of attention layers. |
|
(default: :obj:`6`) |
|
num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
|
(default: :obj:`50`) |
|
rbf_type (string, optional): The type of radial basis function to use. |
|
(default: :obj:`"expnorm"`) |
|
trainable_rbf (bool, optional): Whether to train RBF parameters with |
|
backpropagation. (default: :obj:`True`) |
|
activation (string, optional): The type of activation function to use. |
|
(default: :obj:`"silu"`) |
|
attn_activation (string, optional): The type of activation function to use |
|
inside the attention mechanism. (default: :obj:`"silu"`) |
|
neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
|
embedding step. (default: :obj:`True`) |
|
num_heads (int, optional): Number of attention heads. |
|
(default: :obj:`8`) |
|
distance_influence (string, optional): Where distance information is used inside |
|
the attention mechanism. (default: :obj:`"both"`) |
|
cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
|
(default: :obj:`0.0`) |
|
cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
|
(default: :obj:`5.0`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnormunlim", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=False, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=True, |
|
triangular_update=True, |
|
ee_channels=None, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
): |
|
super(eqMSATriStarTransformer, self).__init__() |
|
|
|
assert distance_influence in ["keys", "values", "both", "none"] |
|
assert rbf_type in rbf_class_mapping, ( |
|
f'Unknown RBF type "{rbf_type}". ' |
|
f'Choose from {", ".join(rbf_class_mapping.keys())}.' |
|
) |
|
assert activation in act_class_mapping, ( |
|
f'Unknown activation function "{activation}". ' |
|
f'Choose from {", ".join(act_class_mapping.keys())}.' |
|
) |
|
assert attn_activation in act_class_mapping, ( |
|
f'Unknown attention activation function "{attn_activation}". ' |
|
f'Choose from {", ".join(act_class_mapping.keys())}.' |
|
) |
|
|
|
self.x_in_channels = x_in_channels |
|
self.x_channels = x_channels |
|
self.vec_in_channels = vec_in_channels |
|
self.vec_channels = vec_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
self.num_layers = num_layers |
|
self.num_rbf = num_rbf |
|
self.num_edge_attr = num_edge_attr |
|
self.rbf_type = rbf_type |
|
self.trainable_rbf = trainable_rbf |
|
self.activation = activation |
|
self.attn_activation = attn_activation |
|
self.neighbor_embedding = neighbor_embedding |
|
self.num_heads = num_heads |
|
self.distance_influence = distance_influence |
|
self.cutoff_lower = cutoff_lower |
|
self.cutoff_upper = cutoff_upper |
|
self.triangular_update = triangular_update |
|
|
|
self.distance = DistanceV2( |
|
return_vecs=True, |
|
loop=True, |
|
) |
|
self.distance_expansion = rbf_class_mapping[rbf_type]( |
|
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf |
|
) |
|
self.msa_encoder = MSAEncoder( |
|
num_species=199, |
|
weighting_schema='spe', |
|
pairwise_type='cov', |
|
) if x_use_msa else None |
|
|
|
self.node_x_proj = None |
|
if x_in_channels is not None: |
|
if x_in_embedding_type == "Linear": |
|
self.node_x_proj = nn.Linear(x_in_channels, x_channels) |
|
elif x_in_embedding_type == "Linear_gelu": |
|
self.node_x_proj = nn.Sequential( |
|
nn.Linear(x_in_channels, x_channels), |
|
nn.GELU(), |
|
) |
|
else: |
|
nn.Embedding(x_in_channels, x_channels) |
|
self.ee_channels = ee_channels |
|
self.attention_layers = nn.ModuleList() |
|
self._set_attn_layers() |
|
self.drop = nn.Dropout(drop_out_rate) |
|
self.out_norm = nn.LayerNorm(x_channels) |
|
|
|
self.reset_parameters() |
|
|
|
def _set_attn_layers(self): |
|
for _ in range(self.num_layers): |
|
layer = EquivariantTriAngularMultiHeadAttention( |
|
x_channels=self.x_channels, |
|
x_hidden_channels=self.x_hidden_channels, |
|
vec_channels=self.vec_in_channels, |
|
vec_hidden_channels=self.vec_channels, |
|
edge_attr_channels=self.num_rbf + self.num_edge_attr, |
|
distance_influence=self.distance_influence, |
|
num_heads=self.num_heads, |
|
activation=act_class_mapping[self.activation], |
|
attn_activation=self.attn_activation, |
|
cutoff_lower=self.cutoff_lower, |
|
cutoff_upper=self.cutoff_upper, |
|
ee_channels=self.ee_channels, |
|
triangular_update=self.triangular_update, |
|
) |
|
self.attention_layers.append(layer) |
|
|
|
def reset_parameters(self): |
|
self.distance_expansion.reset_parameters() |
|
for attn in self.attention_layers: |
|
attn.reset_parameters() |
|
self.out_norm.reset_parameters() |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
pos: Tensor, |
|
batch: Tensor, |
|
edge_index: Tensor, |
|
edge_index_star: Tensor = None, |
|
edge_attr: Tensor = None, |
|
edge_attr_star: Tensor = None, |
|
node_vec_attr: Tensor = None, |
|
return_attn: bool = False, |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
|
coords = node_vec_attr + pos.unsqueeze(2) |
|
|
|
edge_index_star, edge_weight_star, edge_vec_star = self.distance(pos, coords, edge_index_star) |
|
|
|
if (self.x_in_channels is not None and x.shape[1] > self.x_in_channels) or x.shape[1] > self.x_channels: |
|
if self.node_x_proj is not None: |
|
x, x_msa = x[:, :self.x_in_channels], x[:, self.x_in_channels:] |
|
else: |
|
x, x_msa = x[:, :self.x_channels], x[:, self.x_channels:] |
|
else: |
|
x_msa = None |
|
|
|
|
|
|
|
|
|
|
|
if self.msa_encoder is not None and x_msa is not None: |
|
_, msa_edge_attr_star = self.msa_encoder(x_msa, edge_index_star) |
|
edge_attr_star = torch.cat([edge_attr_star, msa_edge_attr_star], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
del edge_attr |
|
edge_attr_star = torch.cat([edge_attr_star, self.distance_expansion(edge_weight_star)], dim=-1) |
|
|
|
|
|
mask = edge_index_star[0] != edge_index_star[1] |
|
edge_vec_star[mask] = edge_vec_star[mask] / torch.norm(edge_vec_star[mask], dim=1).unsqueeze(1) |
|
del mask, edge_weight_star |
|
|
|
x = self.node_x_proj(x) if self.node_x_proj is not None else x |
|
|
|
attn_weight_layers = [] |
|
for i, attn in enumerate(self.attention_layers): |
|
if i == 0: |
|
dx, edge_attr_star, attn_weight = attn( |
|
x, coords, edge_index_star, edge_attr_star, edge_vec_star) |
|
else: |
|
dx = 0 |
|
x = x + self.drop(dx) |
|
if return_attn: |
|
attn_weight_layers.append(attn_weight) |
|
x = self.out_norm(x) |
|
return x, None, pos, edge_attr_star, batch, attn_weight_layers |
|
|
|
def __repr__(self): |
|
return ( |
|
f"{self.__class__.__name__}(" |
|
f"x_channels={self.x_channels}, " |
|
f"x_hidden_channels={self.x_hidden_channels}, " |
|
f"vec_in_channels={self.vec_in_channels}, " |
|
f"vec_channels={self.vec_channels}, " |
|
f"vec_hidden_channels={self.vec_hidden_channels}, " |
|
f"num_layers={self.num_layers}, " |
|
f"num_rbf={self.num_rbf}, " |
|
f"rbf_type={self.rbf_type}, " |
|
f"trainable_rbf={self.trainable_rbf}, " |
|
f"activation={self.activation}, " |
|
f"attn_activation={self.attn_activation}, " |
|
f"neighbor_embedding={self.neighbor_embedding}, " |
|
f"num_heads={self.num_heads}, " |
|
f"distance_influence={self.distance_influence}, " |
|
f"cutoff_lower={self.cutoff_lower}, " |
|
f"cutoff_upper={self.cutoff_upper})" |
|
) |
|
|
|
|
|
class eqMSATriStarGRUTransformer(nn.Module): |
|
"""The equivariant Transformer architecture. Edge attributes are MSA weights. |
|
|
|
Args: |
|
x_channels (int, optional): Hidden embedding size. |
|
(default: :obj:`128`) |
|
num_layers (int, optional): The number of attention layers. |
|
(default: :obj:`6`) |
|
num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
|
(default: :obj:`50`) |
|
rbf_type (string, optional): The type of radial basis function to use. |
|
(default: :obj:`"expnorm"`) |
|
trainable_rbf (bool, optional): Whether to train RBF parameters with |
|
backpropagation. (default: :obj:`True`) |
|
activation (string, optional): The type of activation function to use. |
|
(default: :obj:`"silu"`) |
|
attn_activation (string, optional): The type of activation function to use |
|
inside the attention mechanism. (default: :obj:`"silu"`) |
|
neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
|
embedding step. (default: :obj:`True`) |
|
num_heads (int, optional): Number of attention heads. |
|
(default: :obj:`8`) |
|
distance_influence (string, optional): Where distance information is used inside |
|
the attention mechanism. (default: :obj:`"both"`) |
|
cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
|
(default: :obj:`0.0`) |
|
cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
|
(default: :obj:`5.0`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnormunlim", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=False, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=True, |
|
triangular_update=True, |
|
ee_channels=None, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
): |
|
super(eqMSATriStarGRUTransformer, self).__init__() |
|
|
|
assert distance_influence in ["keys", "values", "both", "none"] |
|
assert rbf_type in rbf_class_mapping, ( |
|
f'Unknown RBF type "{rbf_type}". ' |
|
f'Choose from {", ".join(rbf_class_mapping.keys())}.' |
|
) |
|
assert activation in act_class_mapping, ( |
|
f'Unknown activation function "{activation}". ' |
|
f'Choose from {", ".join(act_class_mapping.keys())}.' |
|
) |
|
assert attn_activation in act_class_mapping, ( |
|
f'Unknown attention activation function "{attn_activation}". ' |
|
f'Choose from {", ".join(act_class_mapping.keys())}.' |
|
) |
|
|
|
self.x_in_channels = x_in_channels |
|
self.x_channels = x_channels |
|
self.vec_in_channels = vec_in_channels |
|
self.vec_channels = vec_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
self.num_layers = num_layers |
|
self.num_rbf = num_rbf |
|
self.num_edge_attr = num_edge_attr |
|
self.rbf_type = rbf_type |
|
self.trainable_rbf = trainable_rbf |
|
self.activation = activation |
|
self.attn_activation = attn_activation |
|
self.neighbor_embedding = neighbor_embedding |
|
self.num_heads = num_heads |
|
self.distance_influence = distance_influence |
|
self.cutoff_lower = cutoff_lower |
|
self.cutoff_upper = cutoff_upper |
|
self.triangular_update = triangular_update |
|
|
|
self.distance = DistanceV2( |
|
return_vecs=True, |
|
loop=True, |
|
) |
|
self.distance_expansion = rbf_class_mapping[rbf_type]( |
|
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf |
|
) |
|
self.msa_encoder = MSAEncoder( |
|
num_species=199, |
|
weighting_schema='spe', |
|
pairwise_type='cov', |
|
) if x_use_msa else None |
|
|
|
self.node_x_proj = None |
|
if x_in_channels is not None: |
|
if x_in_embedding_type == "Linear": |
|
self.node_x_proj = nn.Linear(x_in_channels, x_channels) |
|
elif x_in_embedding_type == "Linear_gelu": |
|
self.node_x_proj = nn.Sequential( |
|
nn.Linear(x_in_channels, x_channels), |
|
nn.GELU(), |
|
) |
|
else: |
|
nn.Embedding(x_in_channels, x_channels) |
|
self.ee_channels = ee_channels |
|
self.attention_layers = nn.ModuleList() |
|
self._set_attn_layers() |
|
self.drop = nn.Dropout(drop_out_rate) |
|
|
|
|
|
self.reset_parameters() |
|
|
|
def _set_attn_layers(self): |
|
for _ in range(self.num_layers): |
|
layer = EquivariantTriAngularStarMultiHeadAttention( |
|
x_channels=self.x_channels, |
|
x_hidden_channels=self.x_hidden_channels, |
|
vec_channels=self.vec_in_channels, |
|
vec_hidden_channels=self.vec_channels, |
|
edge_attr_channels=self.num_rbf + self.num_edge_attr, |
|
distance_influence=self.distance_influence, |
|
num_heads=self.num_heads, |
|
activation=act_class_mapping[self.activation], |
|
attn_activation=self.attn_activation, |
|
cutoff_lower=self.cutoff_lower, |
|
cutoff_upper=self.cutoff_upper, |
|
ee_channels=self.ee_channels, |
|
triangular_update=self.triangular_update, |
|
) |
|
self.attention_layers.append(layer) |
|
|
|
def reset_parameters(self): |
|
self.distance_expansion.reset_parameters() |
|
for attn in self.attention_layers: |
|
attn.reset_parameters() |
|
|
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
x_center: Tensor, |
|
x_mask: Tensor, |
|
pos: Tensor, |
|
batch: Tensor, |
|
edge_index: Tensor, |
|
edge_index_star: Tensor = None, |
|
edge_attr: Tensor = None, |
|
edge_attr_star: Tensor = None, |
|
node_vec_attr: Tensor = None, |
|
return_attn: bool = False, |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
|
coords = node_vec_attr + pos.unsqueeze(2) |
|
|
|
edge_index_star, edge_weight_star, edge_vec_star = self.distance(pos, coords, edge_index_star) |
|
|
|
if (self.x_in_channels is not None and x.shape[1] > self.x_in_channels) or x.shape[1] > self.x_channels: |
|
if self.node_x_proj is not None: |
|
x, x_msa = x[:, :self.x_in_channels], x[:, self.x_in_channels:] |
|
else: |
|
x, x_msa = x[:, :self.x_channels], x[:, self.x_channels:] |
|
else: |
|
x_msa = None |
|
|
|
|
|
|
|
|
|
|
|
if self.msa_encoder is not None and x_msa is not None: |
|
_, msa_edge_attr_star = self.msa_encoder(x_msa, edge_index_star) |
|
edge_attr_star = torch.cat([edge_attr_star, msa_edge_attr_star], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
del edge_attr |
|
edge_attr_star = torch.cat([edge_attr_star, self.distance_expansion(edge_weight_star)], dim=-1) |
|
|
|
|
|
mask = edge_index_star[0] != edge_index_star[1] |
|
edge_vec_star[mask] = edge_vec_star[mask] / torch.norm(edge_vec_star[mask], dim=1).unsqueeze(1) |
|
del mask, edge_weight_star |
|
|
|
x = self.node_x_proj(x) if self.node_x_proj is not None else x |
|
x = x * x_mask.unsqueeze(1) + x_center * (~x_mask).unsqueeze(1) |
|
|
|
attn_weight_layers = [] |
|
for _, attn in enumerate(self.attention_layers): |
|
x, edge_attr_star, attn_weight = attn( |
|
x, coords, edge_index_star, edge_attr_star, edge_vec_star) |
|
if return_attn: |
|
attn_weight_layers.append(attn_weight) |
|
x = self.drop(x) |
|
|
|
batch = batch[~x_mask] |
|
return x, None, pos, edge_attr_star, batch, attn_weight_layers |
|
|
|
def __repr__(self): |
|
return ( |
|
f"{self.__class__.__name__}(" |
|
f"x_channels={self.x_channels}, " |
|
f"x_hidden_channels={self.x_hidden_channels}, " |
|
f"vec_in_channels={self.vec_in_channels}, " |
|
f"vec_channels={self.vec_channels}, " |
|
f"vec_hidden_channels={self.vec_hidden_channels}, " |
|
f"num_layers={self.num_layers}, " |
|
f"num_rbf={self.num_rbf}, " |
|
f"rbf_type={self.rbf_type}, " |
|
f"trainable_rbf={self.trainable_rbf}, " |
|
f"activation={self.activation}, " |
|
f"attn_activation={self.attn_activation}, " |
|
f"neighbor_embedding={self.neighbor_embedding}, " |
|
f"num_heads={self.num_heads}, " |
|
f"distance_influence={self.distance_influence}, " |
|
f"cutoff_lower={self.cutoff_lower}, " |
|
f"cutoff_upper={self.cutoff_upper})" |
|
) |
|
|
|
|
|
class eqMSATriStarDropTransformer(nn.Module): |
|
"""The equivariant Transformer architecture. Edge attributes are MSA weights, distances and drop out is applied. |
|
|
|
Args: |
|
x_channels (int, optional): Hidden embedding size. |
|
(default: :obj:`128`) |
|
num_layers (int, optional): The number of attention layers. |
|
(default: :obj:`6`) |
|
num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
|
(default: :obj:`50`) |
|
rbf_type (string, optional): The type of radial basis function to use. |
|
(default: :obj:`"expnorm"`) |
|
trainable_rbf (bool, optional): Whether to train RBF parameters with |
|
backpropagation. (default: :obj:`True`) |
|
activation (string, optional): The type of activation function to use. |
|
(default: :obj:`"silu"`) |
|
attn_activation (string, optional): The type of activation function to use |
|
inside the attention mechanism. (default: :obj:`"silu"`) |
|
neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
|
embedding step. (default: :obj:`True`) |
|
num_heads (int, optional): Number of attention heads. |
|
(default: :obj:`8`) |
|
distance_influence (string, optional): Where distance information is used inside |
|
the attention mechanism. (default: :obj:`"both"`) |
|
cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
|
(default: :obj:`0.0`) |
|
cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
|
(default: :obj:`5.0`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnormunlim", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=False, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=True, |
|
triangular_update=True, |
|
ee_channels=None, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
layer_norm=True, |
|
): |
|
super(eqMSATriStarDropTransformer, self).__init__() |
|
|
|
assert distance_influence in ["keys", "values", "both", "none"] |
|
assert rbf_type in rbf_class_mapping, ( |
|
f'Unknown RBF type "{rbf_type}". ' |
|
f'Choose from {", ".join(rbf_class_mapping.keys())}.' |
|
) |
|
assert activation in act_class_mapping, ( |
|
f'Unknown activation function "{activation}". ' |
|
f'Choose from {", ".join(act_class_mapping.keys())}.' |
|
) |
|
assert attn_activation in act_class_mapping, ( |
|
f'Unknown attention activation function "{attn_activation}". ' |
|
f'Choose from {", ".join(act_class_mapping.keys())}.' |
|
) |
|
|
|
self.x_in_channels = x_in_channels |
|
self.x_channels = x_channels |
|
self.vec_in_channels = vec_in_channels |
|
self.vec_channels = vec_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
self.num_layers = num_layers |
|
self.num_rbf = num_rbf |
|
self.num_edge_attr = num_edge_attr |
|
self.rbf_type = rbf_type |
|
self.trainable_rbf = trainable_rbf |
|
self.activation = activation |
|
self.attn_activation = attn_activation |
|
self.neighbor_embedding = neighbor_embedding |
|
self.num_heads = num_heads |
|
self.distance_influence = distance_influence |
|
self.cutoff_lower = cutoff_lower |
|
self.cutoff_upper = cutoff_upper |
|
self.triangular_update = triangular_update |
|
self.use_lora = use_lora |
|
self.layer_norm = layer_norm |
|
|
|
self.distance = DistanceV2( |
|
return_vecs=True, |
|
loop=True, |
|
) |
|
self.distance_expansion = rbf_class_mapping[rbf_type]( |
|
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf |
|
) |
|
self.msa_encoder = MSAEncoder( |
|
num_species=199, |
|
weighting_schema='spe', |
|
pairwise_type='cov', |
|
) if x_use_msa else None |
|
|
|
self.node_x_proj = None |
|
if x_in_channels is not None: |
|
if x_in_embedding_type == "Linear": |
|
if use_lora is not None: |
|
self.node_x_proj = lora.Linear(x_in_channels, x_channels, r=use_lora) |
|
else: |
|
self.node_x_proj = nn.Linear(x_in_channels, x_channels) |
|
elif x_in_embedding_type == "Linear_gelu": |
|
self.node_x_proj = nn.Sequential( |
|
lora.Linear(x_in_channels, x_channels, r=use_lora) if use_lora is not None else nn.Linear(x_in_channels, x_channels), |
|
nn.GELU(), |
|
) |
|
else: |
|
nn.Embedding(x_in_channels, x_channels) if use_lora is None else lora.Embedding(x_in_channels, x_channels, r=use_lora) |
|
self.ee_channels = ee_channels |
|
self.attention_layers = nn.ModuleList() |
|
|
|
self.drop_out_rate = drop_out_rate |
|
self._set_attn_layers() |
|
|
|
|
|
self.reset_parameters() |
|
|
|
def _set_attn_layers(self): |
|
for _ in range(self.num_layers): |
|
layer = EquivariantTriAngularDropMultiHeadAttention( |
|
x_channels=self.x_channels, |
|
x_hidden_channels=self.x_hidden_channels, |
|
vec_channels=self.vec_in_channels, |
|
vec_hidden_channels=self.vec_channels, |
|
edge_attr_channels=self.num_rbf + self.num_edge_attr, |
|
distance_influence=self.distance_influence, |
|
num_heads=self.num_heads, |
|
activation=act_class_mapping[self.activation], |
|
attn_activation=self.attn_activation, |
|
ee_channels=self.ee_channels, |
|
rbf_channels=self.num_rbf, |
|
triangular_update=self.triangular_update, |
|
drop_out_rate=self.drop_out_rate, |
|
use_lora=self.use_lora, |
|
layer_norm=self.layer_norm, |
|
) |
|
self.attention_layers.append(layer) |
|
|
|
def reset_parameters(self): |
|
self.distance_expansion.reset_parameters() |
|
for attn in self.attention_layers: |
|
attn.reset_parameters() |
|
|
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
pos: Tensor, |
|
batch: Tensor, |
|
edge_index: Tensor, |
|
edge_index_star: Tensor = None, |
|
edge_attr: Tensor = None, |
|
edge_attr_star: Tensor = None, |
|
node_vec_attr: Tensor = None, |
|
return_attn: bool = False, |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
|
coords = node_vec_attr + pos.unsqueeze(2) |
|
|
|
edge_index_star, edge_weight_star, edge_vec_star = self.distance(pos, coords, edge_index_star) |
|
|
|
if (self.x_in_channels is not None and x.shape[1] > self.x_in_channels) or x.shape[1] > self.x_channels: |
|
if self.node_x_proj is not None: |
|
x, x_msa = x[:, :self.x_in_channels], x[:, self.x_in_channels:] |
|
else: |
|
x, x_msa = x[:, :self.x_channels], x[:, self.x_channels:] |
|
else: |
|
x_msa = None |
|
|
|
|
|
|
|
|
|
|
|
if self.msa_encoder is not None and x_msa is not None: |
|
_, msa_edge_attr_star = self.msa_encoder(x_msa, edge_index_star) |
|
edge_attr_star = torch.cat([edge_attr_star, msa_edge_attr_star], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
del edge_attr |
|
edge_attr_star = torch.cat([edge_attr_star, self.distance_expansion(edge_weight_star)], dim=-1) |
|
|
|
|
|
mask = edge_index_star[0] != edge_index_star[1] |
|
edge_vec_star[mask] = edge_vec_star[mask] / torch.norm(edge_vec_star[mask], dim=1).unsqueeze(1) |
|
del mask, edge_weight_star |
|
|
|
x = self.node_x_proj(x) if self.node_x_proj is not None else x |
|
|
|
|
|
attn_weight_layers = [] |
|
for _, attn in enumerate(self.attention_layers): |
|
x, edge_attr_star, attn_weight = attn( |
|
x, coords, edge_index_star, edge_attr_star, edge_vec_star) |
|
if return_attn: |
|
attn_weight_layers.append(attn_weight) |
|
|
|
|
|
|
|
return x, None, pos, edge_attr_star, batch, attn_weight_layers |
|
|
|
def __repr__(self): |
|
return ( |
|
f"{self.__class__.__name__}(" |
|
f"x_channels={self.x_channels}, " |
|
f"x_hidden_channels={self.x_hidden_channels}, " |
|
f"vec_in_channels={self.vec_in_channels}, " |
|
f"vec_channels={self.vec_channels}, " |
|
f"vec_hidden_channels={self.vec_hidden_channels}, " |
|
f"num_layers={self.num_layers}, " |
|
f"num_rbf={self.num_rbf}, " |
|
f"rbf_type={self.rbf_type}, " |
|
f"trainable_rbf={self.trainable_rbf}, " |
|
f"activation={self.activation}, " |
|
f"attn_activation={self.attn_activation}, " |
|
f"neighbor_embedding={self.neighbor_embedding}, " |
|
f"num_heads={self.num_heads}, " |
|
f"distance_influence={self.distance_influence}, " |
|
f"cutoff_lower={self.cutoff_lower}, " |
|
f"cutoff_upper={self.cutoff_upper})" |
|
) |
|
|
|
|
|
class eqMSATriStarDropGRUTransformer(nn.Module): |
|
"""The equivariant Transformer architecture. Edge attributes are MSA weights, distances and drop out is applied. |
|
|
|
Args: |
|
x_channels (int, optional): Hidden embedding size. |
|
(default: :obj:`128`) |
|
num_layers (int, optional): The number of attention layers. |
|
(default: :obj:`6`) |
|
num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
|
(default: :obj:`50`) |
|
rbf_type (string, optional): The type of radial basis function to use. |
|
(default: :obj:`"expnorm"`) |
|
trainable_rbf (bool, optional): Whether to train RBF parameters with |
|
backpropagation. (default: :obj:`True`) |
|
activation (string, optional): The type of activation function to use. |
|
(default: :obj:`"silu"`) |
|
attn_activation (string, optional): The type of activation function to use |
|
inside the attention mechanism. (default: :obj:`"silu"`) |
|
neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
|
embedding step. (default: :obj:`True`) |
|
num_heads (int, optional): Number of attention heads. |
|
(default: :obj:`8`) |
|
distance_influence (string, optional): Where distance information is used inside |
|
the attention mechanism. (default: :obj:`"both"`) |
|
cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
|
(default: :obj:`0.0`) |
|
cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
|
(default: :obj:`5.0`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
x_in_channels=None, |
|
x_channels=5120, |
|
x_hidden_channels=1280, |
|
vec_in_channels=4, |
|
vec_channels=128, |
|
vec_hidden_channels=5120, |
|
num_layers=6, |
|
num_edge_attr=145, |
|
num_rbf=50, |
|
rbf_type="expnormunlim", |
|
trainable_rbf=True, |
|
activation="silu", |
|
attn_activation="silu", |
|
neighbor_embedding=False, |
|
num_heads=8, |
|
distance_influence="both", |
|
cutoff_lower=0.0, |
|
cutoff_upper=5.0, |
|
x_in_embedding_type="Linear", |
|
x_use_msa=True, |
|
triangular_update=True, |
|
ee_channels=None, |
|
drop_out_rate=0, |
|
use_lora=None, |
|
): |
|
super(eqMSATriStarDropGRUTransformer, self).__init__() |
|
|
|
assert distance_influence in ["keys", "values", "both", "none"] |
|
assert rbf_type in rbf_class_mapping, ( |
|
f'Unknown RBF type "{rbf_type}". ' |
|
f'Choose from {", ".join(rbf_class_mapping.keys())}.' |
|
) |
|
assert activation in act_class_mapping, ( |
|
f'Unknown activation function "{activation}". ' |
|
f'Choose from {", ".join(act_class_mapping.keys())}.' |
|
) |
|
assert attn_activation in act_class_mapping, ( |
|
f'Unknown attention activation function "{attn_activation}". ' |
|
f'Choose from {", ".join(act_class_mapping.keys())}.' |
|
) |
|
|
|
self.x_in_channels = x_in_channels |
|
self.x_channels = x_channels |
|
self.vec_in_channels = vec_in_channels |
|
self.vec_channels = vec_channels |
|
self.x_hidden_channels = x_hidden_channels |
|
self.vec_hidden_channels = vec_hidden_channels |
|
self.num_layers = num_layers |
|
self.num_rbf = num_rbf |
|
self.num_edge_attr = num_edge_attr |
|
self.rbf_type = rbf_type |
|
self.trainable_rbf = trainable_rbf |
|
self.activation = activation |
|
self.attn_activation = attn_activation |
|
self.neighbor_embedding = neighbor_embedding |
|
self.num_heads = num_heads |
|
self.distance_influence = distance_influence |
|
self.cutoff_lower = cutoff_lower |
|
self.cutoff_upper = cutoff_upper |
|
self.triangular_update = triangular_update |
|
self.use_lora = use_lora |
|
|
|
self.distance = DistanceV2( |
|
return_vecs=True, |
|
loop=True, |
|
) |
|
self.distance_expansion = rbf_class_mapping[rbf_type]( |
|
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf |
|
) |
|
self.msa_encoder = MSAEncoder( |
|
num_species=199, |
|
weighting_schema='spe', |
|
pairwise_type='cov', |
|
) if x_use_msa else None |
|
|
|
self.node_x_proj = None |
|
if x_in_channels is not None: |
|
if x_in_embedding_type == "Linear": |
|
if use_lora is not None: |
|
self.node_x_proj = lora.Linear(x_in_channels, x_channels, r=use_lora) |
|
else: |
|
self.node_x_proj = nn.Linear(x_in_channels, x_channels) |
|
elif x_in_embedding_type == "Linear_gelu": |
|
self.node_x_proj = nn.Sequential( |
|
lora.Linear(x_in_channels, x_channels, r=use_lora) if use_lora is not None else nn.Linear(x_in_channels, x_channels), |
|
nn.GELU(), |
|
) |
|
else: |
|
nn.Embedding(x_in_channels, x_channels) if use_lora is None else lora.Embedding(x_in_channels, x_channels, r=use_lora) |
|
self.ee_channels = ee_channels |
|
self.attention_layers = nn.ModuleList() |
|
|
|
self.drop_out_rate = drop_out_rate |
|
self._set_attn_layers() |
|
|
|
|
|
self.reset_parameters() |
|
|
|
def _set_attn_layers(self): |
|
for _ in range(self.num_layers): |
|
layer = EquivariantTriAngularStarDropMultiHeadAttention( |
|
x_channels=self.x_channels, |
|
x_hidden_channels=self.x_hidden_channels, |
|
vec_channels=self.vec_in_channels, |
|
vec_hidden_channels=self.vec_channels, |
|
edge_attr_channels=self.num_rbf + self.num_edge_attr, |
|
distance_influence=self.distance_influence, |
|
num_heads=self.num_heads, |
|
activation=act_class_mapping[self.activation], |
|
attn_activation=self.attn_activation, |
|
ee_channels=self.ee_channels, |
|
rbf_channels=self.num_rbf, |
|
triangular_update=self.triangular_update, |
|
drop_out_rate=self.drop_out_rate, |
|
use_lora=self.use_lora, |
|
) |
|
self.attention_layers.append(layer) |
|
|
|
def reset_parameters(self): |
|
self.distance_expansion.reset_parameters() |
|
for attn in self.attention_layers: |
|
attn.reset_parameters() |
|
|
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
x_center: Tensor, |
|
x_mask: Tensor, |
|
pos: Tensor, |
|
batch: Tensor, |
|
edge_index: Tensor, |
|
edge_index_star: Tensor = None, |
|
edge_attr: Tensor = None, |
|
edge_attr_star: Tensor = None, |
|
node_vec_attr: Tensor = None, |
|
return_attn: bool = False, |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
|
coords = node_vec_attr + pos.unsqueeze(2) |
|
|
|
edge_index_star, edge_weight_star, edge_vec_star = self.distance(pos, coords, edge_index_star) |
|
|
|
if (self.x_in_channels is not None and x.shape[1] > self.x_in_channels) or x.shape[1] > self.x_channels: |
|
if self.node_x_proj is not None: |
|
x, x_msa = x[:, :self.x_in_channels], x[:, self.x_in_channels:] |
|
else: |
|
x, x_msa = x[:, :self.x_channels], x[:, self.x_channels:] |
|
else: |
|
x_msa = None |
|
|
|
|
|
|
|
|
|
|
|
if self.msa_encoder is not None and x_msa is not None: |
|
_, msa_edge_attr_star = self.msa_encoder(x_msa, edge_index_star) |
|
edge_attr_star = torch.cat([edge_attr_star, msa_edge_attr_star], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
del edge_attr |
|
edge_attr_star = torch.cat([edge_attr_star, self.distance_expansion(edge_weight_star)], dim=-1) |
|
|
|
|
|
mask = edge_index_star[0] != edge_index_star[1] |
|
edge_vec_star[mask] = edge_vec_star[mask] / torch.norm(edge_vec_star[mask], dim=1).unsqueeze(1) |
|
del mask, edge_weight_star |
|
|
|
x = self.node_x_proj(x) if self.node_x_proj is not None else x |
|
x = x * x_mask.unsqueeze(1) + x_center * (~x_mask).unsqueeze(1) |
|
|
|
attn_weight_layers = [] |
|
for _, attn in enumerate(self.attention_layers): |
|
x, edge_attr_star, attn_weight = attn( |
|
x, coords, edge_index_star, edge_attr_star, edge_vec_star) |
|
if return_attn: |
|
attn_weight_layers.append(attn_weight) |
|
|
|
|
|
batch = batch[~x_mask] |
|
return x, None, pos, edge_attr_star, batch, attn_weight_layers |
|
|
|
def __repr__(self): |
|
return ( |
|
f"{self.__class__.__name__}(" |
|
f"x_channels={self.x_channels}, " |
|
f"x_hidden_channels={self.x_hidden_channels}, " |
|
f"vec_in_channels={self.vec_in_channels}, " |
|
f"vec_channels={self.vec_channels}, " |
|
f"vec_hidden_channels={self.vec_hidden_channels}, " |
|
f"num_layers={self.num_layers}, " |
|
f"num_rbf={self.num_rbf}, " |
|
f"rbf_type={self.rbf_type}, " |
|
f"trainable_rbf={self.trainable_rbf}, " |
|
f"activation={self.activation}, " |
|
f"attn_activation={self.attn_activation}, " |
|
f"neighbor_embedding={self.neighbor_embedding}, " |
|
f"num_heads={self.num_heads}, " |
|
f"distance_influence={self.distance_influence}, " |
|
f"cutoff_lower={self.cutoff_lower}, " |
|
f"cutoff_upper={self.cutoff_upper})" |
|
) |
|
|
|
|
|
|
|
class eqTriAttnTransformer(nn.Module): |
|
""" |
|
Input a sequence representation and structure, output a new sequence representation and structure |
|
""" |
|
|
|
def __init__(self, |
|
x_in_channels=None, |
|
x_channels=1280, |
|
pairwise_state_dim=128, |
|
num_layers=4, |
|
num_heads=8, |
|
x_in_embedding_type="Embedding", |
|
drop_out_rate=0.1, |
|
x_hidden_channels=None, |
|
vec_channels=None, |
|
vec_in_channels=None, |
|
vec_hidden_channels=None, |
|
num_edge_attr=None, |
|
num_rbf=None, |
|
rbf_type=None, |
|
trainable_rbf=None, |
|
activation=None, |
|
neighbor_embedding=None, |
|
cutoff_lower=None, |
|
cutoff_upper=None, |
|
x_use_msa=False, |
|
use_lora=None, |
|
): |
|
super(eqTriAttnTransformer, self).__init__() |
|
if x_in_channels is not None: |
|
self.node_x_proj = nn.Linear(x_in_channels, x_channels) if x_in_embedding_type == "Linear" \ |
|
else nn.Embedding(x_in_channels, x_channels) |
|
else: |
|
self.node_x_proj = None |
|
assert x_channels % num_heads == 0 \ |
|
and pairwise_state_dim % num_heads == 0, ( |
|
f"The number of hidden channels x_channels ({x_channels}) " |
|
f"and pair-wise channels ({pairwise_state_dim}) " |
|
f"must be evenly divisible by the number of " |
|
f"attention heads ({num_heads})" |
|
) |
|
sequence_head_width = x_channels // num_heads |
|
pairwise_head_width = pairwise_state_dim // num_heads |
|
self.tri_attn_block = nn.ModuleList( |
|
[ |
|
TriangularSelfAttentionBlock( |
|
sequence_state_dim=x_channels, |
|
pairwise_state_dim=pairwise_state_dim, |
|
sequence_head_width=sequence_head_width, |
|
pairwise_head_width=pairwise_head_width, |
|
dropout=drop_out_rate, |
|
) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
self.seq_struct_to_pair = PairFeatureNet( |
|
x_channels, pairwise_state_dim) |
|
|
|
|
|
self.seq_pair_to_output = SeqPairAttentionOutput(seq_state_dim=x_channels, |
|
pairwise_state_dim=pairwise_state_dim, |
|
num_heads=num_heads, |
|
output_dim=x_channels, |
|
dropout=drop_out_rate) |
|
|
|
def reset_parameters(self): |
|
pass |
|
|
|
def forward(self, |
|
x: Tensor, |
|
pos: Tensor, |
|
residx: Tensor = None, |
|
mask: Tensor = None, |
|
batch: Tensor = None, |
|
edge_index: Tensor = None, |
|
edge_index_star: Tensor = None, |
|
edge_attr: Tensor = None, |
|
edge_attr_star: Tensor = None, |
|
node_vec_attr: Tensor = None, |
|
return_attn: bool = False, |
|
): |
|
""" |
|
Inputs: |
|
x: B x L x C tensor of sequence features |
|
pos: B x L x 4 x 3 tensor of [CA, CB, N, O] coordinates |
|
residx: B x L long tensor giving the position in the sequence |
|
mask: B x L boolean tensor indicating valid residues |
|
|
|
Output: |
|
predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object |
|
""" |
|
|
|
if residx is None: |
|
residx = torch.arange( |
|
x.shape[1], device=x.device).repeat(x.shape[0], 1) |
|
if mask is None: |
|
mask = torch.ones((x.shape[0], x.shape[1]), |
|
dtype=torch.bool, device=x.device) |
|
|
|
x = self.node_x_proj(x) if self.node_x_proj is not None else x |
|
|
|
pair_feats = self.seq_struct_to_pair(x, pos, residx, mask) |
|
|
|
s_s = x |
|
s_z = pair_feats |
|
|
|
for block in self.tri_attn_block: |
|
s_s, s_z = block(sequence_state=s_s, |
|
pairwise_state=s_z, |
|
mask=mask.to(torch.float32)) |
|
|
|
s_s = self.seq_pair_to_output( |
|
sequence_state=s_s, pairwise_state=s_z, mask=mask.to(torch.float32)) |
|
|
|
|
|
|
|
return s_s, s_z, pos, None, None, None |
|
|