|
|
|
|
|
|
|
from .loss import *
|
|
from .model_utils import AllOutput, create_output_loss_lucagplm
|
|
from .alphabet import Alphabet
|
|
from .modeling_gplm import *
|
|
from .lucaone_gplm_config import LucaGPLMConfig
|
|
from transformers import PreTrainedModel
|
|
|
|
class LucaGPLM(PreTrainedModel):
|
|
config_class = LucaGPLMConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.type_vocab_size = config.type_vocab_size
|
|
self.num_layers = config.num_hidden_layers
|
|
self.embed_dim = config.hidden_size
|
|
self.attention_heads = config.num_attention_heads
|
|
self.no_position_embeddings = config.no_position_embeddings
|
|
self.no_token_type_embeddings = config.no_token_type_embeddings
|
|
if not isinstance(config.alphabet, Alphabet):
|
|
self.alphabet = Alphabet.from_predefined(config.alphabet)
|
|
else:
|
|
self.alphabet = config.alphabet
|
|
self.alphabet_size = len(self.alphabet)
|
|
self.padding_idx = self.alphabet.padding_idx
|
|
self.mask_idx = self.alphabet.mask_idx
|
|
self.cls_idx = self.alphabet.cls_idx
|
|
self.eos_idx = self.alphabet.eos_idx
|
|
self.prepend_bos = self.alphabet.prepend_bos
|
|
self.append_eos = self.alphabet.append_eos
|
|
self.token_dropout = config.token_dropout
|
|
self.ignore_index = config.ignore_index
|
|
self.use_embed_layer_norm = config.use_embed_layer_norm
|
|
self.use_last_layer_norm = config.use_last_layer_norm
|
|
self.embed_scale = config.embed_scale
|
|
self.embedding_inference = True
|
|
self._init_submodules()
|
|
|
|
def _init_submodules(self):
|
|
|
|
self.embed_tokens = nn.Embedding(
|
|
self.alphabet_size,
|
|
self.embed_dim,
|
|
padding_idx=self.padding_idx,
|
|
)
|
|
self.embed_pos = None
|
|
if not self.no_position_embeddings:
|
|
self.embed_pos = nn.Embedding(self.max_position_embeddings, self.embed_dim)
|
|
self.embed_type = None
|
|
if not self.no_token_type_embeddings:
|
|
self.embed_type = nn.Embedding(self.type_vocab_size, self.embed_dim)
|
|
if self.use_embed_layer_norm:
|
|
self.embed_layer_norm = LucaGPLM1bLayerNorm(self.embed_dim)
|
|
else:
|
|
self.embed_layer_norm = None
|
|
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
LucaGPLMTransformerLayer(
|
|
self.embed_dim,
|
|
4 * self.embed_dim,
|
|
self.attention_heads,
|
|
add_bias_kv=False,
|
|
use_lucagplm1b_layer_norm=True,
|
|
use_rotary_embeddings=True,
|
|
)
|
|
for _ in range(self.num_layers)
|
|
]
|
|
)
|
|
self.layer_size = len(self.layers)
|
|
|
|
if not self.embedding_inference:
|
|
self.contact_head = ContactPredictionHead(
|
|
self.num_layers * self.attention_heads,
|
|
self.prepend_bos,
|
|
self.append_eos,
|
|
eos_idx=self.eos_idx,
|
|
)
|
|
if self.use_last_layer_norm:
|
|
self.last_layer_norm = LucaGPLM1bLayerNorm(self.embed_dim)
|
|
else:
|
|
self.last_layer_norm = None
|
|
if not self.embedding_inference:
|
|
self.lm_head = RobertaLMHead(
|
|
embed_dim=self.embed_dim,
|
|
output_dim=self.alphabet_size,
|
|
weight=self.embed_tokens.weight,
|
|
)
|
|
|
|
def _init_embedding(self, pretrained_token_matrix, token_matrix):
|
|
'''
|
|
0->2
|
|
1->0
|
|
2->3
|
|
3->1
|
|
4->10
|
|
...
|
|
28->34
|
|
29->36
|
|
30->37
|
|
31->38
|
|
32->4
|
|
'''
|
|
|
|
token_matrix[2, :] = pretrained_token_matrix[0, :]
|
|
token_matrix[0, :] = pretrained_token_matrix[1, :]
|
|
token_matrix[3, :] = pretrained_token_matrix[2, :]
|
|
token_matrix[1, :] = pretrained_token_matrix[3, :]
|
|
for idx in range(10, 35):
|
|
token_matrix[idx, :] = pretrained_token_matrix[idx - 6, :]
|
|
token_matrix[36, :] = pretrained_token_matrix[29, :]
|
|
token_matrix[37, :] = pretrained_token_matrix[30, :]
|
|
token_matrix[38, :] = pretrained_token_matrix[31, :]
|
|
token_matrix[4, :] = pretrained_token_matrix[32, :]
|
|
return token_matrix
|
|
|
|
def _init_submodules_new(self, pretrained_model_name):
|
|
|
|
from esm import pretrained
|
|
from collections import OrderedDict
|
|
pretrained, _ = pretrained.load_model_and_alphabet(pretrained_model_name)
|
|
pretrained_state_dict = pretrained.state_dict()
|
|
new_state_dict = OrderedDict()
|
|
our_model_state_dict = {}
|
|
for key, value in self.state_dict().items():
|
|
our_model_state_dict[key] = value
|
|
for name, weight in pretrained_state_dict.items():
|
|
if "final_layer_norm" in name:
|
|
name = name.replace("final_layer_norm", "post_layer_norm")
|
|
elif "self_attn_layer_norm" in name:
|
|
name = name.replace("self_attn_layer_norm", "pre_layer_norm")
|
|
elif "emb_layer_norm_after" in name:
|
|
name = name.replace("emb_layer_norm_after", "last_layer_norm")
|
|
if name.startswith("layers."):
|
|
layer_id = name.split(".")[1]
|
|
if int(layer_id) >= self.num_layers:
|
|
continue
|
|
if name == "embed_tokens.weight":
|
|
new_state_dict[name] = self._init_embedding(weight, our_model_state_dict[name])
|
|
del our_model_state_dict[name]
|
|
elif name in our_model_state_dict and our_model_state_dict[name].shape == weight.shape:
|
|
del our_model_state_dict[name]
|
|
new_state_dict[name] = weight
|
|
'''
|
|
print("Exists layer names:")
|
|
print(new_state_dict.keys())
|
|
print("Not exists Layer names:")
|
|
print(our_model_state_dict.keys())
|
|
'''
|
|
new_state_dict.update(our_model_state_dict)
|
|
self.load_state_dict(new_state_dict)
|
|
|
|
def __calc_loss__(self, task_level_type, output_mode, logits, label, label_size, loss_fct, loss_reduction):
|
|
if output_mode in ["regression"]:
|
|
if task_level_type not in ["seq_level"] and loss_reduction == "meanmean":
|
|
|
|
|
|
|
|
loss = loss_fct(logits, label)
|
|
else:
|
|
|
|
|
|
|
|
loss = loss_fct(logits.view(-1), label.view(-1))
|
|
elif output_mode in ["multi_label", "multi-label"]:
|
|
|
|
if loss_reduction == "meanmean":
|
|
|
|
|
|
loss = loss_fct(logits, label.float())
|
|
else:
|
|
|
|
|
|
loss = loss_fct(logits.view(-1, label_size), label.view(-1, label_size).float())
|
|
elif label_size <= 2 or output_mode in ["binary_class", "binary-class"]:
|
|
if task_level_type not in ["seq_level"] and loss_reduction == "meanmean":
|
|
|
|
|
|
|
|
loss = loss_fct(logits, label.float())
|
|
else:
|
|
|
|
|
|
|
|
loss = loss_fct(logits.view(-1), label.view(-1).float())
|
|
elif output_mode in ["multi_class", "multi-class"]:
|
|
if task_level_type not in ["seq_level"] and loss_reduction == "meanmean":
|
|
|
|
|
|
|
|
loss = loss_fct(logits, label)
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = loss_fct(logits.view(-1, label_size), label.view(-1))
|
|
else:
|
|
raise Exception("Not support output_mode=%s" % output_mode)
|
|
return loss
|
|
|
|
def __forword__(self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
output_keys: Optional[dict[str, set[str]]] = None,
|
|
labels: Optional[dict[str, dict[str, torch.Tensor]]] = None,
|
|
repr_layers=[-1],
|
|
need_head_weights=False,
|
|
return_contacts=False,
|
|
use_last_layer_norm=True):
|
|
assert all(-(self.layer_size + 1) <= i <= self.layer_size for i in repr_layers)
|
|
repr_layers = [(i + self.layer_size + 1) % (self.layer_size + 1) for i in repr_layers]
|
|
|
|
if return_contacts:
|
|
need_head_weights = True
|
|
|
|
assert input_ids.ndim == 2
|
|
|
|
if attention_mask is None:
|
|
padding_mask = input_ids.eq(self.padding_idx)
|
|
else:
|
|
padding_mask = attention_mask.eq(self.padding_idx)
|
|
|
|
x = self.embed_scale * self.embed_tokens(input_ids)
|
|
if self.embed_pos is not None and position_ids is not None:
|
|
x += self.embed_scale * self.embed_pos(position_ids)
|
|
if self.embed_type is not None and token_type_ids is not None:
|
|
x += self.embed_scale * self.embed_type(token_type_ids)
|
|
if self.embed_layer_norm is not None:
|
|
x = self.embed_layer_norm(x)
|
|
|
|
if self.token_dropout:
|
|
x.masked_fill_((input_ids == self.mask_idx).unsqueeze(-1), 0.0)
|
|
|
|
mask_ratio_train = 0.15 * 0.8
|
|
src_lengths = (~padding_mask).sum(-1)
|
|
mask_ratio_observed = (input_ids == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
|
|
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
|
|
|
|
|
|
if padding_mask is not None:
|
|
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
|
|
|
|
|
|
repr_layers = set(repr_layers)
|
|
hidden_representations = {}
|
|
|
|
if 0 in repr_layers:
|
|
hidden_representations[0] = x
|
|
|
|
|
|
if need_head_weights:
|
|
attn_weights = []
|
|
|
|
|
|
x = x.transpose(0, 1)
|
|
|
|
if not padding_mask.any():
|
|
padding_mask = None
|
|
|
|
for layer_idx, layer in enumerate(self.layers):
|
|
x, attn = layer(
|
|
x,
|
|
self_attn_padding_mask=padding_mask,
|
|
need_head_weights=need_head_weights,
|
|
)
|
|
if (layer_idx + 1) in repr_layers:
|
|
hidden_representations[layer_idx + 1] = x.transpose(0, 1)
|
|
if need_head_weights:
|
|
|
|
attn_weights.append(attn.transpose(1, 0))
|
|
|
|
|
|
if self.last_layer_norm is not None and use_last_layer_norm:
|
|
|
|
x = self.last_layer_norm(x)
|
|
x = x.transpose(0, 1)
|
|
|
|
|
|
if (layer_idx + 1) in repr_layers:
|
|
hidden_representations[layer_idx + 1] = x
|
|
|
|
|
|
representation_matrix = hidden_representations[self.layer_size]
|
|
|
|
|
|
if not self.embedding_inference:
|
|
lm_mask_logits = self.lm_head(x)
|
|
|
|
|
|
representation_vector = representation_matrix[:, 0, :]
|
|
|
|
logits = {}
|
|
losses = {}
|
|
outputs = {}
|
|
representations = {
|
|
"representation_matrix": representation_matrix,
|
|
"representation_vector": representation_vector
|
|
}
|
|
|
|
if need_head_weights:
|
|
|
|
attentions = torch.stack(attn_weights, 1)
|
|
if padding_mask is not None:
|
|
attention_mask = 1 - padding_mask.type_as(attentions)
|
|
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
|
|
attentions = attentions * attention_mask[:, None, None, :, :]
|
|
representations["attentions"] = attentions
|
|
|
|
if return_contacts and hasattr(self, "contact_head") \
|
|
and not self.embedding_inference:
|
|
contacts = self.contact_head(input_ids, attentions)
|
|
representations["contacts"] = contacts
|
|
'''
|
|
print("output_keys:")
|
|
print(output_keys)
|
|
'''
|
|
if not self.embedding_inference and output_keys:
|
|
for item in output_keys.items():
|
|
cur_task_level_type = item[0]
|
|
if cur_task_level_type not in logits:
|
|
logits[cur_task_level_type] = {}
|
|
outputs[cur_task_level_type] = {}
|
|
for cur_task_level_name in item[1]:
|
|
if cur_task_level_type == "token_level":
|
|
cur_logits = lm_mask_logits
|
|
elif cur_task_level_type == "seq_level":
|
|
cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](representation_vector)
|
|
cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
|
|
if cur_hidden_layer is not None:
|
|
cur_logits = cur_hidden_layer(cur_logits)
|
|
cur_hidden_act = self.hidden_act[cur_task_level_type][cur_task_level_name]
|
|
if cur_hidden_act is not None:
|
|
cur_logits = cur_hidden_act(cur_logits)
|
|
cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
|
|
elif cur_task_level_type == "span_level":
|
|
cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](representation_matrix)
|
|
cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
|
|
if cur_hidden_layer is not None:
|
|
cur_logits = cur_hidden_layer(cur_logits)
|
|
cur_hidden_act = self.hidden_act[cur_task_level_type][cur_task_level_name]
|
|
if cur_hidden_act is not None:
|
|
cur_logits = cur_hidden_act(cur_logits)
|
|
cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
|
|
elif cur_task_level_type == "structure_level":
|
|
cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](representation_matrix)
|
|
cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
|
|
if cur_hidden_layer is not None:
|
|
cur_logits = cur_hidden_layer(cur_logits)
|
|
cur_hidden_act = self.hidden_act[cur_task_level_type][cur_task_level_name]
|
|
if cur_hidden_act is not None:
|
|
cur_logits = cur_hidden_act(cur_logits)
|
|
cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
|
|
logits[cur_task_level_type][cur_task_level_name] = cur_logits
|
|
if cur_task_level_type in self.output and cur_task_level_name in self.output[cur_task_level_type] \
|
|
and self.output[cur_task_level_type][cur_task_level_name] is not None:
|
|
outputs[cur_task_level_type][cur_task_level_name] = self.output[cur_task_level_type][cur_task_level_name](cur_logits)
|
|
else:
|
|
outputs[cur_task_level_type][cur_task_level_name] = cur_logits
|
|
if labels is not None and cur_task_level_type in labels and cur_task_level_name in labels[cur_task_level_type]:
|
|
if cur_task_level_type not in losses:
|
|
losses[cur_task_level_type] = {}
|
|
cur_label = labels[cur_task_level_type][cur_task_level_name]
|
|
cur_label_size = self.label_size[cur_task_level_type][cur_task_level_name]
|
|
cur_output_mode = self.output_mode[cur_task_level_type][cur_task_level_name]
|
|
cur_loss_fct = self.loss_fct[cur_task_level_type][cur_task_level_name]
|
|
cur_loss = self.__calc_loss__(
|
|
task_level_type=cur_task_level_type,
|
|
output_mode=cur_output_mode,
|
|
logits=cur_logits,
|
|
label=cur_label,
|
|
label_size=cur_label_size,
|
|
loss_fct=cur_loss_fct,
|
|
loss_reduction="meanmean")
|
|
losses[cur_task_level_type][cur_task_level_name] = cur_loss
|
|
return representations, logits, outputs, losses
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
global_attention_mask: Optional[torch.Tensor] = None,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
output_keys: Optional[dict[str, set[str]]] = None,
|
|
labels: Optional[dict[str, dict[str, torch.Tensor]]] = None,
|
|
input_ids_b: Optional[torch.Tensor] = None,
|
|
attention_mask_b: Optional[torch.Tensor] = None,
|
|
global_attention_mask_b: Optional[torch.Tensor] = None,
|
|
token_type_ids_b: Optional[torch.Tensor] = None,
|
|
position_ids_b: Optional[torch.Tensor] = None,
|
|
head_mask_b: Optional[torch.Tensor] = None,
|
|
inputs_embeds_b: Optional[torch.Tensor] = None,
|
|
output_keys_b: Optional[dict[str, set[str]]] = None,
|
|
labels_b: Optional[dict[str, dict[str, torch.Tensor]]] = None,
|
|
pair_label: Optional[dict[str, dict[str, torch.Tensor]]] = None,
|
|
pair_output_keys: Optional[dict[str, set[str]]] = None,
|
|
output_hidden_states: Optional[dict[str, set[str]]] = None,
|
|
output_attentions: Optional[dict[str, set[str]]] = None,
|
|
need_head_weights: Optional[bool] = None,
|
|
return_contacts: Optional[bool] = None,
|
|
repr_layers: Optional[list[int]] = None,
|
|
return_dict: Optional[bool] = None,
|
|
use_last_layer_norm: Optional[bool] = True
|
|
) -> Union[Tuple[torch.Tensor], AllOutput]:
|
|
if return_dict is None and self.config is not None:
|
|
return_dict = self.config.use_return_dict
|
|
if return_dict is None:
|
|
return_dict = False
|
|
if repr_layers is None or len(repr_layers) == 0:
|
|
repr_layers = [-1]
|
|
if return_contacts is None:
|
|
return_contacts = False
|
|
if need_head_weights is None:
|
|
need_head_weights = True
|
|
has_pair = False
|
|
has_pair_b = False
|
|
if input_ids is not None or inputs_embeds is not None:
|
|
encoding, logits, outputs, losses = self.__forword__(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
output_keys=output_keys,
|
|
labels=labels,
|
|
repr_layers=repr_layers,
|
|
need_head_weights=need_head_weights,
|
|
return_contacts=return_contacts,
|
|
use_last_layer_norm=use_last_layer_norm
|
|
)
|
|
has_pair = True
|
|
if input_ids_b is not None or inputs_embeds_b is not None:
|
|
encoding_b, logits_b, outputs_b, losses_b = self.__forword__(
|
|
input_ids=input_ids_b,
|
|
attention_mask=attention_mask_b,
|
|
token_type_ids=token_type_ids_b,
|
|
position_ids=position_ids_b,
|
|
output_keys=output_keys_b,
|
|
labels=labels_b,
|
|
repr_layers=repr_layers,
|
|
need_head_weights=need_head_weights,
|
|
return_contacts=return_contacts,
|
|
use_last_layer_norm=use_last_layer_norm
|
|
)
|
|
has_pair_b = True
|
|
|
|
if not self.embedding_inference:
|
|
if has_pair and has_pair_b and pair_output_keys and len(pair_output_keys) > 0:
|
|
cur_representation_vector = encoding["representation_vector"]
|
|
cur_representation_vector_b = encoding_b["representation_vector"]
|
|
|
|
pair_logits = {}
|
|
pair_outputs = {}
|
|
for item1 in pair_output_keys.items():
|
|
cur_task_level_type = item1[0]
|
|
if cur_task_level_type not in pair_outputs:
|
|
pair_outputs[cur_task_level_type] = {}
|
|
pair_logits[cur_task_level_type] = {}
|
|
for cur_task_level_name in item1[1]:
|
|
cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](
|
|
torch.cat((cur_representation_vector, cur_representation_vector_b), dim=-1)
|
|
)
|
|
cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
|
|
if cur_hidden_layer is not None:
|
|
cur_logits = cur_hidden_layer(cur_logits)
|
|
cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
|
|
pair_logits[cur_task_level_type][cur_task_level_name] = cur_logits
|
|
pair_outputs[cur_task_level_type][cur_task_level_name] = self.output[cur_task_level_type][cur_task_level_name](cur_logits)
|
|
|
|
if pair_label is not None:
|
|
pair_loss = {}
|
|
for item1 in pair_output_keys.items():
|
|
cur_task_level_type = item1[0]
|
|
if cur_task_level_type not in pair_label:
|
|
continue
|
|
if cur_task_level_type in pair_label:
|
|
pair_loss[cur_task_level_type] = {}
|
|
for cur_task_level_name in item1[1]:
|
|
if cur_task_level_name not in pair_label[cur_task_level_type]:
|
|
continue
|
|
cur_label = pair_label[cur_task_level_type][cur_task_level_name]
|
|
cur_label_size = self.label_size[cur_task_level_type][cur_task_level_name]
|
|
cur_output_mode = self.output_mode[cur_task_level_type][cur_task_level_name]
|
|
cur_loss_fct = self.loss_fct[cur_task_level_type][cur_task_level_name]
|
|
cur_logits = pair_logits[cur_task_level_type][cur_task_level_name]
|
|
cur_loss = self.__calc_loss__(
|
|
task_level_type=cur_task_level_type,
|
|
output_mode=cur_output_mode, logits=cur_logits,
|
|
label=cur_label, label_size=cur_label_size, loss_fct=cur_loss_fct,
|
|
loss_reduction="meanmean")
|
|
pair_loss[cur_task_level_type][cur_task_level_name] = cur_loss
|
|
|
|
if not return_dict:
|
|
return [[losses, losses_b, pair_loss], [outputs, outputs_b, pair_outputs]] + [[encoding, encoding_b]]
|
|
return AllOutput(
|
|
losses=losses,
|
|
outputs=outputs,
|
|
hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
|
|
attentions=encoding["attentions"] if "attentions" in encoding else None,
|
|
global_attentions=None,
|
|
contacts=encoding["contacts"] if "contacts" in encoding else None,
|
|
losses_b=losses_b,
|
|
outputs_b=outputs_b,
|
|
hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
|
|
attentions_b=encoding_b["attentions"] if "hidden_states" in encoding_b else None,
|
|
global_attentions_b=None,
|
|
contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None,
|
|
pair_outputs=pair_outputs,
|
|
pair_losses=pair_loss)
|
|
else:
|
|
if not return_dict:
|
|
return [[losses, losses_b], [outputs, outputs_b]] + [[encoding, encoding_b]]
|
|
return AllOutput(
|
|
losses=losses,
|
|
outputs=outputs,
|
|
hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
|
|
attentions=encoding["attentions"] if "attentions" in encoding else None,
|
|
global_attentions=None,
|
|
contacts=encoding["contacts"] if "contacts" in encoding else None,
|
|
losses_b=losses_b,
|
|
outputs_b=outputs_b,
|
|
hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
|
|
attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
|
|
global_attentions_b=None,
|
|
contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
|
|
)
|
|
elif has_pair:
|
|
if not return_dict:
|
|
return [[losses], [outputs], [encoding]]
|
|
return AllOutput(
|
|
losses=losses,
|
|
outputs=outputs,
|
|
hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
|
|
attentions=encoding["attentions"] if "attentions" in encoding else None,
|
|
global_attentions=None,
|
|
contacts=encoding["contacts"] if "contacts" in encoding else None
|
|
)
|
|
else:
|
|
if not return_dict:
|
|
return [[losses_b], [outputs_b], [encoding_b]]
|
|
return AllOutput(
|
|
losses_b=losses_b,
|
|
outputs_b=outputs_b,
|
|
hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
|
|
attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
|
|
global_attentions_b=None,
|
|
contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
|
|
)
|
|
else:
|
|
if has_pair and has_pair_b:
|
|
if not return_dict:
|
|
return [[None, None], [None, None]] + [[encoding, encoding_b]]
|
|
return AllOutput(
|
|
losses=None,
|
|
outputs=None,
|
|
hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
|
|
attentions=encoding["attentions"] if "attentions" in encoding else None,
|
|
global_attentions=None,
|
|
contacts=encoding["contacts"] if "contacts" in encoding else None,
|
|
losses_b=None,
|
|
outputs_b=None,
|
|
hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
|
|
attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
|
|
global_attentions_b=None,
|
|
contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
|
|
)
|
|
elif has_pair:
|
|
if not return_dict:
|
|
return [[None], [None], [encoding]]
|
|
return AllOutput(
|
|
losses=None,
|
|
outputs=None,
|
|
hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
|
|
attentions=encoding["attentions"] if "attentions" in encoding else None,
|
|
global_attentions=None,
|
|
contacts=encoding["contacts"] if "contacts" in encoding else None
|
|
)
|
|
else:
|
|
if not return_dict:
|
|
return [[None], [None], [encoding_b]]
|
|
return AllOutput(
|
|
losses_b=None,
|
|
outputs_b=None,
|
|
hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
|
|
attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
|
|
global_attentions_b=None,
|
|
contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
|
|
)
|
|
|
|
def predict_contacts(self, input_ids, position_ids=None, token_type_ids=None):
|
|
return self(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
token_type_ids=token_type_ids,
|
|
return_contacts=True)["contacts"]
|
|
|