LucaOne / model_utils.py
Yuanfei's picture
Upload LucaGPLM
96c0ca2 verified
#!/usr/bin/env python
# encoding: utf-8
from typing import Optional, Tuple
from dataclasses import dataclass
from transformers.modeling_outputs import ModelOutput
import sys, copy, math
from .pooling import *
from .loss import *
@dataclass
class AllOutput(ModelOutput):
losses: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
outputs: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
contacts: Optional[Tuple[torch.FloatTensor]] = None
losses_b: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
outputs_b: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
hidden_states_b: Optional[Tuple[torch.FloatTensor]] = None
attentions_b: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions_b: Optional[Tuple[torch.FloatTensor]] = None
global_attentions_b: Optional[Tuple[torch.FloatTensor]] = None
contacts_b: Optional[Tuple[torch.FloatTensor]] = None
pair_outputs: Optional[Tuple[torch.FloatTensor]] = None
pair_losses: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
def create_pooler(task_level_type, task_level_name, config, args):
'''
pooler building
:param task_level_type:
:param task_level_name:
:param config:
:param args:
:return:
'''
hidden_size = config.hidden_size[task_level_type][task_level_name]
pooling_type = args.pooling_type[task_level_type][task_level_name]
if pooling_type == "max":
return GlobalMaskMaxPooling1D()
elif pooling_type == "sum":
return GlobalMaskSumPooling1D(axis=1)
elif pooling_type == "avg":
return GlobalMaskAvgPooling1D()
elif pooling_type == "attention":
return GlobalMaskContextAttentionPooling1D(embed_size=hidden_size)
elif pooling_type == "context_attention":
return GlobalMaskContextAttentionPooling1D(embed_size=hidden_size)
elif pooling_type == "weighted_attention":
return GlobalMaskWeightedAttentionPooling1D(embed_size=hidden_size)
elif pooling_type == "value_attention":
return GlobalMaskValueAttentionPooling1D(embed_size=hidden_size)
elif pooling_type == "transformer":
copy_config = copy.deepcopy(config)
copy_config.hidden_size = hidden_size
return GlobalMaskTransformerPooling1D(copy_config)
else:
return None
def create_output_loss_lucagplm(task_level_type, task_level_name, config):
'''not cls module'''
if not hasattr(config, "sigmoid"):
config.sigmoid = {task_level_type: {}}
elif task_level_type not in config.sigmoid:
config.sigmoid[task_level_type] = {}
config.sigmoid[task_level_type][task_level_name] = False if config.output_mode[task_level_type][task_level_name] \
in ["multi_class", "multi-class", "regression"] else True
# 特殊情况,contact需要是sigmoid, 需要思考strcuture需不需要sigmoid
if task_level_name == "prot_contact":
config.sigmoid[task_level_type][task_level_name] = True
config.num_labels = config.label_size[task_level_type][task_level_name]
if task_level_type in ["token_level", "whole_level"]:
return_types = ["output", "loss"]
else:
return_types = ["dropout", "hidden_layer", "hidden_act", "classifier", "output", "loss"]
return create_loss_function(config,
task_level_type=task_level_type,
task_level_name=task_level_name,
sigmoid=config.sigmoid[task_level_type][task_level_name],
output_mode=config.output_mode[task_level_type][task_level_name],
num_labels=config.num_labels,
loss_type=config.loss_type[task_level_type][task_level_name],
ignore_index=config.ignore_index,
pair_level=True if task_level_type == "pair_level" else False,
return_types=return_types)
def create_output_loss(task_level_type, task_level_name, cls_module, config, args):
cls = None
if task_level_type in ["token_level", "whole_level"]:
cls = cls_module(config)
dropout, hidden_layer, hidden_act, classifier, output, loss_fct = create_output_loss_lucagplm(task_level_type, task_level_name, config, args)
return cls, dropout, hidden_layer, hidden_act, classifier, output, loss_fct