|
import collections
|
|
import math
|
|
from argparse import ArgumentParser
|
|
import enum
|
|
from os.path import isfile
|
|
from typing import List, Tuple, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
import metl.relative_attention as ra
|
|
|
|
|
|
def reset_parameters_helper(m: nn.Module):
|
|
""" helper function for resetting model parameters, meant to be used with model.apply() """
|
|
|
|
|
|
|
|
reset_parameters = getattr(m, "reset_parameters", None)
|
|
reset_parameters_private = getattr(m, "_reset_parameters", None)
|
|
|
|
if callable(reset_parameters) and callable(reset_parameters_private):
|
|
raise RuntimeError("Module has both public and private methods for resetting parameters. "
|
|
"This is unexpected... probably should just call the public one.")
|
|
|
|
if callable(reset_parameters):
|
|
m.reset_parameters()
|
|
|
|
if callable(reset_parameters_private):
|
|
m._reset_parameters()
|
|
|
|
|
|
class SequentialWithArgs(nn.Sequential):
|
|
def forward(self, x, **kwargs):
|
|
for module in self:
|
|
if isinstance(module, ra.RelativeTransformerEncoder) or isinstance(module, SequentialWithArgs):
|
|
|
|
x = module(x, **kwargs)
|
|
else:
|
|
|
|
x = module(x)
|
|
return x
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
|
|
|
|
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
|
super(PositionalEncoding, self).__init__()
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
|
pe = torch.zeros(max_len, d_model)
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
|
|
|
|
|
|
|
|
|
|
pe = pe.unsqueeze(0)
|
|
self.register_buffer('pe', pe)
|
|
|
|
def forward(self, x, **kwargs):
|
|
|
|
|
|
|
|
|
|
x = x + self.pe[:, :x.size(1), :]
|
|
return self.dropout(x)
|
|
|
|
|
|
class ScaledEmbedding(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int, scale: bool):
|
|
super(ScaledEmbedding, self).__init__()
|
|
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
|
self.emb_size = embedding_dim
|
|
self.embed_scale = math.sqrt(self.emb_size)
|
|
|
|
self.scale = scale
|
|
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
|
|
|
|
|
|
init_range = 0.1
|
|
self.embedding.weight.data.uniform_(-init_range, init_range)
|
|
|
|
def forward(self, tokens: Tensor, **kwargs):
|
|
if self.scale:
|
|
return self.embedding(tokens.long()) * self.embed_scale
|
|
else:
|
|
return self.embedding(tokens.long())
|
|
|
|
|
|
class FCBlock(nn.Module):
|
|
""" a fully connected block with options for batchnorm and dropout
|
|
can extend in the future with option for different activation, etc """
|
|
|
|
def __init__(self,
|
|
in_features: int,
|
|
num_hidden_nodes: int = 64,
|
|
use_batchnorm: bool = False,
|
|
use_layernorm: bool = False,
|
|
norm_before_activation: bool = False,
|
|
use_dropout: bool = False,
|
|
dropout_rate: float = 0.2,
|
|
activation: str = "relu"):
|
|
|
|
super().__init__()
|
|
|
|
if use_batchnorm and use_layernorm:
|
|
raise ValueError("Only one of use_batchnorm or use_layernorm can be set to True")
|
|
|
|
self.use_batchnorm = use_batchnorm
|
|
self.use_dropout = use_dropout
|
|
self.use_layernorm = use_layernorm
|
|
self.norm_before_activation = norm_before_activation
|
|
|
|
self.fc = nn.Linear(in_features=in_features, out_features=num_hidden_nodes)
|
|
|
|
self.activation = get_activation_fn(activation, functional=False)
|
|
|
|
if use_batchnorm:
|
|
self.norm = nn.BatchNorm1d(num_hidden_nodes)
|
|
|
|
if use_layernorm:
|
|
self.norm = nn.LayerNorm(num_hidden_nodes)
|
|
|
|
if use_dropout:
|
|
self.dropout = nn.Dropout(p=dropout_rate)
|
|
|
|
def forward(self, x, **kwargs):
|
|
x = self.fc(x)
|
|
|
|
|
|
if (self.use_batchnorm or self.use_layernorm) and self.norm_before_activation:
|
|
x = self.norm(x)
|
|
|
|
x = self.activation(x)
|
|
|
|
|
|
if (self.use_batchnorm or self.use_layernorm) and not self.norm_before_activation:
|
|
x = self.norm(x)
|
|
|
|
|
|
if self.use_dropout:
|
|
x = self.dropout(x)
|
|
|
|
return x
|
|
|
|
|
|
class TaskSpecificPredictionLayers(nn.Module):
|
|
""" Constructs num_tasks [dense(num_hidden_nodes)+relu+dense(1)] layers, each independently transforming input
|
|
into a single output node. All num_tasks outputs are then concatenated into a single tensor. """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
num_tasks: int,
|
|
in_features: int,
|
|
num_hidden_nodes: int = 64,
|
|
use_batchnorm: bool = False,
|
|
use_dropout: bool = False,
|
|
dropout_rate: float = 0.2,
|
|
activation: str = "relu"):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.task_specific_pred_layers = nn.ModuleList()
|
|
for i in range(num_tasks):
|
|
layers = [FCBlock(in_features=in_features,
|
|
num_hidden_nodes=num_hidden_nodes,
|
|
use_batchnorm=use_batchnorm,
|
|
use_dropout=use_dropout,
|
|
dropout_rate=dropout_rate,
|
|
activation=activation),
|
|
nn.Linear(in_features=num_hidden_nodes, out_features=1)]
|
|
self.task_specific_pred_layers.append(nn.Sequential(*layers))
|
|
|
|
def forward(self, x, **kwargs):
|
|
|
|
task_specific_outputs = []
|
|
for layer in self.task_specific_pred_layers:
|
|
task_specific_outputs.append(layer(x))
|
|
|
|
output = torch.cat(task_specific_outputs, dim=1)
|
|
return output
|
|
|
|
|
|
class GlobalAveragePooling(nn.Module):
|
|
""" helper class for global average pooling """
|
|
|
|
def __init__(self, dim=1):
|
|
super().__init__()
|
|
|
|
|
|
self.dim = dim
|
|
|
|
def forward(self, x, **kwargs):
|
|
return torch.mean(x, dim=self.dim)
|
|
|
|
|
|
class CLSPooling(nn.Module):
|
|
""" helper class for CLS token extraction """
|
|
|
|
def __init__(self, cls_position=0):
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.cls_position = cls_position
|
|
|
|
def forward(self, x, **kwargs):
|
|
|
|
|
|
return x[:, self.cls_position, :]
|
|
|
|
|
|
class TransformerEncoderWrapper(nn.TransformerEncoder):
|
|
""" wrapper around PyTorch's TransformerEncoder that re-initializes layer parameters,
|
|
so each transformer encoder layer has a different initialization """
|
|
|
|
|
|
def __init__(self, encoder_layer, num_layers, norm=None, reset_params=True):
|
|
super().__init__(encoder_layer, num_layers, norm)
|
|
if reset_params:
|
|
self.apply(reset_parameters_helper)
|
|
|
|
|
|
class AttnModel(nn.Module):
|
|
|
|
|
|
@staticmethod
|
|
def add_model_specific_args(parent_parser):
|
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
|
|
|
parser.add_argument('--pos_encoding', type=str, default="absolute",
|
|
choices=["none", "absolute", "relative", "relative_3D"],
|
|
help="what type of positional encoding to use")
|
|
parser.add_argument('--pos_encoding_dropout', type=float, default=0.1,
|
|
help="out much dropout to use in positional encoding, for pos_encoding==absolute")
|
|
parser.add_argument('--clipping_threshold', type=int, default=3,
|
|
help="clipping threshold for relative position embedding, for relative and relative_3D")
|
|
parser.add_argument('--contact_threshold', type=int, default=7,
|
|
help="threshold, in angstroms, for contact map, for relative_3D")
|
|
parser.add_argument('--embedding_len', type=int, default=128)
|
|
parser.add_argument('--num_heads', type=int, default=2)
|
|
parser.add_argument('--num_hidden', type=int, default=64)
|
|
parser.add_argument('--num_enc_layers', type=int, default=2)
|
|
parser.add_argument('--enc_layer_dropout', type=float, default=0.1)
|
|
parser.add_argument('--use_final_encoder_norm', action="store_true", default=False)
|
|
|
|
parser.add_argument('--global_average_pooling', action="store_true", default=False)
|
|
parser.add_argument('--cls_pooling', action="store_true", default=False)
|
|
|
|
parser.add_argument('--use_task_specific_layers', action="store_true", default=False,
|
|
help="exclusive with use_final_hidden_layer; takes priority over use_final_hidden_layer"
|
|
" if both flags are set")
|
|
parser.add_argument('--task_specific_hidden_nodes', type=int, default=64)
|
|
parser.add_argument('--use_final_hidden_layer', action="store_true", default=False)
|
|
parser.add_argument('--final_hidden_size', type=int, default=64)
|
|
parser.add_argument('--use_final_hidden_layer_norm', action="store_true", default=False)
|
|
parser.add_argument('--final_hidden_layer_norm_before_activation', action="store_true", default=False)
|
|
parser.add_argument('--use_final_hidden_layer_dropout', action="store_true", default=False)
|
|
parser.add_argument('--final_hidden_layer_dropout_rate', type=float, default=0.2)
|
|
|
|
parser.add_argument('--activation', type=str, default="relu",
|
|
help="activation function used for all activations in the network")
|
|
return parser
|
|
|
|
def __init__(self,
|
|
|
|
num_tasks: int,
|
|
aa_seq_len: int,
|
|
num_tokens: int,
|
|
|
|
pos_encoding: str = "absolute",
|
|
pos_encoding_dropout: float = 0.1,
|
|
clipping_threshold: int = 3,
|
|
contact_threshold: int = 7,
|
|
pdb_fns: List[str] = None,
|
|
embedding_len: int = 64,
|
|
num_heads: int = 2,
|
|
num_hidden: int = 64,
|
|
num_enc_layers: int = 2,
|
|
enc_layer_dropout: float = 0.1,
|
|
use_final_encoder_norm: bool = False,
|
|
|
|
global_average_pooling: bool = True,
|
|
cls_pooling: bool = False,
|
|
|
|
use_task_specific_layers: bool = False,
|
|
task_specific_hidden_nodes: int = 64,
|
|
use_final_hidden_layer: bool = False,
|
|
final_hidden_size: int = 64,
|
|
use_final_hidden_layer_norm: bool = False,
|
|
final_hidden_layer_norm_before_activation: bool = False,
|
|
use_final_hidden_layer_dropout: bool = False,
|
|
final_hidden_layer_dropout_rate: float = 0.2,
|
|
|
|
activation: str = "relu",
|
|
*args, **kwargs):
|
|
|
|
super().__init__()
|
|
|
|
|
|
self.embedding_len = embedding_len
|
|
self.aa_seq_len = aa_seq_len
|
|
|
|
|
|
layers = collections.OrderedDict()
|
|
|
|
|
|
layers["embedder"] = ScaledEmbedding(num_embeddings=num_tokens, embedding_dim=embedding_len, scale=True)
|
|
|
|
|
|
if pos_encoding == "absolute":
|
|
layers["pos_encoder"] = PositionalEncoding(embedding_len, dropout=pos_encoding_dropout, max_len=512)
|
|
|
|
|
|
if pos_encoding in ["none", "absolute"]:
|
|
encoder_layer = torch.nn.TransformerEncoderLayer(d_model=embedding_len,
|
|
nhead=num_heads,
|
|
dim_feedforward=num_hidden,
|
|
dropout=enc_layer_dropout,
|
|
activation=get_activation_fn(activation),
|
|
norm_first=True,
|
|
batch_first=True)
|
|
|
|
|
|
|
|
|
|
|
|
encoder_norm = None
|
|
if use_final_encoder_norm:
|
|
encoder_norm = nn.LayerNorm(embedding_len)
|
|
|
|
layers["tr_encoder"] = TransformerEncoderWrapper(encoder_layer=encoder_layer,
|
|
num_layers=num_enc_layers,
|
|
norm=encoder_norm)
|
|
|
|
|
|
elif pos_encoding in ["relative", "relative_3D"]:
|
|
relative_encoder_layer = ra.RelativeTransformerEncoderLayer(d_model=embedding_len,
|
|
nhead=num_heads,
|
|
pos_encoding=pos_encoding,
|
|
clipping_threshold=clipping_threshold,
|
|
contact_threshold=contact_threshold,
|
|
pdb_fns=pdb_fns,
|
|
dim_feedforward=num_hidden,
|
|
dropout=enc_layer_dropout,
|
|
activation=get_activation_fn(activation),
|
|
norm_first=True)
|
|
|
|
encoder_norm = None
|
|
if use_final_encoder_norm:
|
|
encoder_norm = nn.LayerNorm(embedding_len)
|
|
|
|
layers["tr_encoder"] = ra.RelativeTransformerEncoder(encoder_layer=relative_encoder_layer,
|
|
num_layers=num_enc_layers,
|
|
norm=encoder_norm)
|
|
|
|
|
|
|
|
if global_average_pooling:
|
|
|
|
layers["avg_pooling"] = GlobalAveragePooling(dim=1)
|
|
pred_layer_input_features = embedding_len
|
|
elif cls_pooling:
|
|
layers["cls_pooling"] = CLSPooling(cls_position=0)
|
|
pred_layer_input_features = embedding_len
|
|
else:
|
|
|
|
|
|
layers["flatten"] = nn.Flatten()
|
|
pred_layer_input_features = embedding_len * aa_seq_len
|
|
|
|
|
|
if use_task_specific_layers:
|
|
|
|
layers["prediction"] = TaskSpecificPredictionLayers(num_tasks=num_tasks,
|
|
in_features=pred_layer_input_features,
|
|
num_hidden_nodes=task_specific_hidden_nodes,
|
|
activation=activation)
|
|
elif use_final_hidden_layer:
|
|
|
|
layers["fc1"] = FCBlock(in_features=pred_layer_input_features,
|
|
num_hidden_nodes=final_hidden_size,
|
|
use_batchnorm=False,
|
|
use_layernorm=use_final_hidden_layer_norm,
|
|
norm_before_activation=final_hidden_layer_norm_before_activation,
|
|
use_dropout=use_final_hidden_layer_dropout,
|
|
dropout_rate=final_hidden_layer_dropout_rate,
|
|
activation=activation)
|
|
|
|
layers["prediction"] = nn.Linear(in_features=final_hidden_size, out_features=num_tasks)
|
|
else:
|
|
layers["prediction"] = nn.Linear(in_features=pred_layer_input_features, out_features=num_tasks)
|
|
|
|
|
|
self.model = SequentialWithArgs(layers)
|
|
|
|
def forward(self, x, **kwargs):
|
|
return self.model(x, **kwargs)
|
|
|
|
|
|
class Transpose(nn.Module):
|
|
""" helper layer to swap data from (batch, seq, channels) to (batch, channels, seq)
|
|
used as a helper in the convolutional network which pytorch defaults to channels-first """
|
|
|
|
def __init__(self, dims: Tuple[int, ...] = (1, 2)):
|
|
super().__init__()
|
|
self.dims = dims
|
|
|
|
def forward(self, x, **kwargs):
|
|
x = x.transpose(*self.dims).contiguous()
|
|
return x
|
|
|
|
|
|
def conv1d_out_shape(seq_len, kernel_size, stride=1, pad=0, dilation=1):
|
|
return (seq_len + (2 * pad) - (dilation * (kernel_size - 1)) - 1 // stride) + 1
|
|
|
|
|
|
class ConvBlock(nn.Module):
|
|
def __init__(self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: int,
|
|
dilation: int = 1,
|
|
padding: str = "same",
|
|
use_batchnorm: bool = False,
|
|
use_layernorm: bool = False,
|
|
norm_before_activation: bool = False,
|
|
use_dropout: bool = False,
|
|
dropout_rate: float = 0.2,
|
|
activation: str = "relu"):
|
|
|
|
super().__init__()
|
|
|
|
if use_batchnorm and use_layernorm:
|
|
raise ValueError("Only one of use_batchnorm or use_layernorm can be set to True")
|
|
|
|
self.use_batchnorm = use_batchnorm
|
|
self.use_layernorm = use_layernorm
|
|
self.norm_before_activation = norm_before_activation
|
|
self.use_dropout = use_dropout
|
|
|
|
self.conv = nn.Conv1d(in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
padding=padding,
|
|
dilation=dilation)
|
|
|
|
self.activation = get_activation_fn(activation, functional=False)
|
|
|
|
if use_batchnorm:
|
|
self.norm = nn.BatchNorm1d(out_channels)
|
|
|
|
if use_layernorm:
|
|
self.norm = nn.LayerNorm(out_channels)
|
|
|
|
if use_dropout:
|
|
self.dropout = nn.Dropout(p=dropout_rate)
|
|
|
|
def forward(self, x, **kwargs):
|
|
x = self.conv(x)
|
|
|
|
|
|
if self.use_batchnorm and self.norm_before_activation:
|
|
x = self.norm(x)
|
|
elif self.use_layernorm and self.norm_before_activation:
|
|
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
|
|
|
|
x = self.activation(x)
|
|
|
|
|
|
if self.use_batchnorm and not self.norm_before_activation:
|
|
x = self.norm(x)
|
|
elif self.use_layernorm and not self.norm_before_activation:
|
|
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
|
|
|
|
|
|
if self.use_dropout:
|
|
x = self.dropout(x)
|
|
|
|
return x
|
|
|
|
|
|
class ConvModel2(nn.Module):
|
|
""" convolutional source model that supports padded inputs, pooling, etc """
|
|
|
|
@staticmethod
|
|
def add_model_specific_args(parent_parser):
|
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
|
parser.add_argument('--use_embedding', action="store_true", default=False)
|
|
parser.add_argument('--embedding_len', type=int, default=128)
|
|
|
|
parser.add_argument('--num_conv_layers', type=int, default=1)
|
|
parser.add_argument('--kernel_sizes', type=int, nargs="+", default=[7])
|
|
parser.add_argument('--out_channels', type=int, nargs="+", default=[128])
|
|
parser.add_argument('--dilations', type=int, nargs="+", default=[1])
|
|
parser.add_argument('--padding', type=str, default="valid", choices=["valid", "same"])
|
|
parser.add_argument('--use_conv_layer_norm', action="store_true", default=False)
|
|
parser.add_argument('--conv_layer_norm_before_activation', action="store_true", default=False)
|
|
parser.add_argument('--use_conv_layer_dropout', action="store_true", default=False)
|
|
parser.add_argument('--conv_layer_dropout_rate', type=float, default=0.2)
|
|
|
|
parser.add_argument('--global_average_pooling', action="store_true", default=False)
|
|
|
|
parser.add_argument('--use_task_specific_layers', action="store_true", default=False)
|
|
parser.add_argument('--task_specific_hidden_nodes', type=int, default=64)
|
|
parser.add_argument('--use_final_hidden_layer', action="store_true", default=False)
|
|
parser.add_argument('--final_hidden_size', type=int, default=64)
|
|
parser.add_argument('--use_final_hidden_layer_norm', action="store_true", default=False)
|
|
parser.add_argument('--final_hidden_layer_norm_before_activation', action="store_true", default=False)
|
|
parser.add_argument('--use_final_hidden_layer_dropout', action="store_true", default=False)
|
|
parser.add_argument('--final_hidden_layer_dropout_rate', type=float, default=0.2)
|
|
|
|
parser.add_argument('--activation', type=str, default="relu",
|
|
help="activation function used for all activations in the network")
|
|
|
|
return parser
|
|
|
|
def __init__(self,
|
|
|
|
num_tasks: int,
|
|
aa_seq_len: int,
|
|
aa_encoding_len: int,
|
|
num_tokens: int,
|
|
|
|
use_embedding: bool = False,
|
|
embedding_len: int = 64,
|
|
num_conv_layers: int = 1,
|
|
kernel_sizes: List[int] = (7,),
|
|
out_channels: List[int] = (128,),
|
|
dilations: List[int] = (1,),
|
|
padding: str = "valid",
|
|
use_conv_layer_norm: bool = False,
|
|
conv_layer_norm_before_activation: bool = False,
|
|
use_conv_layer_dropout: bool = False,
|
|
conv_layer_dropout_rate: float = 0.2,
|
|
|
|
global_average_pooling: bool = True,
|
|
|
|
use_task_specific_layers: bool = False,
|
|
task_specific_hidden_nodes: int = 64,
|
|
use_final_hidden_layer: bool = False,
|
|
final_hidden_size: int = 64,
|
|
use_final_hidden_layer_norm: bool = False,
|
|
final_hidden_layer_norm_before_activation: bool = False,
|
|
use_final_hidden_layer_dropout: bool = False,
|
|
final_hidden_layer_dropout_rate: float = 0.2,
|
|
|
|
activation: str = "relu",
|
|
*args, **kwargs):
|
|
|
|
super(ConvModel2, self).__init__()
|
|
|
|
|
|
layers = collections.OrderedDict()
|
|
|
|
|
|
if use_embedding:
|
|
layers["embedder"] = ScaledEmbedding(num_embeddings=num_tokens, embedding_dim=embedding_len, scale=False)
|
|
|
|
|
|
layers["transpose"] = Transpose(dims=(1, 2))
|
|
|
|
|
|
for layer_num in range(num_conv_layers):
|
|
|
|
if layer_num == 0 and use_embedding:
|
|
|
|
in_channels = embedding_len
|
|
elif layer_num == 0 and not use_embedding:
|
|
|
|
in_channels = aa_encoding_len
|
|
else:
|
|
in_channels = out_channels[layer_num - 1]
|
|
|
|
layers[f"conv{layer_num}"] = ConvBlock(in_channels=in_channels,
|
|
out_channels=out_channels[layer_num],
|
|
kernel_size=kernel_sizes[layer_num],
|
|
dilation=dilations[layer_num],
|
|
padding=padding,
|
|
use_batchnorm=False,
|
|
use_layernorm=use_conv_layer_norm,
|
|
norm_before_activation=conv_layer_norm_before_activation,
|
|
use_dropout=use_conv_layer_dropout,
|
|
dropout_rate=conv_layer_dropout_rate,
|
|
activation=activation)
|
|
|
|
|
|
|
|
|
|
if global_average_pooling:
|
|
|
|
|
|
layers["avg_pooling"] = GlobalAveragePooling(dim=-1)
|
|
|
|
pred_layer_input_features = out_channels[-1]
|
|
|
|
else:
|
|
|
|
layers["flatten"] = nn.Flatten()
|
|
|
|
|
|
if padding == "valid":
|
|
|
|
conv_out_len = conv1d_out_shape(aa_seq_len, kernel_size=kernel_sizes[0], dilation=dilations[0])
|
|
for layer_num in range(1, num_conv_layers):
|
|
conv_out_len = conv1d_out_shape(conv_out_len,
|
|
kernel_size=kernel_sizes[layer_num],
|
|
dilation=dilations[layer_num])
|
|
pred_layer_input_features = conv_out_len * out_channels[-1]
|
|
else:
|
|
|
|
pred_layer_input_features = aa_seq_len * out_channels[-1]
|
|
|
|
|
|
if use_task_specific_layers:
|
|
layers["prediction"] = TaskSpecificPredictionLayers(num_tasks=num_tasks,
|
|
in_features=pred_layer_input_features,
|
|
num_hidden_nodes=task_specific_hidden_nodes,
|
|
activation=activation)
|
|
|
|
|
|
elif use_final_hidden_layer:
|
|
layers["fc1"] = FCBlock(in_features=pred_layer_input_features,
|
|
num_hidden_nodes=final_hidden_size,
|
|
use_batchnorm=False,
|
|
use_layernorm=use_final_hidden_layer_norm,
|
|
norm_before_activation=final_hidden_layer_norm_before_activation,
|
|
use_dropout=use_final_hidden_layer_dropout,
|
|
dropout_rate=final_hidden_layer_dropout_rate,
|
|
activation=activation)
|
|
layers["prediction"] = nn.Linear(in_features=final_hidden_size, out_features=num_tasks)
|
|
|
|
else:
|
|
layers["prediction"] = nn.Linear(in_features=pred_layer_input_features, out_features=num_tasks)
|
|
|
|
self.model = nn.Sequential(layers)
|
|
|
|
def forward(self, x, **kwargs):
|
|
output = self.model(x)
|
|
return output
|
|
|
|
|
|
class ConvModel(nn.Module):
|
|
""" a convolutional network with convolutional layers followed by a fully connected layer """
|
|
|
|
@staticmethod
|
|
def add_model_specific_args(parent_parser):
|
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
|
parser.add_argument('--num_conv_layers', type=int, default=1)
|
|
parser.add_argument('--kernel_sizes', type=int, nargs="+", default=[7])
|
|
parser.add_argument('--out_channels', type=int, nargs="+", default=[128])
|
|
parser.add_argument('--padding', type=str, default="valid", choices=["valid", "same"])
|
|
parser.add_argument('--use_final_hidden_layer', action="store_true",
|
|
help="whether to use a final hidden layer")
|
|
parser.add_argument('--final_hidden_size', type=int, default=128,
|
|
help="number of nodes in the final hidden layer")
|
|
parser.add_argument('--use_dropout', action="store_true",
|
|
help="whether to use dropout in the final hidden layer")
|
|
parser.add_argument('--dropout_rate', type=float, default=0.2,
|
|
help="dropout rate in the final hidden layer")
|
|
parser.add_argument('--use_task_specific_layers', action="store_true", default=False)
|
|
parser.add_argument('--task_specific_hidden_nodes', type=int, default=64)
|
|
return parser
|
|
|
|
def __init__(self,
|
|
num_tasks: int,
|
|
aa_seq_len: int,
|
|
aa_encoding_len: int,
|
|
num_conv_layers: int = 1,
|
|
kernel_sizes: List[int] = (7,),
|
|
out_channels: List[int] = (128,),
|
|
padding: str = "valid",
|
|
use_final_hidden_layer: bool = True,
|
|
final_hidden_size: int = 128,
|
|
use_dropout: bool = False,
|
|
dropout_rate: float = 0.2,
|
|
use_task_specific_layers: bool = False,
|
|
task_specific_hidden_nodes: int = 64,
|
|
*args, **kwargs):
|
|
|
|
super(ConvModel, self).__init__()
|
|
|
|
|
|
layers = collections.OrderedDict()
|
|
|
|
layers["transpose"] = Transpose(dims=(1, 2))
|
|
|
|
for layer_num in range(num_conv_layers):
|
|
|
|
in_channels = aa_encoding_len if layer_num == 0 else out_channels[layer_num - 1]
|
|
|
|
layers["conv{}".format(layer_num)] = nn.Sequential(
|
|
nn.Conv1d(in_channels=in_channels,
|
|
out_channels=out_channels[layer_num],
|
|
kernel_size=kernel_sizes[layer_num],
|
|
padding=padding),
|
|
nn.ReLU()
|
|
)
|
|
|
|
layers["flatten"] = nn.Flatten()
|
|
|
|
|
|
|
|
if padding == "valid":
|
|
|
|
conv_out_len = conv1d_out_shape(aa_seq_len, kernel_size=kernel_sizes[0])
|
|
for layer_num in range(1, num_conv_layers):
|
|
conv_out_len = conv1d_out_shape(conv_out_len, kernel_size=kernel_sizes[layer_num])
|
|
next_dim = conv_out_len * out_channels[-1]
|
|
elif padding == "same":
|
|
next_dim = aa_seq_len * out_channels[-1]
|
|
else:
|
|
raise ValueError("unexpected value for padding: {}".format(padding))
|
|
|
|
|
|
if use_final_hidden_layer:
|
|
layers["fc1"] = FCBlock(in_features=next_dim,
|
|
num_hidden_nodes=final_hidden_size,
|
|
use_batchnorm=False,
|
|
use_dropout=use_dropout,
|
|
dropout_rate=dropout_rate)
|
|
next_dim = final_hidden_size
|
|
|
|
|
|
|
|
if use_task_specific_layers:
|
|
layers["prediction"] = TaskSpecificPredictionLayers(num_tasks=num_tasks,
|
|
in_features=next_dim,
|
|
num_hidden_nodes=task_specific_hidden_nodes)
|
|
else:
|
|
layers["prediction"] = nn.Linear(in_features=next_dim, out_features=num_tasks)
|
|
|
|
self.model = nn.Sequential(layers)
|
|
|
|
def forward(self, x, **kwargs):
|
|
output = self.model(x)
|
|
return output
|
|
|
|
|
|
class FCModel(nn.Module):
|
|
|
|
@staticmethod
|
|
def add_model_specific_args(parent_parser):
|
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
|
parser.add_argument('--num_layers', type=int, default=1)
|
|
parser.add_argument('--num_hidden', nargs="+", type=int, default=[128])
|
|
parser.add_argument('--use_batchnorm', action="store_true", default=False)
|
|
parser.add_argument('--use_layernorm', action="store_true", default=False)
|
|
parser.add_argument('--norm_before_activation', action="store_true", default=False)
|
|
parser.add_argument('--use_dropout', action="store_true", default=False)
|
|
parser.add_argument('--dropout_rate', type=float, default=0.2)
|
|
return parser
|
|
|
|
def __init__(self,
|
|
num_tasks: int,
|
|
seq_encoding_len: int,
|
|
num_layers: int = 1,
|
|
num_hidden: List[int] = (128,),
|
|
use_batchnorm: bool = False,
|
|
use_layernorm: bool = False,
|
|
norm_before_activation: bool = False,
|
|
use_dropout: bool = False,
|
|
dropout_rate: float = 0.2,
|
|
activation: str = "relu",
|
|
*args, **kwargs):
|
|
super().__init__()
|
|
|
|
|
|
layers = collections.OrderedDict()
|
|
|
|
|
|
layers["flatten"] = nn.Flatten()
|
|
|
|
|
|
for layer_num in range(num_layers):
|
|
|
|
|
|
in_features = seq_encoding_len if layer_num == 0 else num_hidden[layer_num - 1]
|
|
|
|
layers["fc{}".format(layer_num)] = FCBlock(in_features=in_features,
|
|
num_hidden_nodes=num_hidden[layer_num],
|
|
use_batchnorm=use_batchnorm,
|
|
use_layernorm=use_layernorm,
|
|
norm_before_activation=norm_before_activation,
|
|
use_dropout=use_dropout,
|
|
dropout_rate=dropout_rate,
|
|
activation=activation)
|
|
|
|
|
|
in_features = num_hidden[-1] if num_layers > 0 else seq_encoding_len
|
|
layers["output"] = nn.Linear(in_features=in_features, out_features=num_tasks)
|
|
|
|
self.model = nn.Sequential(layers)
|
|
|
|
def forward(self, x, **kwargs):
|
|
output = self.model(x)
|
|
return output
|
|
|
|
|
|
class LRModel(nn.Module):
|
|
""" a simple linear model """
|
|
|
|
def __init__(self, num_tasks, seq_encoding_len, *args, **kwargs):
|
|
super().__init__()
|
|
|
|
self.model = nn.Sequential(
|
|
nn.Flatten(),
|
|
nn.Linear(seq_encoding_len, out_features=num_tasks))
|
|
|
|
def forward(self, x, **kwargs):
|
|
output = self.model(x)
|
|
return output
|
|
|
|
|
|
class TransferModel(nn.Module):
|
|
""" transfer learning model """
|
|
|
|
@staticmethod
|
|
def add_model_specific_args(parent_parser):
|
|
|
|
def none_or_int(value: str):
|
|
return None if value.lower() == "none" else int(value)
|
|
|
|
p = ArgumentParser(parents=[parent_parser], add_help=False)
|
|
|
|
|
|
p.add_argument('--pretrained_ckpt_path', type=str, default=None)
|
|
|
|
|
|
p.add_argument("--backbone_cutoff", type=none_or_int, default=-1,
|
|
help="where to cut off the backbone. can be a negative int, indexing back from "
|
|
"pretrained_model.model.model. a value of -1 would chop off the backbone prediction head. "
|
|
"a value of -2 chops the prediction head and FC layer. a value of -3 chops"
|
|
"the above, as well as the global average pooling layer. all depends on architecture.")
|
|
|
|
p.add_argument("--pred_layer_input_features", type=int, default=None,
|
|
help="if None, number of features will be determined based on backbone_cutoff and standard "
|
|
"architecture. otherwise, specify the number of input features for the prediction layer")
|
|
|
|
|
|
p.add_argument("--top_net_type", type=str, default="linear", choices=["linear", "nonlinear", "sklearn"])
|
|
p.add_argument("--top_net_hidden_nodes", type=int, default=256)
|
|
p.add_argument("--top_net_use_batchnorm", action="store_true")
|
|
p.add_argument("--top_net_use_dropout", action="store_true")
|
|
p.add_argument("--top_net_dropout_rate", type=float, default=0.1)
|
|
|
|
return p
|
|
|
|
def __init__(self,
|
|
|
|
pretrained_ckpt_path: Optional[str] = None,
|
|
pretrained_hparams: Optional[dict] = None,
|
|
backbone_cutoff: Optional[int] = -1,
|
|
|
|
pred_layer_input_features: Optional[int] = None,
|
|
top_net_type: str = "linear",
|
|
top_net_hidden_nodes: int = 256,
|
|
top_net_use_batchnorm: bool = False,
|
|
top_net_use_dropout: bool = False,
|
|
top_net_dropout_rate: float = 0.1,
|
|
*args, **kwargs):
|
|
|
|
super().__init__()
|
|
|
|
|
|
if pretrained_ckpt_path is None and pretrained_hparams is None:
|
|
raise ValueError("Either pretrained_ckpt_path or pretrained_hparams must be specified")
|
|
|
|
|
|
|
|
|
|
pdb_fns = kwargs["pdb_fns"] if "pdb_fns" in kwargs else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if pretrained_hparams is not None:
|
|
|
|
pretrained_hparams["pdb_fns"] = pdb_fns
|
|
pretrained_model = Model[pretrained_hparams["model_name"]].cls(**pretrained_hparams)
|
|
self.pretrained_hparams = pretrained_hparams
|
|
else:
|
|
|
|
raise NotImplementedError("Loading pretrained weights from RosettaTask checkpoint not supported")
|
|
|
|
layers = collections.OrderedDict()
|
|
|
|
|
|
if backbone_cutoff is None:
|
|
layers["backbone"] = SequentialWithArgs(*list(pretrained_model.model.children()))
|
|
else:
|
|
layers["backbone"] = SequentialWithArgs(*list(pretrained_model.model.children())[0:backbone_cutoff])
|
|
|
|
if top_net_type == "sklearn":
|
|
|
|
self.model = SequentialWithArgs(layers)
|
|
return
|
|
|
|
|
|
if pred_layer_input_features is None:
|
|
|
|
|
|
|
|
if backbone_cutoff is None:
|
|
|
|
pred_layer_input_features = self.pretrained_hparams["num_tasks"]
|
|
elif backbone_cutoff == -1:
|
|
pred_layer_input_features = self.pretrained_hparams["final_hidden_size"]
|
|
elif backbone_cutoff == -2:
|
|
pred_layer_input_features = self.pretrained_hparams["embedding_len"]
|
|
elif backbone_cutoff == -3:
|
|
pred_layer_input_features = self.pretrained_hparams["embedding_len"] * kwargs["aa_seq_len"]
|
|
else:
|
|
raise ValueError("can't automatically determine pred_layer_input_features for given backbone_cutoff")
|
|
|
|
layers["flatten"] = nn.Flatten(start_dim=1)
|
|
|
|
|
|
if top_net_type == "linear":
|
|
|
|
layers["prediction"] = nn.Linear(in_features=pred_layer_input_features, out_features=1)
|
|
elif top_net_type == "nonlinear":
|
|
|
|
fc_block = FCBlock(in_features=pred_layer_input_features,
|
|
num_hidden_nodes=top_net_hidden_nodes,
|
|
use_batchnorm=top_net_use_batchnorm,
|
|
use_dropout=top_net_use_dropout,
|
|
dropout_rate=top_net_dropout_rate)
|
|
|
|
pred_layer = nn.Linear(in_features=top_net_hidden_nodes, out_features=1)
|
|
|
|
layers["prediction"] = SequentialWithArgs(fc_block, pred_layer)
|
|
else:
|
|
raise ValueError("Unexpected type of top net layer: {}".format(top_net_type))
|
|
|
|
self.model = SequentialWithArgs(layers)
|
|
|
|
def forward(self, x, **kwargs):
|
|
return self.model(x, **kwargs)
|
|
|
|
|
|
def get_activation_fn(activation, functional=True):
|
|
if activation == "relu":
|
|
return F.relu if functional else nn.ReLU()
|
|
elif activation == "gelu":
|
|
return F.gelu if functional else nn.GELU()
|
|
elif activation == "silo" or activation == "swish":
|
|
return F.silu if functional else nn.SiLU()
|
|
elif activation == "leaky_relu" or activation == "lrelu":
|
|
return F.leaky_relu if functional else nn.LeakyReLU()
|
|
else:
|
|
raise RuntimeError("unknown activation: {}".format(activation))
|
|
|
|
|
|
class Model(enum.Enum):
|
|
def __new__(cls, *args, **kwds):
|
|
value = len(cls.__members__) + 1
|
|
obj = object.__new__(cls)
|
|
obj._value_ = value
|
|
return obj
|
|
|
|
def __init__(self, cls, transfer_model):
|
|
self.cls = cls
|
|
self.transfer_model = transfer_model
|
|
|
|
linear = LRModel, False
|
|
fully_connected = FCModel, False
|
|
cnn = ConvModel, False
|
|
cnn2 = ConvModel2, False
|
|
transformer_encoder = AttnModel, False
|
|
transfer_model = TransferModel, True
|
|
|
|
|
|
def main():
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|