|
|
|
|
|
|
|
from typing import Optional, Tuple
|
|
from dataclasses import dataclass
|
|
from transformers.modeling_outputs import ModelOutput
|
|
import sys, copy, math
|
|
|
|
from .pooling import *
|
|
from .loss import *
|
|
|
|
@dataclass
|
|
class AllOutput(ModelOutput):
|
|
losses: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
|
|
outputs: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
contacts: Optional[Tuple[torch.FloatTensor]] = None
|
|
losses_b: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
|
|
outputs_b: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
|
|
hidden_states_b: Optional[Tuple[torch.FloatTensor]] = None
|
|
attentions_b: Optional[Tuple[torch.FloatTensor]] = None
|
|
cross_attentions_b: Optional[Tuple[torch.FloatTensor]] = None
|
|
global_attentions_b: Optional[Tuple[torch.FloatTensor]] = None
|
|
contacts_b: Optional[Tuple[torch.FloatTensor]] = None
|
|
pair_outputs: Optional[Tuple[torch.FloatTensor]] = None
|
|
pair_losses: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
|
|
|
|
|
|
def create_pooler(task_level_type, task_level_name, config, args):
|
|
'''
|
|
pooler building
|
|
:param task_level_type:
|
|
:param task_level_name:
|
|
:param config:
|
|
:param args:
|
|
:return:
|
|
'''
|
|
hidden_size = config.hidden_size[task_level_type][task_level_name]
|
|
pooling_type = args.pooling_type[task_level_type][task_level_name]
|
|
|
|
if pooling_type == "max":
|
|
return GlobalMaskMaxPooling1D()
|
|
elif pooling_type == "sum":
|
|
return GlobalMaskSumPooling1D(axis=1)
|
|
elif pooling_type == "avg":
|
|
return GlobalMaskAvgPooling1D()
|
|
elif pooling_type == "attention":
|
|
return GlobalMaskContextAttentionPooling1D(embed_size=hidden_size)
|
|
elif pooling_type == "context_attention":
|
|
return GlobalMaskContextAttentionPooling1D(embed_size=hidden_size)
|
|
elif pooling_type == "weighted_attention":
|
|
return GlobalMaskWeightedAttentionPooling1D(embed_size=hidden_size)
|
|
elif pooling_type == "value_attention":
|
|
return GlobalMaskValueAttentionPooling1D(embed_size=hidden_size)
|
|
elif pooling_type == "transformer":
|
|
copy_config = copy.deepcopy(config)
|
|
copy_config.hidden_size = hidden_size
|
|
return GlobalMaskTransformerPooling1D(copy_config)
|
|
else:
|
|
return None
|
|
|
|
|
|
def create_output_loss_lucagplm(task_level_type, task_level_name, config):
|
|
'''not cls module'''
|
|
if not hasattr(config, "sigmoid"):
|
|
config.sigmoid = {task_level_type: {}}
|
|
elif task_level_type not in config.sigmoid:
|
|
config.sigmoid[task_level_type] = {}
|
|
config.sigmoid[task_level_type][task_level_name] = False if config.output_mode[task_level_type][task_level_name] \
|
|
in ["multi_class", "multi-class", "regression"] else True
|
|
|
|
if task_level_name == "prot_contact":
|
|
config.sigmoid[task_level_type][task_level_name] = True
|
|
config.num_labels = config.label_size[task_level_type][task_level_name]
|
|
if task_level_type in ["token_level", "whole_level"]:
|
|
return_types = ["output", "loss"]
|
|
else:
|
|
return_types = ["dropout", "hidden_layer", "hidden_act", "classifier", "output", "loss"]
|
|
return create_loss_function(config,
|
|
task_level_type=task_level_type,
|
|
task_level_name=task_level_name,
|
|
sigmoid=config.sigmoid[task_level_type][task_level_name],
|
|
output_mode=config.output_mode[task_level_type][task_level_name],
|
|
num_labels=config.num_labels,
|
|
loss_type=config.loss_type[task_level_type][task_level_name],
|
|
ignore_index=config.ignore_index,
|
|
pair_level=True if task_level_type == "pair_level" else False,
|
|
return_types=return_types)
|
|
|
|
|
|
def create_output_loss(task_level_type, task_level_name, cls_module, config, args):
|
|
cls = None
|
|
if task_level_type in ["token_level", "whole_level"]:
|
|
cls = cls_module(config)
|
|
dropout, hidden_layer, hidden_act, classifier, output, loss_fct = create_output_loss_lucagplm(task_level_type, task_level_name, config, args)
|
|
return cls, dropout, hidden_layer, hidden_act, classifier, output, loss_fct
|
|
|