|
from abc import abstractmethod, ABCMeta |
|
from typing import Optional, Tuple |
|
import math |
|
import torch |
|
from torch import _dynamo |
|
_dynamo.config.suppress_errors = True |
|
from torch import nn |
|
from torch_geometric.nn import MessagePassing |
|
|
|
from torch_scatter import scatter |
|
import loralib as lora |
|
import gpytorch |
|
from gpytorch.models.deep_gps import DeepGPLayer, DeepGP |
|
from pyro.nn.module import to_pyro_module_ |
|
from data.utils import AA_DICT_HUMAN, ESM_TOKENS |
|
from model.module.utils import act_class_mapping |
|
from .attention import PAEMultiHeadAttentionSoftMaxStarGraph, MultiHeadAttentionSoftMaxStarGraph |
|
from esm.modules import RobertaLMHead |
|
|
|
|
|
class OutputModel(nn.Module, metaclass=ABCMeta): |
|
def __init__(self, allow_prior_model, reduce_op): |
|
super(OutputModel, self).__init__() |
|
self.allow_prior_model = allow_prior_model |
|
self.reduce_op = reduce_op |
|
|
|
def reset_parameters(self): |
|
pass |
|
|
|
@abstractmethod |
|
def pre_reduce(self, x, v, pos, batch): |
|
return |
|
|
|
def reduce(self, x, edge_index, edge_attr, batch): |
|
return scatter(x, batch, dim=0, reduce=self.reduce_op), None |
|
|
|
def post_reduce(self, x): |
|
return x |
|
|
|
|
|
class ESMScalar(OutputModel): |
|
|
|
|
|
|
|
|
|
def __init__(self, args, |
|
activation="sigmoid", |
|
allow_prior_model=True, |
|
lm_head: RobertaLMHead=None): |
|
|
|
x_channels = args["x_channels"] |
|
out_channels = args["output_dim"] |
|
reduce_op = args["reduce_op"] |
|
super(ESMScalar, self).__init__( |
|
allow_prior_model=allow_prior_model, reduce_op=reduce_op |
|
) |
|
self.activation = act_class_mapping[activation]() |
|
|
|
self.lm_dense = nn.Linear(x_channels, x_channels) |
|
self.lm_dense.weight.data.copy_(lm_head.dense.weight) |
|
self.lm_dense.bias.data.copy_(lm_head.dense.bias) |
|
|
|
self.lm_weight = nn.Parameter(torch.zeros(len(ESM_TOKENS), x_channels, out_channels)) |
|
self.lm_bias = nn.Parameter(torch.zeros(len(ESM_TOKENS), out_channels)) |
|
|
|
|
|
|
|
self.lm_layer_norm = nn.LayerNorm(x_channels) |
|
|
|
for i in range(out_channels): |
|
self.lm_weight[:, :, i].data.copy_(lm_head.weight) |
|
self.lm_bias[:, i].data.copy_(lm_head.bias) |
|
self.lm_layer_norm.weight.data.copy_(lm_head.layer_norm.weight) |
|
|
|
def gelu(self, x): |
|
"""Implementation of the gelu activation function. |
|
|
|
For information: OpenAI GPT's gelu is slightly different |
|
(and gives slightly different results): |
|
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
|
""" |
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
|
|
|
def pre_reduce(self, x, v, pos, batch): |
|
return x |
|
|
|
def reduce(self, x, edge_index, edge_attr, batch): |
|
|
|
|
|
center_nodes = torch.unique(edge_index[1]) |
|
x = x[center_nodes] |
|
return x, None |
|
|
|
def post_reduce(self, x, score_mask=None): |
|
|
|
|
|
x = self.lm_dense(x) |
|
|
|
x = self.gelu(x) |
|
|
|
x = self.lm_layer_norm(x) |
|
|
|
x = torch.einsum('bh,ths->bts', x, self.lm_weight) + self.lm_bias.unsqueeze(0) |
|
|
|
if score_mask is not None: |
|
x = (x * score_mask.unsqueeze(-1)).sum(dim=1) |
|
else: |
|
x = x.sum(dim=1) |
|
|
|
return self.activation(x) |
|
|
|
|
|
class ESMFullGraphScalar(ESMScalar): |
|
|
|
|
|
|
|
|
|
def __init__(self, args, |
|
activation="sigmoid", |
|
allow_prior_model=True, |
|
lm_head: RobertaLMHead=None): |
|
|
|
super(ESMFullGraphScalar, self).__init__( |
|
args=args, |
|
activation=activation, |
|
allow_prior_model=allow_prior_model, |
|
lm_head=lm_head |
|
) |
|
|
|
def reduce(self, x, x_mask): |
|
|
|
|
|
x = (x * x_mask).sum(dim=1) |
|
return x, None |
|
|
|
|
|
class EquivariantNoGraphScalar(OutputModel): |
|
def __init__( |
|
self, |
|
args, |
|
activation="sigmoid", |
|
allow_prior_model=True, |
|
): |
|
x_channels = args["x_channels"] |
|
out_channels = args["output_dim"] |
|
reduce_op = args["reduce_op"] |
|
super(EquivariantNoGraphScalar, self).__init__( |
|
allow_prior_model=allow_prior_model, reduce_op=reduce_op |
|
) |
|
act_class = act_class_mapping[activation] |
|
self.layer_norm = nn.LayerNorm(x_channels) |
|
self.output_network = nn.Sequential( |
|
nn.Linear(x_channels, out_channels), |
|
act_class(), |
|
) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
nn.init.xavier_uniform_(self.output_network[0].weight) |
|
self.output_network[0].bias.data.fill_(0) |
|
|
|
def pre_reduce(self, x, v: Optional[torch.Tensor], pos, batch): |
|
return x |
|
|
|
def reduce(self, x, edge_index, edge_attr, batch): |
|
return x.sum(axis=-2, keepdim=False), None |
|
|
|
def post_reduce(self, x): |
|
x = self.layer_norm(x) |
|
return self.output_network(x) |
|
|
|
|
|
class EquivariantScalar(OutputModel): |
|
def __init__( |
|
self, |
|
args, |
|
activation="sigmoid", |
|
allow_prior_model=True, |
|
init_fn='uniform', |
|
): |
|
x_channels = args["x_channels"] |
|
if args["model"] == "pass-forward": |
|
x_channels = args["x_in_channels"] if args["x_in_channels"] is not None else args["x_channels"] |
|
if args["add_msa"]: |
|
x_channels += 199 |
|
out_channels = args["output_dim"] |
|
reduce_op = args["reduce_op"] |
|
super(EquivariantScalar, self).__init__( |
|
allow_prior_model=allow_prior_model, reduce_op=reduce_op |
|
) |
|
act_class = act_class_mapping[activation] |
|
self.output_network = nn.Sequential( |
|
nn.Linear(x_channels, out_channels), |
|
act_class(), |
|
) |
|
self.init_fn = init_fn |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
if self.init_fn == 'uniform': |
|
nn.init.xavier_uniform_(self.output_network[0].weight) |
|
else: |
|
nn.init.constant_(self.output_network[0].weight, 0) |
|
self.output_network[0].bias.data.fill_(0) |
|
|
|
def pre_reduce(self, x, v: Optional[torch.Tensor], pos, batch): |
|
return x |
|
|
|
def post_reduce(self, x): |
|
return self.output_network(x) |
|
|
|
|
|
class EquivariantPAEAttnScalar(EquivariantScalar): |
|
def __init__( |
|
self, |
|
args, |
|
activation="sigmoid", |
|
allow_prior_model=True, |
|
): |
|
x_channels = args["x_channels"] |
|
super(EquivariantAttnScalar, self).__init__( |
|
args=args, |
|
activation=activation, |
|
allow_prior_model=allow_prior_model, |
|
) |
|
|
|
if args["loss_fn"] == "weighted_combined_loss" or args["loss_fn"] == "combined_loss": |
|
use_lora = args["use_lora"] |
|
else: |
|
use_lora = None |
|
input_dic = { |
|
"x_channels": args["x_channels"], |
|
"x_hidden_channels": args["x_hidden_channels"], |
|
"vec_channels": args["vec_channels"], |
|
"vec_hidden_channels": args["vec_hidden_channels"], |
|
"share_kv": args["share_kv"], |
|
"edge_attr_dist_channels": args["num_rbf"], |
|
"edge_attr_channels": args["num_edge_attr"], |
|
"distance_influence": args["distance_influence"], |
|
"num_heads": args["num_heads"], |
|
"activation": act_class_mapping[args["activation"]], |
|
"cutoff_lower": args["cutoff_lower"], |
|
"cutoff_upper": args["cutoff_upper"], |
|
"use_lora": use_lora |
|
} |
|
self.AttnPoolLayers = nn.ModuleList([ |
|
PAEMultiHeadAttentionSoftMaxStarGraph(**input_dic), |
|
]) |
|
|
|
def reduce(self, x, x_center_index, w_ij, f_dist_ij, f_ij, plddt, key_padding_mask): |
|
|
|
w_ij = w_ij[x_center_index].unsqueeze(1) |
|
f_dist_ij = f_dist_ij[x_center_index].unsqueeze(1) |
|
f_ij = f_ij[x_center_index].unsqueeze(1) |
|
for layer in self.AttnPoolLayers: |
|
x, attn = layer(x, x_center_index, w_ij, f_dist_ij, f_ij, key_padding_mask, True) |
|
return x, attn |
|
|
|
|
|
class EquivariantAttnScalar(EquivariantScalar): |
|
def __init__( |
|
self, |
|
args, |
|
activation="sigmoid", |
|
allow_prior_model=True, |
|
): |
|
x_channels = args["x_channels"] |
|
super(EquivariantAttnScalar, self).__init__( |
|
args=args, |
|
activation=activation, |
|
allow_prior_model=allow_prior_model, |
|
) |
|
|
|
if args["loss_fn"] == "weighted_combined_loss" or args["loss_fn"] == "combined_loss": |
|
use_lora = args["use_lora"] |
|
else: |
|
use_lora = None |
|
input_dic = { |
|
"x_channels": args["x_channels"], |
|
"x_hidden_channels": args["x_hidden_channels"], |
|
"vec_channels": args["vec_channels"], |
|
"vec_hidden_channels": args["vec_hidden_channels"], |
|
"share_kv": args["share_kv"], |
|
"edge_attr_dist_channels": args["num_rbf"], |
|
"edge_attr_channels": args["num_edge_attr"], |
|
"distance_influence": args["distance_influence"], |
|
"num_heads": args["num_heads"], |
|
"activation": act_class_mapping[args["activation"]], |
|
"cutoff_lower": args["cutoff_lower"], |
|
"cutoff_upper": args["cutoff_upper"], |
|
"use_lora": use_lora |
|
} |
|
self.AttnPoolLayers = nn.ModuleList([ |
|
MultiHeadAttentionSoftMaxStarGraph(**input_dic), |
|
]) |
|
|
|
def reduce(self, x, x_center_index, w_ij, f_dist_ij, f_ij, plddt, key_padding_mask): |
|
|
|
|
|
|
|
f_ij = f_ij[x_center_index].unsqueeze(1) |
|
for layer in self.AttnPoolLayers: |
|
x, attn = layer(x, x_center_index, None, None, f_ij, key_padding_mask, True) |
|
return x, attn |
|
|
|
|
|
class EquivariantAttnOneSiteScalar(EquivariantAttnScalar): |
|
def __init__( |
|
self, |
|
args, |
|
activation="sigmoid", |
|
allow_prior_model=True, |
|
): |
|
self.output_dim = args["output_dim"] |
|
args["output_dim"] = len(AA_DICT_HUMAN) * self.output_dim |
|
super(EquivariantAttnOneSiteScalar, self).__init__( |
|
args=args, |
|
activation=activation, |
|
allow_prior_model=allow_prior_model, |
|
) |
|
|
|
def post_reduce(self, x): |
|
res = self.output_network(x).reshape(-1, len(AA_DICT_HUMAN), self.output_dim) |
|
return res |
|
|
|
|
|
class EquivariantStarPoolScalar(EquivariantScalar): |
|
def __init__( |
|
self, |
|
args, |
|
activation="sigmoid", |
|
allow_prior_model=True, |
|
init_fn='uniform', |
|
): |
|
x_channels = args["x_channels"] |
|
super(EquivariantStarPoolScalar, self).__init__( |
|
args=args, |
|
activation=activation, |
|
allow_prior_model=allow_prior_model, |
|
init_fn=init_fn, |
|
) |
|
|
|
if args["loss_fn"] == "weighted_combined_loss" or args["loss_fn"] == "combined_loss": |
|
use_lora = args["use_lora"] |
|
else: |
|
use_lora = None |
|
self.StarPoolLayers = nn.ModuleList([ |
|
StarPool(hidden_channels=x_channels, |
|
edge_channels=args["num_rbf"] + args["num_edge_attr"], |
|
cutoff_lower=args["cutoff_lower"], |
|
cutoff_upper=args["cutoff_upper"], |
|
use_lora=use_lora, |
|
drop_out=args["drop_out"], |
|
ratio=0.5), |
|
]) |
|
|
|
def reduce(self, x, edge_index, edge_attr, batch): |
|
for layer in self.StarPoolLayers: |
|
x, edge_index, edge_attr, batch, attn = layer(x, edge_index, edge_attr, batch) |
|
out = scatter(x, batch, dim=0, reduce=self.reduce_op) |
|
return out, attn |
|
|
|
|
|
class EquivariantStarPoolOneSiteScalar(EquivariantStarPoolScalar): |
|
def __init__( |
|
self, |
|
args, |
|
activation="sigmoid", |
|
allow_prior_model=True, |
|
): |
|
self.output_dim = args["output_dim"] |
|
args["output_dim"] = len(AA_DICT_HUMAN) * self.output_dim |
|
super(EquivariantStarPoolOneSiteScalar, self).__init__( |
|
args=args, |
|
activation=activation, |
|
allow_prior_model=allow_prior_model, |
|
) |
|
|
|
def post_reduce(self, x): |
|
res = self.output_network(x).reshape(-1, len(AA_DICT_HUMAN), self.output_dim) |
|
return res |
|
|
|
|
|
class EquivariantStarPoolMeanVarScalar(EquivariantStarPoolScalar): |
|
def __init__( |
|
self, |
|
args, |
|
activation="sigmoid", |
|
allow_prior_model=True, |
|
): |
|
self.output_dim = args["output_dim"] |
|
|
|
args_copy = args.copy() |
|
args_copy["output_dim"] = 2 * self.output_dim |
|
super(EquivariantStarPoolMeanVarScalar, self).__init__( |
|
args=args_copy, |
|
activation=activation, |
|
allow_prior_model=allow_prior_model, |
|
) |
|
|
|
def post_reduce(self, x): |
|
|
|
return self.output_network(x).reshape(-1, 2, self.output_dim) |
|
|
|
|
|
class EquivariantStarPoolPyroScalar(EquivariantStarPoolScalar): |
|
def __init__( |
|
self, |
|
args, |
|
activation="sigmoid", |
|
allow_prior_model=True, |
|
): |
|
super(EquivariantStarPoolPyroScalar, self).__init__( |
|
args=args, |
|
activation=activation, |
|
allow_prior_model=allow_prior_model, |
|
) |
|
to_pyro_module_(self.output_network) |
|
|
|
def post_reduce(self, x): |
|
return self.output_network(x) |
|
|
|
|
|
|
|
class GaussianProcessLayer(DeepGPLayer): |
|
def __init__(self, input_dims, output_dims, num_inducing=64, mean_type='constant'): |
|
if output_dims is None: |
|
inducing_points = torch.randn(num_inducing, input_dims) |
|
batch_shape = torch.Size([]) |
|
else: |
|
inducing_points = torch.randn(output_dims, num_inducing, input_dims) |
|
batch_shape = torch.Size([output_dims]) |
|
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution( |
|
num_inducing_points=num_inducing, |
|
batch_shape=batch_shape |
|
) |
|
|
|
variational_strategy = gpytorch.variational.VariationalStrategy( |
|
self, |
|
inducing_points, |
|
variational_distribution, |
|
learn_inducing_locations=True |
|
) |
|
super(GaussianProcessLayer, self).__init__(variational_strategy, input_dims, output_dims) |
|
if mean_type == 'constant': |
|
self.mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape) |
|
else: |
|
self.mean_module = gpytorch.means.LinearMean(input_dims) |
|
|
|
self.covar_module = gpytorch.kernels.ScaleKernel( |
|
gpytorch.kernels.LinearKernel(num_dimensions=input_dims), |
|
batch_shape=batch_shape, ard_num_dims=None |
|
) |
|
|
|
def forward(self, x): |
|
mean = self.mean_module(x) |
|
covar = self.covar_module(x) |
|
y = gpytorch.distributions.MultivariateNormal(mean, covar) |
|
y_new = y.to_data_independent_dist() |
|
y_new.lazy_covariance_matrix = gpytorch.lazy.DiagLazyTensor(y.lazy_covariance_matrix.diag()) |
|
return y_new |
|
|
|
def __call__(self, x, *other_inputs, **kwargs): |
|
""" |
|
Overriding __call__ isn't strictly necessary, but it lets us add concatenation based skip connections |
|
easily. For example, hidden_layer2(hidden_layer1_outputs, inputs) will pass the concatenation of the first |
|
hidden layer's outputs and the input data to hidden_layer2. |
|
""" |
|
if len(other_inputs): |
|
if isinstance(x, gpytorch.distributions.MultitaskMultivariateNormal): |
|
x = x.rsample() |
|
|
|
processed_inputs = [ |
|
inp.unsqueeze(0).expand(gpytorch.settings.num_likelihood_samples.value(), *inp.shape) |
|
for inp in other_inputs |
|
] |
|
|
|
x = torch.cat([x] + processed_inputs, dim=-1) |
|
|
|
|
|
return super().__call__(x, are_samples=True, **kwargs) |
|
|
|
|
|
class EquivariantStarPoolGPScalar(EquivariantStarPoolScalar, DeepGP): |
|
def __init__( |
|
self, |
|
args, |
|
activation="sigmoid", |
|
grid_bounds=(-10., 10.), |
|
allow_prior_model=True, |
|
): |
|
x_channels = args["x_channels"] |
|
DeepGP.__init__(self) |
|
EquivariantStarPoolScalar.__init__( |
|
self, |
|
args=args, |
|
activation=activation, |
|
allow_prior_model=allow_prior_model, |
|
) |
|
|
|
self.output_network = GaussianProcessLayer(input_dims=x_channels, output_dims=args["output_dim"], mean_type='linear') |
|
self.grid_bounds = grid_bounds |
|
|
|
self.scale_to_bounds = gpytorch.utils.grid.ScaleToBounds(self.grid_bounds[0], self.grid_bounds[1]) |
|
self.likelihood = gpytorch.likelihoods.BernoulliLikelihood() |
|
|
|
def reset_parameters(self): |
|
|
|
return |
|
|
|
def post_reduce(self, x): |
|
|
|
x = self.scale_to_bounds(x) |
|
|
|
x = self.output_network(x) |
|
return x |
|
|
|
|
|
class PositiveLinear(nn.Module): |
|
def __init__(self, in_features, out_features): |
|
super(PositiveLinear, self).__init__() |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.weight = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=True) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
nn.init.xavier_uniform_(self.weight) |
|
|
|
def forward(self, x): |
|
return nn.functional.linear(x, self.weight.exp()) |
|
|
|
|
|
class EquivariantRegressionClassificationStarPoolScalar(EquivariantStarPoolScalar): |
|
def __init__( |
|
self, |
|
args, |
|
activation="pass", |
|
allow_prior_model=True, |
|
): |
|
x_channels = args["x_channels"] |
|
super(EquivariantRegressionClassificationStarPoolScalar, self).__init__( |
|
args=args, |
|
activation=activation, |
|
allow_prior_model=allow_prior_model, |
|
) |
|
|
|
|
|
self.StarPoolLayers = nn.ModuleList([ |
|
StarPool(hidden_channels=x_channels, |
|
edge_channels=args["num_rbf"] + args["num_edge_attr"], |
|
cutoff_lower=args["cutoff_lower"], |
|
cutoff_upper=args["cutoff_upper"], |
|
use_lora=args["use_lora"], |
|
drop_out=args["drop_out"], |
|
ratio=0.5), |
|
]) |
|
self.output_network_1 = nn.Sequential( |
|
nn.Linear(x_channels, args["output_dim_1"]), |
|
act_class_mapping["pass"](), |
|
) |
|
self.output_network_2 = nn.Sequential( |
|
nn.Linear(args["output_dim_1"], args["output_dim_2"]), |
|
act_class_mapping["sigmoid"](), |
|
) |
|
|
|
def reduce(self, x, edge_index, edge_attr, batch): |
|
for layer in self.StarPoolLayers: |
|
x, edge_index, edge_attr, batch, attn = layer(x, edge_index, edge_attr, batch) |
|
out = scatter(x, batch, dim=0, reduce=self.reduce_op) |
|
return out, attn |
|
|
|
def post_reduce(self, x): |
|
x = self.output_network_1(x) |
|
|
|
output = torch.cat((self.output_network_2(x), x), dim=-1) |
|
return output |
|
|
|
|
|
class EquivariantMaskPredictScalar(OutputModel): |
|
def __init__( |
|
self, |
|
args, |
|
lm_weight, |
|
activation="gelu", |
|
allow_prior_model=True, |
|
): |
|
x_channels = args["x_channels"] |
|
out_channels = args["x_channels"] |
|
reduce_op = args["reduce_op"] |
|
super(EquivariantMaskPredictScalar, self).__init__( |
|
allow_prior_model=allow_prior_model, reduce_op=reduce_op |
|
) |
|
act_class = act_class_mapping[activation] |
|
self.output_network = nn.Sequential( |
|
nn.Linear(x_channels, out_channels), |
|
act_class(), |
|
nn.LayerNorm(out_channels), |
|
) |
|
self.lm_weight = lm_weight |
|
self.bias = nn.Parameter(torch.zeros(args["x_in_channels"])) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
nn.init.xavier_uniform_(self.output_network[0].weight) |
|
self.output_network[0].bias.data.fill_(0) |
|
|
|
def pre_reduce(self, x, v: Optional[torch.Tensor], pos, batch): |
|
x = self.output_network(x) |
|
x = nn.functional.linear(x, self.lm_weight) + self.bias |
|
return x |
|
|
|
def post_reduce(self, x): |
|
return x |
|
|
|
|
|
class EquivariantMaskPredictLogLogitsScalar(EquivariantMaskPredictScalar): |
|
def __init__( |
|
self, |
|
args, |
|
lm_weight, |
|
activation="gelu", |
|
allow_prior_model=True, |
|
): |
|
super(EquivariantMaskPredictLogLogitsScalar, self).__init__( |
|
args=args, |
|
lm_weight=lm_weight, |
|
activation=activation, |
|
allow_prior_model=allow_prior_model, |
|
) |
|
|
|
def pre_reduce(self, x, v: Optional[torch.Tensor], pos, batch): |
|
x = self.output_network(x) |
|
x = nn.functional.linear(x, self.lm_weight) + self.bias |
|
x = torch.log_softmax(x, dim=-1) |
|
return x |
|
|
|
|
|
class Scalar(OutputModel): |
|
def __init__( |
|
self, |
|
args, |
|
activation="silu", |
|
allow_prior_model=True, |
|
): |
|
x_channels = args["x_channels"] |
|
out_channels = args["output_dim"] |
|
reduce_op = args["reduce_op"] |
|
super(Scalar, self).__init__( |
|
allow_prior_model=allow_prior_model, reduce_op=reduce_op |
|
) |
|
act_class = act_class_mapping[activation] |
|
self.output_network = nn.Sequential( |
|
nn.Linear(x_channels, x_channels // 2), |
|
act_class(), |
|
nn.Linear(x_channels // 2, out_channels), |
|
) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
nn.init.xavier_uniform_(self.output_network[0].weight) |
|
self.output_network[0].bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.output_network[2].weight) |
|
self.output_network[2].bias.data.fill_(0) |
|
|
|
def pre_reduce(self, x, v: Optional[torch.Tensor], pos, batch): |
|
return self.output_network(x) |
|
|
|
|
|
def gelu(x): |
|
"""Implementation of the gelu activation function. |
|
|
|
For information: OpenAI GPT's gelu is slightly different |
|
(and gives slightly different results): |
|
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
|
""" |
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
|
|
|
|
|
class StarPool(MessagePassing, metaclass=ABCMeta): |
|
def __init__(self, |
|
hidden_channels, |
|
edge_channels, |
|
cutoff_lower, |
|
cutoff_upper, |
|
ratio=0.5, |
|
drop_out=0.0, |
|
num_heads=32, |
|
use_lora=None, |
|
non_linearity=torch.tanh): |
|
super(StarPool, self).__init__(aggr="mean") |
|
if use_lora is not None: |
|
self.q_proj = lora.Linear(hidden_channels, hidden_channels, r=use_lora) |
|
self.kv_proj = lora.Linear(hidden_channels, hidden_channels, r=use_lora) |
|
self.dk_proj = lora.Linear(edge_channels, hidden_channels, r=use_lora) |
|
|
|
|
|
else: |
|
self.q_proj = nn.Linear(hidden_channels, hidden_channels) |
|
self.kv_proj = nn.Linear(hidden_channels, hidden_channels) |
|
self.dk_proj = nn.Linear(edge_channels, hidden_channels) |
|
|
|
|
|
self.layernorm_in = nn.LayerNorm(hidden_channels) |
|
self.layernorm_out = nn.LayerNorm(hidden_channels) |
|
self.num_heads = num_heads |
|
self.hidden_channels = hidden_channels |
|
self.x_head_dim = hidden_channels // num_heads |
|
self.node_dim = 0 |
|
self.attn_activation = act_class_mapping["silu"]() |
|
self.drop_out = nn.Dropout(drop_out) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
nn.init.xavier_uniform_(self.q_proj.weight) |
|
nn.init.xavier_uniform_(self.kv_proj.weight) |
|
|
|
nn.init.xavier_uniform_(self.dk_proj.weight) |
|
self.q_proj.bias.data.fill_(0) |
|
self.kv_proj.bias.data.fill_(0) |
|
|
|
self.dk_proj.bias.data.fill_(0) |
|
|
|
def forward(self, x, edge_index, edge_attr, batch=None): |
|
residue = x |
|
x = self.layernorm_in(x) |
|
if batch is None: |
|
batch = edge_index.new_zeros(x.size(0)) |
|
q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim) |
|
k = self.kv_proj(x).reshape(-1, self.num_heads, self.x_head_dim) |
|
|
|
v = k |
|
dk = self.dk_proj(edge_attr).reshape(-1, self.num_heads, self.x_head_dim) |
|
x, attn = self.propagate( |
|
edge_index = edge_index, |
|
q=q, |
|
k=k, |
|
v=v, |
|
dk=dk, |
|
size=None |
|
) |
|
x = x.reshape(-1, self.hidden_channels) |
|
x = residue + x |
|
|
|
center_nodes = torch.unique(edge_index[1]) |
|
perm = center_nodes |
|
x = x[perm] |
|
batch = batch[perm] |
|
residue = residue[perm] |
|
x = self.layernorm_out(x) |
|
x = residue + self.drop_out(x) |
|
|
|
return x, edge_index, edge_attr, batch, attn |
|
|
|
def message(self, q_i, k_j, v_j, dk): |
|
attn = (q_i * k_j * dk).sum(dim=-1) |
|
|
|
attn = self.attn_activation(attn) |
|
|
|
x = v_j * attn.unsqueeze(2) |
|
return x, attn |
|
|
|
def aggregate( |
|
self, |
|
features: Tuple[torch.Tensor, torch.Tensor], |
|
index: torch.Tensor, |
|
ptr: Optional[torch.Tensor], |
|
dim_size: Optional[int], |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
x, attn = features |
|
x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr) |
|
return x, attn |
|
|
|
def update( |
|
self, inputs: Tuple[torch.Tensor, torch.Tensor] |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
return inputs |
|
|
|
def message_and_aggregate(self, adj_t) -> torch.Tensor: |
|
pass |
|
|
|
def edge_update(self) -> torch.Tensor: |
|
pass |
|
|
|
|
|
def build_output_model(output_model_name, args, **kwargs): |
|
if output_model_name == "EquivariantBinaryClassificationNoGraphScalar": |
|
return EquivariantNoGraphScalar(args=args, activation="sigmoid") |
|
elif output_model_name == "EquivariantBinaryClassificationScalar": |
|
return EquivariantScalar(args=args, activation="sigmoid") |
|
elif output_model_name == "ESMBinaryClassificationScalar": |
|
return ESMScalar(args=args, activation="sigmoid", **kwargs) |
|
elif output_model_name == "ESMFullGraphBinaryClassificationScalar": |
|
return ESMFullGraphScalar(args=args, activation="sigmoid", **kwargs) |
|
elif output_model_name == "EquivariantBinaryClassificationStarPoolScalar": |
|
return EquivariantStarPoolScalar(args=args, activation="sigmoid", init_fn=args["init_fn"]) |
|
elif output_model_name == "EquivariantBinaryClassificationStarPoolMeanVarScalar": |
|
return EquivariantStarPoolMeanVarScalar(args=args, activation="softplus") |
|
elif output_model_name == "EquivariantBinaryClassificationAttnScalar": |
|
return EquivariantAttnScalar(args=args, activation="sigmoid") |
|
elif output_model_name == "EquivariantBinaryClassificationPAEAttnScalar": |
|
return EquivariantPAEAttnScalar(args=args, activation="sigmoid") |
|
elif output_model_name == "EquivariantBinaryClassificationStarPoolOneSiteScalar": |
|
return EquivariantStarPoolOneSiteScalar(args=args, activation="sigmoid") |
|
elif output_model_name == "EquivariantBinaryClassificationAttnOneSiteScalar": |
|
return EquivariantAttnOneSiteScalar(args=args, activation="sigmoid") |
|
elif output_model_name == "EquivariantBinaryClassificationStarPoolGPScalar": |
|
return EquivariantStarPoolGPScalar(args=args, activation="sigmoid") |
|
elif output_model_name == "EquivariantRegressionScalar": |
|
return EquivariantScalar(args=args, activation="pass") |
|
elif output_model_name == "ESMRegressionScalar": |
|
return ESMScalar(args=args, activation="pass", **kwargs) |
|
elif output_model_name == "ESMFullGraphRegressionScalar": |
|
return ESMFullGraphScalar(args=args, activation="pass", **kwargs) |
|
elif output_model_name == "EquivariantRegressionStarPoolScalar": |
|
return EquivariantStarPoolScalar(args=args, activation="pass") |
|
elif output_model_name == "EquivariantRegressionStarPoolMeanVarScalar": |
|
return EquivariantStarPoolMeanVarScalar(args=args, activation="pass") |
|
elif output_model_name == "EquivariantRegressionAttnScalar": |
|
return EquivariantAttnScalar(args=args, activation="pass") |
|
elif output_model_name == "EquivariantRegressionPAEAttnScalar": |
|
return EquivariantPAEAttnScalar(args=args, activation="pass") |
|
elif output_model_name == "EquivariantRegressionStarPoolOneSiteScalar": |
|
return EquivariantStarPoolOneSiteScalar(args=args, activation="pass") |
|
elif output_model_name == "EquivariantRegressionAttnOneSiteScalar": |
|
return EquivariantAttnOneSiteScalar(args=args, activation="pass") |
|
else: |
|
raise NotImplementedError |