PreMode / model /module /output.py
gzhong's picture
Upload folder using huggingface_hub
7718235 verified
raw
history blame
31.9 kB
from abc import abstractmethod, ABCMeta
from typing import Optional, Tuple
import math
import torch
from torch import _dynamo
_dynamo.config.suppress_errors = True
from torch import nn
from torch_geometric.nn import MessagePassing
# from torch_geometric.nn.pool.topk_pool import topk, filter_adj # Abort SAGPool
from torch_scatter import scatter
import loralib as lora
import gpytorch
from gpytorch.models.deep_gps import DeepGPLayer, DeepGP
from pyro.nn.module import to_pyro_module_
from data.utils import AA_DICT_HUMAN, ESM_TOKENS
from model.module.utils import act_class_mapping
from .attention import PAEMultiHeadAttentionSoftMaxStarGraph, MultiHeadAttentionSoftMaxStarGraph
from esm.modules import RobertaLMHead
class OutputModel(nn.Module, metaclass=ABCMeta):
def __init__(self, allow_prior_model, reduce_op):
super(OutputModel, self).__init__()
self.allow_prior_model = allow_prior_model
self.reduce_op = reduce_op
def reset_parameters(self):
pass
@abstractmethod
def pre_reduce(self, x, v, pos, batch):
return
def reduce(self, x, edge_index, edge_attr, batch):
return scatter(x, batch, dim=0, reduce=self.reduce_op), None
def post_reduce(self, x):
return x
class ESMScalar(OutputModel):
# ESMOutputModel is a special output model for ESM-2
# it has the same initial weights as the ESM-2 model
# it outputs the log probs of all 20 amino acids
# currently it only handles single aa change
def __init__(self, args,
activation="sigmoid",
allow_prior_model=True,
lm_head: RobertaLMHead=None):
# have a language model head weights
x_channels = args["x_channels"]
out_channels = args["output_dim"]
reduce_op = args["reduce_op"]
super(ESMScalar, self).__init__(
allow_prior_model=allow_prior_model, reduce_op=reduce_op
)
self.activation = act_class_mapping[activation]()
# first is lm dense layer
self.lm_dense = nn.Linear(x_channels, x_channels)
self.lm_dense.weight.data.copy_(lm_head.dense.weight)
self.lm_dense.bias.data.copy_(lm_head.dense.bias)
# next is lm_weight layer
self.lm_weight = nn.Parameter(torch.zeros(len(ESM_TOKENS), x_channels, out_channels))
self.lm_bias = nn.Parameter(torch.zeros(len(ESM_TOKENS), out_channels))
# self.sigmoid_bias = nn.Parameter(torch.ones(out_channels) * -4)
# self.sigmoid_weight = nn.Parameter(torch.ones(out_channels) * -1)
# last is the layer norm
self.lm_layer_norm = nn.LayerNorm(x_channels)
# copy lm_weight to self.lm_weight
for i in range(out_channels):
self.lm_weight[:, :, i].data.copy_(lm_head.weight)
self.lm_bias[:, i].data.copy_(lm_head.bias)
self.lm_layer_norm.weight.data.copy_(lm_head.layer_norm.weight)
def gelu(self, x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different
(and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def pre_reduce(self, x, v, pos, batch):
return x
def reduce(self, x, edge_index, edge_attr, batch):
# return the center node features
# center node index is the most common node index in the edge_index
center_nodes = torch.unique(edge_index[1])
x = x[center_nodes]
return x, None
def post_reduce(self, x, score_mask=None):
# apply the language model head
# first dense layer
x = self.lm_dense(x)
# next gelu
x = self.gelu(x)
# next layer norm
x = self.lm_layer_norm(x)
# last linear layer
x = torch.einsum('bh,ths->bts', x, self.lm_weight) + self.lm_bias.unsqueeze(0)
# should apply the score mask here, -1 means reference, 1 means alternative, 0 means non of interest
if score_mask is not None:
x = (x * score_mask.unsqueeze(-1)).sum(dim=1)
else:
x = x.sum(dim=1)
# return self.activation(x * self.sigmoid_weight + self.sigmoid_bias)
return self.activation(x)
class ESMFullGraphScalar(ESMScalar):
# ESMOutputModel is a special output model for ESM-2
# it has the same initial weights as the ESM-2 model
# it outputs the log probs of all 20 amino acids
# currently it only handles single aa change
def __init__(self, args,
activation="sigmoid",
allow_prior_model=True,
lm_head: RobertaLMHead=None):
# have a language model head weights
super(ESMFullGraphScalar, self).__init__(
args=args,
activation=activation,
allow_prior_model=allow_prior_model,
lm_head=lm_head
)
def reduce(self, x, x_mask):
# return the center node features
# center node index is the most common node index in the edge_index
x = (x * x_mask).sum(dim=1)
return x, None
class EquivariantNoGraphScalar(OutputModel):
def __init__(
self,
args,
activation="sigmoid",
allow_prior_model=True,
):
x_channels = args["x_channels"]
out_channels = args["output_dim"]
reduce_op = args["reduce_op"]
super(EquivariantNoGraphScalar, self).__init__(
allow_prior_model=allow_prior_model, reduce_op=reduce_op
)
act_class = act_class_mapping[activation]
self.layer_norm = nn.LayerNorm(x_channels)
self.output_network = nn.Sequential(
nn.Linear(x_channels, out_channels),
act_class(),
)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.output_network[0].weight)
self.output_network[0].bias.data.fill_(0)
def pre_reduce(self, x, v: Optional[torch.Tensor], pos, batch):
return x
def reduce(self, x, edge_index, edge_attr, batch):
return x.sum(axis=-2, keepdim=False), None
def post_reduce(self, x):
x = self.layer_norm(x)
return self.output_network(x)
class EquivariantScalar(OutputModel):
def __init__(
self,
args,
activation="sigmoid",
allow_prior_model=True,
init_fn='uniform',
):
x_channels = args["x_channels"]
if args["model"] == "pass-forward":
x_channels = args["x_in_channels"] if args["x_in_channels"] is not None else args["x_channels"]
if args["add_msa"]:
x_channels += 199
out_channels = args["output_dim"]
reduce_op = args["reduce_op"]
super(EquivariantScalar, self).__init__(
allow_prior_model=allow_prior_model, reduce_op=reduce_op
)
act_class = act_class_mapping[activation]
self.output_network = nn.Sequential(
nn.Linear(x_channels, out_channels),
act_class(),
)
self.init_fn = init_fn
self.reset_parameters()
def reset_parameters(self):
if self.init_fn == 'uniform':
nn.init.xavier_uniform_(self.output_network[0].weight)
else:
nn.init.constant_(self.output_network[0].weight, 0)
self.output_network[0].bias.data.fill_(0)
def pre_reduce(self, x, v: Optional[torch.Tensor], pos, batch):
return x
def post_reduce(self, x):
return self.output_network(x)
class EquivariantPAEAttnScalar(EquivariantScalar):
def __init__(
self,
args,
activation="sigmoid",
allow_prior_model=True,
):
x_channels = args["x_channels"]
super(EquivariantAttnScalar, self).__init__(
args=args,
activation=activation,
allow_prior_model=allow_prior_model,
)
# apply two layers of StarPool
if args["loss_fn"] == "weighted_combined_loss" or args["loss_fn"] == "combined_loss":
use_lora = args["use_lora"]
else:
use_lora = None
input_dic = {
"x_channels": args["x_channels"],
"x_hidden_channels": args["x_hidden_channels"],
"vec_channels": args["vec_channels"],
"vec_hidden_channels": args["vec_hidden_channels"],
"share_kv": args["share_kv"],
"edge_attr_dist_channels": args["num_rbf"],
"edge_attr_channels": args["num_edge_attr"],
"distance_influence": args["distance_influence"],
"num_heads": args["num_heads"],
"activation": act_class_mapping[args["activation"]],
"cutoff_lower": args["cutoff_lower"],
"cutoff_upper": args["cutoff_upper"],
"use_lora": use_lora
}
self.AttnPoolLayers = nn.ModuleList([
PAEMultiHeadAttentionSoftMaxStarGraph(**input_dic),
])
def reduce(self, x, x_center_index, w_ij, f_dist_ij, f_ij, plddt, key_padding_mask):
# x don't have to reduce to x_center_index but w_ij and f_dist_ij have to
w_ij = w_ij[x_center_index].unsqueeze(1)
f_dist_ij = f_dist_ij[x_center_index].unsqueeze(1)
f_ij = f_ij[x_center_index].unsqueeze(1)
for layer in self.AttnPoolLayers:
x, attn = layer(x, x_center_index, w_ij, f_dist_ij, f_ij, key_padding_mask, True)
return x, attn
class EquivariantAttnScalar(EquivariantScalar):
def __init__(
self,
args,
activation="sigmoid",
allow_prior_model=True,
):
x_channels = args["x_channels"]
super(EquivariantAttnScalar, self).__init__(
args=args,
activation=activation,
allow_prior_model=allow_prior_model,
)
# apply two layers of StarPool
if args["loss_fn"] == "weighted_combined_loss" or args["loss_fn"] == "combined_loss":
use_lora = args["use_lora"]
else:
use_lora = None
input_dic = {
"x_channels": args["x_channels"],
"x_hidden_channels": args["x_hidden_channels"],
"vec_channels": args["vec_channels"],
"vec_hidden_channels": args["vec_hidden_channels"],
"share_kv": args["share_kv"],
"edge_attr_dist_channels": args["num_rbf"],
"edge_attr_channels": args["num_edge_attr"],
"distance_influence": args["distance_influence"],
"num_heads": args["num_heads"],
"activation": act_class_mapping[args["activation"]],
"cutoff_lower": args["cutoff_lower"],
"cutoff_upper": args["cutoff_upper"],
"use_lora": use_lora
}
self.AttnPoolLayers = nn.ModuleList([
MultiHeadAttentionSoftMaxStarGraph(**input_dic),
])
def reduce(self, x, x_center_index, w_ij, f_dist_ij, f_ij, plddt, key_padding_mask):
# x don't have to reduce to x_center_index but w_ij and f_dist_ij have to
# w_ij = w_ij[x_center_index].unsqueeze(1)
# f_dist_ij = f_dist_ij[x_center_index].unsqueeze(1)
f_ij = f_ij[x_center_index].unsqueeze(1)
for layer in self.AttnPoolLayers:
x, attn = layer(x, x_center_index, None, None, f_ij, key_padding_mask, True)
return x, attn
class EquivariantAttnOneSiteScalar(EquivariantAttnScalar):
def __init__(
self,
args,
activation="sigmoid",
allow_prior_model=True,
):
self.output_dim = args["output_dim"]
args["output_dim"] = len(AA_DICT_HUMAN) * self.output_dim
super(EquivariantAttnOneSiteScalar, self).__init__(
args=args,
activation=activation,
allow_prior_model=allow_prior_model,
)
def post_reduce(self, x):
res = self.output_network(x).reshape(-1, len(AA_DICT_HUMAN), self.output_dim)
return res
class EquivariantStarPoolScalar(EquivariantScalar):
def __init__(
self,
args,
activation="sigmoid",
allow_prior_model=True,
init_fn='uniform',
):
x_channels = args["x_channels"]
super(EquivariantStarPoolScalar, self).__init__(
args=args,
activation=activation,
allow_prior_model=allow_prior_model,
init_fn=init_fn,
)
# apply two layers of StarPool
if args["loss_fn"] == "weighted_combined_loss" or args["loss_fn"] == "combined_loss":
use_lora = args["use_lora"]
else:
use_lora = None
self.StarPoolLayers = nn.ModuleList([
StarPool(hidden_channels=x_channels,
edge_channels=args["num_rbf"] + args["num_edge_attr"],
cutoff_lower=args["cutoff_lower"],
cutoff_upper=args["cutoff_upper"],
use_lora=use_lora,
drop_out=args["drop_out"],
ratio=0.5),
])
def reduce(self, x, edge_index, edge_attr, batch):
for layer in self.StarPoolLayers:
x, edge_index, edge_attr, batch, attn = layer(x, edge_index, edge_attr, batch)
out = scatter(x, batch, dim=0, reduce=self.reduce_op)
return out, attn
class EquivariantStarPoolOneSiteScalar(EquivariantStarPoolScalar):
def __init__(
self,
args,
activation="sigmoid",
allow_prior_model=True,
):
self.output_dim = args["output_dim"]
args["output_dim"] = len(AA_DICT_HUMAN) * self.output_dim
super(EquivariantStarPoolOneSiteScalar, self).__init__(
args=args,
activation=activation,
allow_prior_model=allow_prior_model,
)
def post_reduce(self, x):
res = self.output_network(x).reshape(-1, len(AA_DICT_HUMAN), self.output_dim)
return res
class EquivariantStarPoolMeanVarScalar(EquivariantStarPoolScalar):
def __init__(
self,
args,
activation="sigmoid",
allow_prior_model=True,
):
self.output_dim = args["output_dim"]
# make a copy of args
args_copy = args.copy()
args_copy["output_dim"] = 2 * self.output_dim
super(EquivariantStarPoolMeanVarScalar, self).__init__(
args=args_copy,
activation=activation,
allow_prior_model=allow_prior_model,
)
def post_reduce(self, x):
# output mean and variance
return self.output_network(x).reshape(-1, 2, self.output_dim)
class EquivariantStarPoolPyroScalar(EquivariantStarPoolScalar):
def __init__(
self,
args,
activation="sigmoid",
allow_prior_model=True,
):
super(EquivariantStarPoolPyroScalar, self).__init__(
args=args,
activation=activation,
allow_prior_model=allow_prior_model,
)
to_pyro_module_(self.output_network)
def post_reduce(self, x):
return self.output_network(x)
# GP layer basic class
class GaussianProcessLayer(DeepGPLayer):
def __init__(self, input_dims, output_dims, num_inducing=64, mean_type='constant'):
if output_dims is None:
inducing_points = torch.randn(num_inducing, input_dims)
batch_shape = torch.Size([])
else:
inducing_points = torch.randn(output_dims, num_inducing, input_dims)
batch_shape = torch.Size([output_dims])
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
num_inducing_points=num_inducing,
batch_shape=batch_shape
)
# initialize variational strategy
variational_strategy = gpytorch.variational.VariationalStrategy(
self,
inducing_points,
variational_distribution,
learn_inducing_locations=True
)
super(GaussianProcessLayer, self).__init__(variational_strategy, input_dims, output_dims)
if mean_type == 'constant':
self.mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape)
else:
self.mean_module = gpytorch.means.LinearMean(input_dims)
# Cannot use RBFKernel here because it guarantee the diagnal values to be same
self.covar_module = gpytorch.kernels.ScaleKernel(
gpytorch.kernels.LinearKernel(num_dimensions=input_dims),
batch_shape=batch_shape, ard_num_dims=None
)
def forward(self, x):
mean = self.mean_module(x)
covar = self.covar_module(x)
y = gpytorch.distributions.MultivariateNormal(mean, covar)
y_new = y.to_data_independent_dist()
y_new.lazy_covariance_matrix = gpytorch.lazy.DiagLazyTensor(y.lazy_covariance_matrix.diag())
return y_new
def __call__(self, x, *other_inputs, **kwargs):
"""
Overriding __call__ isn't strictly necessary, but it lets us add concatenation based skip connections
easily. For example, hidden_layer2(hidden_layer1_outputs, inputs) will pass the concatenation of the first
hidden layer's outputs and the input data to hidden_layer2.
"""
if len(other_inputs):
if isinstance(x, gpytorch.distributions.MultitaskMultivariateNormal):
x = x.rsample()
processed_inputs = [
inp.unsqueeze(0).expand(gpytorch.settings.num_likelihood_samples.value(), *inp.shape)
for inp in other_inputs
]
x = torch.cat([x] + processed_inputs, dim=-1)
# TODO: here the are_samples are set to always true, not sure if it is correct
# return super().__call__(x, are_samples=bool(len(other_inputs)), **kwargs)
return super().__call__(x, are_samples=True, **kwargs)
class EquivariantStarPoolGPScalar(EquivariantStarPoolScalar, DeepGP):
def __init__(
self,
args,
activation="sigmoid",
grid_bounds=(-10., 10.),
allow_prior_model=True,
):
x_channels = args["x_channels"]
DeepGP.__init__(self)
EquivariantStarPoolScalar.__init__(
self,
args=args,
activation=activation,
allow_prior_model=allow_prior_model,
)
# change the output network to GP
self.output_network = GaussianProcessLayer(input_dims=x_channels, output_dims=args["output_dim"], mean_type='linear')
self.grid_bounds = grid_bounds
# This module will scale the NN features so that they're nice values
self.scale_to_bounds = gpytorch.utils.grid.ScaleToBounds(self.grid_bounds[0], self.grid_bounds[1])
self.likelihood = gpytorch.likelihoods.BernoulliLikelihood()
def reset_parameters(self):
# Do nothing
return
def post_reduce(self, x):
# Scale the input to be between [-10, 10]
x = self.scale_to_bounds(x)
# Get the predicted latent function values
x = self.output_network(x)
return x
class PositiveLinear(nn.Module):
def __init__(self, in_features, out_features):
super(PositiveLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
def forward(self, x):
return nn.functional.linear(x, self.weight.exp())
class EquivariantRegressionClassificationStarPoolScalar(EquivariantStarPoolScalar):
def __init__(
self,
args,
activation="pass",
allow_prior_model=True,
):
x_channels = args["x_channels"]
super(EquivariantRegressionClassificationStarPoolScalar, self).__init__(
args=args,
activation=activation,
allow_prior_model=allow_prior_model,
)
# apply two layers of SAGPool
# apply two layers of SAGPool
self.StarPoolLayers = nn.ModuleList([
StarPool(hidden_channels=x_channels,
edge_channels=args["num_rbf"] + args["num_edge_attr"],
cutoff_lower=args["cutoff_lower"],
cutoff_upper=args["cutoff_upper"],
use_lora=args["use_lora"],
drop_out=args["drop_out"],
ratio=0.5),
])
self.output_network_1 = nn.Sequential(
nn.Linear(x_channels, args["output_dim_1"]),
act_class_mapping["pass"](),
)
self.output_network_2 = nn.Sequential(
nn.Linear(args["output_dim_1"], args["output_dim_2"]),
act_class_mapping["sigmoid"](),
)
def reduce(self, x, edge_index, edge_attr, batch):
for layer in self.StarPoolLayers:
x, edge_index, edge_attr, batch, attn = layer(x, edge_index, edge_attr, batch)
out = scatter(x, batch, dim=0, reduce=self.reduce_op)
return out, attn
def post_reduce(self, x):
x = self.output_network_1(x)
# concat x and self.output_network_2(x) to get the final output
output = torch.cat((self.output_network_2(x), x), dim=-1)
return output
class EquivariantMaskPredictScalar(OutputModel):
def __init__(
self,
args,
lm_weight,
activation="gelu",
allow_prior_model=True,
):
x_channels = args["x_channels"]
out_channels = args["x_channels"]
reduce_op = args["reduce_op"]
super(EquivariantMaskPredictScalar, self).__init__(
allow_prior_model=allow_prior_model, reduce_op=reduce_op
)
act_class = act_class_mapping[activation]
self.output_network = nn.Sequential(
nn.Linear(x_channels, out_channels),
act_class(),
nn.LayerNorm(out_channels),
)
self.lm_weight = lm_weight
self.bias = nn.Parameter(torch.zeros(args["x_in_channels"]))
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.output_network[0].weight)
self.output_network[0].bias.data.fill_(0)
def pre_reduce(self, x, v: Optional[torch.Tensor], pos, batch):
x = self.output_network(x)
x = nn.functional.linear(x, self.lm_weight) + self.bias
return x
def post_reduce(self, x):
return x
class EquivariantMaskPredictLogLogitsScalar(EquivariantMaskPredictScalar):
def __init__(
self,
args,
lm_weight,
activation="gelu",
allow_prior_model=True,
):
super(EquivariantMaskPredictLogLogitsScalar, self).__init__(
args=args,
lm_weight=lm_weight,
activation=activation,
allow_prior_model=allow_prior_model,
)
def pre_reduce(self, x, v: Optional[torch.Tensor], pos, batch):
x = self.output_network(x)
x = nn.functional.linear(x, self.lm_weight) + self.bias
x = torch.log_softmax(x, dim=-1)
return x
class Scalar(OutputModel):
def __init__(
self,
args,
activation="silu",
allow_prior_model=True,
):
x_channels = args["x_channels"]
out_channels = args["output_dim"]
reduce_op = args["reduce_op"]
super(Scalar, self).__init__(
allow_prior_model=allow_prior_model, reduce_op=reduce_op
)
act_class = act_class_mapping[activation]
self.output_network = nn.Sequential(
nn.Linear(x_channels, x_channels // 2),
act_class(),
nn.Linear(x_channels // 2, out_channels),
)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.output_network[0].weight)
self.output_network[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.output_network[2].weight)
self.output_network[2].bias.data.fill_(0)
def pre_reduce(self, x, v: Optional[torch.Tensor], pos, batch):
return self.output_network(x)
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different
(and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class StarPool(MessagePassing, metaclass=ABCMeta):
def __init__(self,
hidden_channels,
edge_channels,
cutoff_lower,
cutoff_upper,
ratio=0.5,
drop_out=0.0,
num_heads=32,
use_lora=None,
non_linearity=torch.tanh):
super(StarPool, self).__init__(aggr="mean")
if use_lora is not None:
self.q_proj = lora.Linear(hidden_channels, hidden_channels, r=use_lora)
self.kv_proj = lora.Linear(hidden_channels, hidden_channels, r=use_lora)
self.dk_proj = lora.Linear(edge_channels, hidden_channels, r=use_lora)
# self.fc1 = lora.Linear(hidden_channels, hidden_channels, r=use_lora)
# self.fc2 = lora.Linear(hidden_channels, hidden_channels, r=use_lora)
else:
self.q_proj = nn.Linear(hidden_channels, hidden_channels)
self.kv_proj = nn.Linear(hidden_channels, hidden_channels)
self.dk_proj = nn.Linear(edge_channels, hidden_channels)
# self.fc1 = nn.Linear(hidden_channels, hidden_channels)
# self.fc2 = nn.Linear(hidden_channels, hidden_channels)
self.layernorm_in = nn.LayerNorm(hidden_channels)
self.layernorm_out = nn.LayerNorm(hidden_channels)
self.num_heads = num_heads
self.hidden_channels = hidden_channels
self.x_head_dim = hidden_channels // num_heads
self.node_dim = 0
self.attn_activation = act_class_mapping["silu"]()
self.drop_out = nn.Dropout(drop_out)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.kv_proj.weight)
# nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.dk_proj.weight)
self.q_proj.bias.data.fill_(0)
self.kv_proj.bias.data.fill_(0)
# self.v_proj.bias.data.fill_(0)
self.dk_proj.bias.data.fill_(0)
def forward(self, x, edge_index, edge_attr, batch=None):
residue = x
x = self.layernorm_in(x)
if batch is None:
batch = edge_index.new_zeros(x.size(0))
q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim)
k = self.kv_proj(x).reshape(-1, self.num_heads, self.x_head_dim)
# v = self.v_proj(x).reshape(-1, self.num_heads, self.x_head_dim)
v = k
dk = self.dk_proj(edge_attr).reshape(-1, self.num_heads, self.x_head_dim)
x, attn = self.propagate(
edge_index = edge_index,
q=q,
k=k,
v=v,
dk=dk,
size=None
)
x = x.reshape(-1, self.hidden_channels)
x = residue + x
# perform topK pooling
center_nodes = torch.unique(edge_index[1])
perm = center_nodes
x = x[perm]
batch = batch[perm]
residue = residue[perm]
x = self.layernorm_out(x)
x = residue + self.drop_out(x)
return x, edge_index, edge_attr, batch, attn
def message(self, q_i, k_j, v_j, dk):
attn = (q_i * k_j * dk).sum(dim=-1)
# attention activation function
attn = self.attn_activation(attn)
# update scalar features
x = v_j * attn.unsqueeze(2)
return x, attn
def aggregate(
self,
features: Tuple[torch.Tensor, torch.Tensor],
index: torch.Tensor,
ptr: Optional[torch.Tensor],
dim_size: Optional[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
x, attn = features
x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
return x, attn
def update(
self, inputs: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
return inputs
def message_and_aggregate(self, adj_t) -> torch.Tensor:
pass
def edge_update(self) -> torch.Tensor:
pass
def build_output_model(output_model_name, args, **kwargs):
if output_model_name == "EquivariantBinaryClassificationNoGraphScalar":
return EquivariantNoGraphScalar(args=args, activation="sigmoid")
elif output_model_name == "EquivariantBinaryClassificationScalar":
return EquivariantScalar(args=args, activation="sigmoid")
elif output_model_name == "ESMBinaryClassificationScalar":
return ESMScalar(args=args, activation="sigmoid", **kwargs)
elif output_model_name == "ESMFullGraphBinaryClassificationScalar":
return ESMFullGraphScalar(args=args, activation="sigmoid", **kwargs)
elif output_model_name == "EquivariantBinaryClassificationStarPoolScalar":
return EquivariantStarPoolScalar(args=args, activation="sigmoid", init_fn=args["init_fn"])
elif output_model_name == "EquivariantBinaryClassificationStarPoolMeanVarScalar":
return EquivariantStarPoolMeanVarScalar(args=args, activation="softplus")
elif output_model_name == "EquivariantBinaryClassificationAttnScalar":
return EquivariantAttnScalar(args=args, activation="sigmoid")
elif output_model_name == "EquivariantBinaryClassificationPAEAttnScalar":
return EquivariantPAEAttnScalar(args=args, activation="sigmoid")
elif output_model_name == "EquivariantBinaryClassificationStarPoolOneSiteScalar":
return EquivariantStarPoolOneSiteScalar(args=args, activation="sigmoid")
elif output_model_name == "EquivariantBinaryClassificationAttnOneSiteScalar":
return EquivariantAttnOneSiteScalar(args=args, activation="sigmoid")
elif output_model_name == "EquivariantBinaryClassificationStarPoolGPScalar":
return EquivariantStarPoolGPScalar(args=args, activation="sigmoid")
elif output_model_name == "EquivariantRegressionScalar":
return EquivariantScalar(args=args, activation="pass")
elif output_model_name == "ESMRegressionScalar":
return ESMScalar(args=args, activation="pass", **kwargs)
elif output_model_name == "ESMFullGraphRegressionScalar":
return ESMFullGraphScalar(args=args, activation="pass", **kwargs)
elif output_model_name == "EquivariantRegressionStarPoolScalar":
return EquivariantStarPoolScalar(args=args, activation="pass")
elif output_model_name == "EquivariantRegressionStarPoolMeanVarScalar":
return EquivariantStarPoolMeanVarScalar(args=args, activation="pass")
elif output_model_name == "EquivariantRegressionAttnScalar":
return EquivariantAttnScalar(args=args, activation="pass")
elif output_model_name == "EquivariantRegressionPAEAttnScalar":
return EquivariantPAEAttnScalar(args=args, activation="pass")
elif output_model_name == "EquivariantRegressionStarPoolOneSiteScalar":
return EquivariantStarPoolOneSiteScalar(args=args, activation="pass")
elif output_model_name == "EquivariantRegressionAttnOneSiteScalar":
return EquivariantAttnOneSiteScalar(args=args, activation="pass")
else:
raise NotImplementedError