Yuanfei commited on
Commit
96c0ca2
·
verified ·
1 Parent(s): 729354d

Upload LucaGPLM

Browse files
alphabet.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ import sys
5
+ import itertools
6
+ from typing import Sequence, List
7
+
8
+ from .batch_converter import BatchConverter
9
+
10
+ gene_standard_toks = ['1', '2', '3', '4', '5', '.', '-', '*']
11
+
12
+ prot_standard_toks = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', 'J', '.', '-', '*']
13
+
14
+ gene_prot_standard_toks = ['1', '2', '3', '4', '5', 'L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', 'J', '.', '-', '*']
15
+
16
+ gene_prot_prepend_toks = ['[PAD]', '[UNK]']
17
+
18
+ gene_prot_append_toks = ['[CLS]', '[SEP]', '[MASK]']
19
+
20
+
21
+ class Alphabet(object):
22
+ def __init__(
23
+ self,
24
+ standard_toks: Sequence[str],
25
+ prepend_toks: Sequence[str] = gene_prot_prepend_toks,
26
+ append_toks: Sequence[str] = gene_prot_append_toks,
27
+ prepend_bos: bool = True,
28
+ append_eos: bool = True
29
+ ):
30
+ self.standard_toks = list(standard_toks)
31
+ self.prepend_toks = list(prepend_toks)
32
+ self.append_toks = list(append_toks)
33
+ self.prepend_bos = prepend_bos
34
+ self.append_eos = append_eos
35
+
36
+ self.all_toks = list(self.prepend_toks)
37
+ self.all_toks.extend(self.append_toks)
38
+ self.all_toks.extend(self.standard_toks)
39
+
40
+ self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
41
+
42
+ self.unk_idx = self.tok_to_idx["[UNK]"]
43
+ self.padding_idx = self.get_idx("[PAD]")
44
+ self.pad_token_id = self.padding_idx
45
+ self.cls_idx = self.get_idx("[CLS]")
46
+ self.mask_idx = self.get_idx("[MASK]")
47
+ self.eos_idx = self.get_idx("[SEP]")
48
+ self.all_special_tokens = prepend_toks + append_toks
49
+ self.all_special_token_idx_list = [self.tok_to_idx[v] for v in self.all_special_tokens]
50
+ self.unique_no_split_tokens = self.all_toks
51
+ self.vocab_size = self.__len__()
52
+
53
+ def __len__(self):
54
+ return len(self.all_toks)
55
+
56
+ def get_idx(self, tok):
57
+ return self.tok_to_idx.get(tok, self.unk_idx)
58
+
59
+ def get_tok(self, ind):
60
+ return self.all_toks[ind]
61
+
62
+ def to_dict(self):
63
+ return self.tok_to_idx.copy()
64
+
65
+ def get_batch_converter(self, no_position_embeddings, no_token_type_embeddings, truncation_seq_length: int = None, ignore_index: int = -100, mlm_probability=0.15):
66
+ return BatchConverter(self,
67
+ no_position_embeddings=no_position_embeddings,
68
+ no_token_type_embeddings=no_token_type_embeddings,
69
+ truncation_seq_length=truncation_seq_length,
70
+ ignore_index=ignore_index,
71
+ mlm_probability=mlm_probability)
72
+
73
+ @classmethod
74
+ def from_predefined(cls, name: str):
75
+ if name.lower() == "prot":
76
+ standard_toks = prot_standard_toks
77
+ elif name.lower() == "gene":
78
+ standard_toks = gene_standard_toks
79
+ elif name.lower() in ["gene_prot", "prot_gene"]:
80
+ standard_toks = gene_prot_standard_toks
81
+ else:
82
+ raise Exception("Not support tokenizer name: %s" % name)
83
+
84
+ prepend_toks = gene_prot_prepend_toks
85
+ append_toks = gene_prot_append_toks
86
+ prepend_bos = True
87
+ append_eos = True
88
+
89
+ return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos)
90
+
91
+ @classmethod
92
+ def from_pretrained(cls, dir_path):
93
+ import os, pickle
94
+ return pickle.load(open(os.path.join(dir_path, "alphabet.pkl"), "rb"))
95
+
96
+ def save_pretrained(self, save_dir):
97
+ import os, pickle
98
+ with open(os.path.join(save_dir, "alphabet.pkl"), 'wb') as outp:
99
+ pickle.dump(self, outp, pickle.HIGHEST_PROTOCOL)
100
+
101
+ def _tokenize(self, text) -> str:
102
+ return text.split()
103
+
104
+ def tokenize(self, text, **kwargs) -> List[str]:
105
+ def split_on_token(tok, text):
106
+ result = []
107
+ split_text = text.split(tok)
108
+ for i, sub_text in enumerate(split_text):
109
+ if i < len(split_text) - 1:
110
+ sub_text = sub_text.rstrip()
111
+ if i > 0:
112
+ sub_text = sub_text.lstrip()
113
+
114
+ if i == 0 and not sub_text:
115
+ result.append(tok)
116
+ elif i == len(split_text) - 1:
117
+ if sub_text:
118
+ result.append(sub_text)
119
+ else:
120
+ pass
121
+ else:
122
+ if sub_text:
123
+ result.append(sub_text)
124
+ result.append(tok)
125
+ return result
126
+
127
+ def split_on_tokens(tok_list, text):
128
+ if not text.strip():
129
+ return []
130
+ tokenized_text = []
131
+ text_list = [text]
132
+ for tok in tok_list:
133
+ tokenized_text = []
134
+ for sub_text in text_list:
135
+ if sub_text not in self.unique_no_split_tokens:
136
+ tokenized_text.extend(split_on_token(tok, sub_text))
137
+ else:
138
+ tokenized_text.append(sub_text)
139
+ text_list = tokenized_text
140
+
141
+ return list(
142
+ itertools.chain.from_iterable(
143
+ (
144
+ self._tokenize(token)
145
+ if token not in self.unique_no_split_tokens
146
+ else [token]
147
+ for token in tokenized_text
148
+ )
149
+ )
150
+ )
151
+
152
+ no_split_token = self.unique_no_split_tokens
153
+ tokenized_text = split_on_tokens(no_split_token, text)
154
+ return tokenized_text
155
+
156
+ def encode(self, text):
157
+ return [self.tok_to_idx[tok] for tok in self.tokenize(text)]
158
+
159
+
160
+ if __name__ == "__main__":
161
+ alphabet = Alphabet.from_predefined("gene_prot")
162
+ from src.utils import gene_seq_replace
163
+ print(alphabet.encode(gene_seq_replace("gttgtttggtagctaggagcctgactacatggcttcaaggctaaatggccacaggtgcccaggctatttggcttgctggaggcttcattcat")))
164
+
alphabet_atom.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ from rdkit import Chem
5
+ from rdkit.Chem import AllChem
6
+ from typing import Sequence, List
7
+
8
+ atom_standard_toks = ['C', 'N', 'O', 'S', 'H', 'Cl', 'F', 'Br', 'I',
9
+ 'Si', 'P', 'B', 'Na', 'K', 'Al', 'Ca', 'Sn', 'As',
10
+ 'Hg', 'Fe', 'Zn', 'Cr', 'Se', 'Gd', 'Au', 'Li'
11
+ ]
12
+
13
+ atom_prepend_toks = ['[PAD]', '[UNK]', '[CLS]']
14
+
15
+ atom_append_toks = ['[SEP]', '[MASK]']
16
+
17
+
18
+ class AlphabetAtom(object):
19
+ def __init__(
20
+ self,
21
+ standard_toks: Sequence[str] = atom_standard_toks,
22
+ prepend_toks: Sequence[str] = atom_prepend_toks,
23
+ append_toks: Sequence[str] = atom_append_toks,
24
+ prepend_bos: bool = True,
25
+ append_eos: bool = True
26
+ ):
27
+ self.standard_toks = list(standard_toks)
28
+ self.prepend_toks = list(prepend_toks)
29
+ self.append_toks = list(append_toks)
30
+ self.prepend_bos = prepend_bos
31
+ self.append_eos = append_eos
32
+
33
+ self.all_toks = list(self.prepend_toks)
34
+ self.all_toks.extend(self.append_toks)
35
+ self.all_toks.extend(self.standard_toks)
36
+
37
+ self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
38
+
39
+ self.unk_idx = self.tok_to_idx["[UNK]"]
40
+ self.padding_idx = self.get_idx("[PAD]")
41
+ self.pad_idx = self.get_idx("[PAD]")
42
+ self.pad_token_id = self.padding_idx
43
+ self.cls_idx = self.get_idx("[CLS]")
44
+ self.mask_idx = self.get_idx("[MASK]")
45
+ self.eos_idx = self.get_idx("[SEP]")
46
+ self.all_special_tokens = prepend_toks + append_toks
47
+ self.all_special_token_idx_list = [self.tok_to_idx[v] for v in self.all_special_tokens]
48
+ self.unique_no_split_tokens = self.all_toks
49
+ self.vocab_size = self.__len__()
50
+
51
+ def __len__(self):
52
+ return len(self.all_toks)
53
+
54
+ def get_idx(self, tok):
55
+ return self.tok_to_idx.get(tok, self.unk_idx)
56
+
57
+ def get_tok(self, ind):
58
+ return self.all_toks[ind]
59
+
60
+ def to_dict(self):
61
+ return self.tok_to_idx.copy()
62
+
63
+ def get_batch_converter(self, task_level_type, label_size, output_mode, no_position_embeddings,
64
+ no_token_type_embeddings, truncation_seq_length: int = None, ignore_index: int = -100, mlm_probability=0.15):
65
+ '''
66
+ return BatchConverter(
67
+ task_level_type,
68
+ label_size,
69
+ output_mode,
70
+ seq_subword=False,
71
+ seq_tokenizer=self,
72
+ no_position_embeddings=no_position_embeddings,
73
+ no_token_type_embeddings=no_token_type_embeddings,
74
+ truncation_seq_length=truncation_seq_length,
75
+ truncation_matrix_length=truncation_seq_length,
76
+ ignore_index=ignore_index,
77
+ mlm_probability=mlm_probability,
78
+ prepend_bos=self.prepend_bos,
79
+ append_eos=self.append_eos)
80
+ '''
81
+ pass
82
+
83
+ @classmethod
84
+ def smiles_2_atom_seq(cls, smi):
85
+ mol = Chem.MolFromSmiles(smi)
86
+ mol = AllChem.AddHs(mol)
87
+ atoms = [atom.GetSymbol() for atom in mol.GetAtoms()] # after add H
88
+ return atoms
89
+
90
+ @classmethod
91
+ def from_predefined(cls, name: str = "atom_v1"):
92
+ if name.lower() == "atom_v1":
93
+ standard_toks = atom_standard_toks
94
+ else:
95
+ raise Exception("Not support tokenizer name: %s" % name)
96
+
97
+ prepend_toks = atom_prepend_toks
98
+ append_toks = atom_append_toks
99
+ prepend_bos = True
100
+ append_eos = True
101
+
102
+ return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos)
103
+
104
+ @classmethod
105
+ def from_pretrained(cls, dir_path):
106
+ import os, pickle
107
+ return pickle.load(open(os.path.join(dir_path, "alphabet_atom.pkl"), "rb"))
108
+
109
+ def save_pretrained(self, save_dir):
110
+ import os, pickle
111
+ with open(os.path.join(save_dir, "alphabet_atom.pkl"), 'wb') as outp:
112
+ pickle.dump(self, outp, pickle.HIGHEST_PROTOCOL)
113
+
114
+ def tokenize(self, smi, prepend_bos, append_eos) -> List[str]:
115
+ seq = AlphabetAtom.smiles_2_atom_seq(smi)
116
+ if prepend_bos:
117
+ seq = [self.get_tok(self.cls_idx)] + seq
118
+ if append_eos:
119
+ seq = seq + [self.get_tok(self.eos_idx)]
120
+ return seq
121
+
122
+ def encode(self, atom_list, prepend_bos, append_eos):
123
+ idx_list = [self.get_idx(tok) for tok in atom_list]
124
+ if prepend_bos:
125
+ idx_list = [self.cls_idx] + idx_list
126
+ if append_eos:
127
+ idx_list = idx_list + [self.eos_idx]
128
+ return idx_list
129
+
130
+ def encode_smi(self, smi, prepend_bos, append_eos):
131
+ atom_list = self.smiles_2_atom_seq(smi)
132
+ return self.encode(atom_list, prepend_bos, append_eos)
batch_converter.py ADDED
@@ -0,0 +1,1365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ import sys
5
+ import torch
6
+ from typing import Sequence
7
+
8
+ from .alphabet_atom import AlphabetAtom
9
+ from .utils import gene_seq_replace
10
+
11
+ class BatchConverter(object):
12
+
13
+ def __init__(self,
14
+ task_level_type,
15
+ label_size,
16
+ output_mode,
17
+ seq_subword,
18
+ seq_tokenizer,
19
+ no_position_embeddings,
20
+ no_token_type_embeddings,
21
+ truncation_seq_length: int = None,
22
+ truncation_matrix_length: int = None,
23
+ atom_tokenizer: AlphabetAtom = None,
24
+ atom_truncation_seq_length: int = None,
25
+ atom_truncation_matrix_length: int = None,
26
+ ignore_index: int = -100,
27
+ padding_idx: int = 0,
28
+ unk_idx: int = 1,
29
+ cls_idx: int = 2,
30
+ eos_idx: int = 3,
31
+ mask_idx: int = 4,
32
+ non_ignore: bool = False,
33
+ mlm_probability=0.15,
34
+ prepend_bos=None,
35
+ append_eos=None,
36
+ **kwargs):
37
+ print("------BatchConverter------")
38
+ print("BatchConverter, kwargs:")
39
+ print(kwargs)
40
+ self.task_level_type = task_level_type
41
+ self.label_size = label_size
42
+ self.output_mode = output_mode
43
+ self.seq_tokenizer = seq_tokenizer
44
+ self.seq_subword = seq_subword
45
+ self.ignore_index = ignore_index
46
+ self.non_ignore = non_ignore
47
+ self.mlm_probability = mlm_probability
48
+ self.truncation_seq_length = truncation_seq_length
49
+ self.truncation_matrix_length = truncation_matrix_length
50
+
51
+ # subword 则必包含两个特殊字符
52
+ if prepend_bos is None:
53
+ if seq_subword is not None:
54
+ self.prepend_bos = True
55
+ else:
56
+ self.prepend_bos = False
57
+ else:
58
+ self.prepend_bos = prepend_bos
59
+ if append_eos is None:
60
+ if seq_subword is not None:
61
+ self.append_eos = True
62
+ else:
63
+ self.append_eos = False
64
+ else:
65
+ self.append_eos = append_eos
66
+
67
+ self.padding_idx = padding_idx
68
+ self.unk_idx = unk_idx
69
+ self.cls_idx = cls_idx
70
+ self.eos_idx = eos_idx
71
+ self.mask_idx = mask_idx
72
+ if self.seq_tokenizer is None:
73
+ self.append_len = 0
74
+ else:
75
+ if hasattr(seq_tokenizer, "prepend_bos"):
76
+ self.prepend_bos = self.seq_tokenizer.prepend_bos
77
+ if hasattr(seq_tokenizer, "append_eos"):
78
+ self.append_eos = self.seq_tokenizer.append_eos
79
+ if hasattr(seq_tokenizer, "padding_idx"):
80
+ self.padding_idx = self.seq_tokenizer.padding_idx
81
+ if hasattr(seq_tokenizer, "unk_idx"):
82
+ self.unk_idx = self.seq_tokenizer.unk_idx
83
+ if hasattr(seq_tokenizer, "cls_idx"):
84
+ self.cls_idx = self.seq_tokenizer.cls_idx
85
+ if hasattr(seq_tokenizer, "eos_idx"):
86
+ self.eos_idx = self.seq_tokenizer.eos_idx
87
+ if hasattr(seq_tokenizer, "mask_idx"):
88
+ self.mask_idx = self.seq_tokenizer.mask_idx
89
+ if hasattr(seq_tokenizer, "all_special_token_idx_list"):
90
+ self.all_special_token_idx_list = self.seq_tokenizer.all_special_token_idx_list
91
+ else:
92
+ self.all_special_token_idx_list = [self.padding_idx, self.unk_idx, self.cls_idx, self.eos_idx, self.mask_idx]
93
+ self.append_len = int(self.prepend_bos) + int(self.append_eos)
94
+
95
+ # for atom
96
+ self.atom_tokenizer = atom_tokenizer
97
+ self.atom_truncation_seq_length = atom_truncation_seq_length
98
+ self.atom_truncation_matrix_length = atom_truncation_matrix_length
99
+ self.atom_prepend_bos = False
100
+ self.atom_append_eos = False
101
+ self.atom_padding_idx = padding_idx
102
+ self.atom_unk_idx = unk_idx
103
+ self.atom_cls_idx = cls_idx
104
+ self.atom_eos_idx = eos_idx
105
+ self.atom_mask_idx = mask_idx
106
+ if self.atom_tokenizer is None:
107
+ self.atom_append_len = 0
108
+ else:
109
+ if hasattr(seq_tokenizer, "padding_idx"):
110
+ self.padding_idx = self.seq_tokenizer.padding_idx
111
+ elif hasattr(seq_tokenizer, "pad_idx"):
112
+ self.padding_idx = self.seq_tokenizer.pad_idx
113
+ elif hasattr(seq_tokenizer, "pad_token_id"):
114
+ self.padding_idx = self.seq_tokenizer.pad_token_id
115
+
116
+ if hasattr(seq_tokenizer, "unk_idx"):
117
+ self.unk_idx = self.seq_tokenizer.unk_idx
118
+ elif hasattr(seq_tokenizer, "unk_token_id"):
119
+ self.unk_idx = self.seq_tokenizer.unk_token_id
120
+
121
+ if hasattr(seq_tokenizer, "cls_idx"):
122
+ self.cls_idx = self.seq_tokenizer.cls_idx
123
+ elif hasattr(seq_tokenizer, "cls_token_id"):
124
+ self.cls_idx = self.seq_tokenizer.cls_token_id
125
+ elif hasattr(seq_tokenizer, "bos_idx"):
126
+ self.cls_idx = self.seq_tokenizer.bos_idx
127
+ elif hasattr(seq_tokenizer, "bos_token_id"):
128
+ self.cls_idx = self.seq_tokenizer.bos_token_id
129
+
130
+ if hasattr(seq_tokenizer, "eos_idx"):
131
+ self.eos_idx = self.seq_tokenizer.eos_idx
132
+ elif hasattr(seq_tokenizer, "eos_token_id"):
133
+ self.eos_idx = self.seq_tokenizer.eos_token_id
134
+ elif hasattr(seq_tokenizer, "sep_token_id"):
135
+ self.eos_idx = self.seq_tokenizer.sep_token_id
136
+
137
+ if hasattr(seq_tokenizer, "mask_idx"):
138
+ self.mask_idx = self.seq_tokenizer.mask_idx
139
+ elif hasattr(seq_tokenizer, "mask_token_id"):
140
+ self.mask_idx = self.seq_tokenizer.mask_token_id
141
+ if hasattr(atom_tokenizer, "all_special_token_idx_list"):
142
+ self.atom_all_special_token_idx_list = self.atom_tokenizer.all_special_token_idx_list
143
+ else:
144
+ self.atom_all_special_token_idx_list = [self.padding_idx, self.unk_idx, self.cls_idx, self.eos_idx, self.mask_idx]
145
+ self.atom_append_len = int(self.atom_prepend_bos) + int(self.atom_append_eos)
146
+
147
+ print("BatchConverter: prepend_bos=%r, append_eos=%r" % (self.prepend_bos, self.append_eos))
148
+ print("BatchConverter: atom_prepend_bos=%r, atom_append_eos=%r" % (self.atom_prepend_bos, self.atom_append_eos))
149
+ self.matrix_add_special_token = False
150
+ if "matrix_add_special_token" in kwargs and kwargs["matrix_add_special_token"]:
151
+ self.matrix_add_special_token = kwargs["matrix_add_special_token"]
152
+ if self.matrix_add_special_token:
153
+ self.prepend_bos = True
154
+ self.append_eos = True
155
+ self.atom_prepend_bos = True
156
+ self.atom_append_eos = True
157
+ self.append_len = int(self.prepend_bos) + int(self.append_eos)
158
+ self.atom_append_len = int(self.atom_prepend_bos) + int(self.atom_append_eos)
159
+
160
+ # 减去特殊字符之后的长度
161
+ self.truncation_seq_length -= self.append_len
162
+ self.truncation_matrix_length -= self.append_len
163
+ # 减去特殊字符之后的长度
164
+ if self.atom_truncation_seq_length:
165
+ self.atom_truncation_seq_length -= self.atom_append_len
166
+ if self.atom_truncation_matrix_length:
167
+ self.atom_truncation_matrix_length -= self.atom_append_len
168
+
169
+ self.input_type = None
170
+ if "input_type" in kwargs and kwargs["input_type"]:
171
+ self.input_type = kwargs["input_type"]
172
+
173
+ if "max_sentence_length" in kwargs and kwargs["max_sentence_length"]:
174
+ self.max_sentence_length = kwargs["max_sentence_length"] - self.append_len
175
+ print("BatchConverter: self.max_sentence_length=%d" % self.max_sentence_length)
176
+ if atom_tokenizer is not None:
177
+ self.atom_max_sentence_length = kwargs["max_sentence_length"] - self.atom_append_len
178
+ print("BatchConverter: self.atom_max_sentence_length=%d" % self.atom_max_sentence_length)
179
+ if "max_sentences" in kwargs and kwargs["max_sentences"]:
180
+ self.max_sentences = kwargs["max_sentences"]
181
+ print("BatchConverter: self.max_sentences=%d" % self.max_sentences)
182
+ self.trunc_type = "right"
183
+ if "trunc_type" in kwargs and kwargs["trunc_type"]:
184
+ self.trunc_type = kwargs["trunc_type"]
185
+ print("BatchConverter: self.trunc_type=%s" % self.trunc_type)
186
+
187
+ self.no_position_embeddings = no_position_embeddings
188
+ self.no_token_type_embeddings = no_token_type_embeddings
189
+ print("BatchConverter: prepend_bos=%r, append_eos=%r" % (self.prepend_bos, self.append_eos))
190
+ print("BatchConverter: atom_prepend_bos=%r, atom_append_eos=%r" % (self.atom_prepend_bos, self.atom_append_eos))
191
+ print("-" * 50)
192
+
193
+ def __parse_label__(self, max_length, task_level_type, label_size, output_mode, label):
194
+ if isinstance(label, str):
195
+ label = eval(label)
196
+ '''
197
+ print("label:")
198
+ print(label)
199
+ '''
200
+ # 需要是padding长度
201
+ cur_len = max_length
202
+ if task_level_type in ["token_level", "structure_level"]:
203
+ if output_mode in ["multi_label", "multi-label"]:
204
+ # N * seq_len * label_size
205
+ new_label = []
206
+ for _ in range(cur_len):
207
+ tmp = []
208
+ for _ in range(label_size):
209
+ tmp.append(0 if self.non_ignore else self.ignore_index)
210
+ new_label.append(tmp)
211
+ else:
212
+ # N * seq_len
213
+ new_label = []
214
+ for _ in range(cur_len):
215
+ new_label.append(0 if self.non_ignore else self.ignore_index)
216
+ if label is not None and len(label) > 0:
217
+ begin_idx = 0
218
+ end_idx = cur_len
219
+ if self.prepend_bos:
220
+ begin_idx = 1
221
+ if self.append_eos:
222
+ end_idx = cur_len - 1
223
+ for idx, item in enumerate(label):
224
+ idx += begin_idx
225
+ if idx >= end_idx:
226
+ break
227
+ if output_mode in ["multi_label", "multi-label"]:
228
+ for v in item:
229
+ new_label[idx][v] = 1
230
+ else:
231
+ new_label[idx] = item
232
+ elif task_level_type == "span_level":
233
+ if output_mode in ["multi_label", "multi-label"]:
234
+ # N * seq_len * label_size
235
+ new_label = []
236
+ for _ in range(cur_len):
237
+ tmp = []
238
+ for _ in range(label_size):
239
+ tmp.append(0 if self.non_ignore else self.ignore_index)
240
+ new_label.append(tmp)
241
+ else:
242
+ # N * seq_len
243
+ new_label = []
244
+ for _ in range(cur_len):
245
+ new_label.append(0 if self.non_ignore else self.ignore_index)
246
+ if label is not None and len(label) > 0:
247
+ begin_idx = 0
248
+ end_idx = cur_len
249
+ if self.prepend_bos:
250
+ begin_idx = 1
251
+ if self.append_eos:
252
+ end_idx = cur_len - 1
253
+ for item in label:
254
+ for idx in range(item[0], item[1] + 1, 1):
255
+ idx += begin_idx
256
+ if idx >= end_idx:
257
+ break
258
+ if output_mode in ["multi_label", "multi-label"]:
259
+ new_label[idx][item[2]] = 1
260
+ else:
261
+ new_label[idx] = item[2]
262
+ elif task_level_type in ["seq_level"]:
263
+ if output_mode in ["multi_label", "multi-label"]:
264
+ # N * label_size
265
+ new_label = []
266
+ for _ in range(label_size):
267
+ new_label.append(0 if self.non_ignore else self.ignore_index)
268
+ else:
269
+ # N * 1
270
+ new_label = [0 if self.non_ignore else self.ignore_index]
271
+ if output_mode in ["multi_label", "multi-label"]:
272
+ if label is not None and len(label) > 0:
273
+ for v in label:
274
+ new_label[int(v)] = 1
275
+ else:
276
+ if label is not None and len(str(label)) > 0:
277
+ if isinstance(label, str):
278
+ new_label = [int(label)]
279
+ elif isinstance(label, list):
280
+ new_label = [int(label[0])]
281
+ else:
282
+ new_label = [label]
283
+ else:
284
+ raise Exception("Not support task_level_type=%s" % task_level_type)
285
+ return new_label
286
+
287
+ def __atom_parse_label__(self, max_length, task_level_type, label_size, output_mode, label):
288
+ if isinstance(label, str):
289
+ label = eval(label)
290
+ '''
291
+ print("label:")
292
+ print(label)
293
+ '''
294
+ # 需要是padding长度
295
+ cur_len = max_length
296
+ if task_level_type in ["token_level", "structure_level"]:
297
+ if output_mode in ["multi_label", "multi-label"]:
298
+ # N * seq_len * label_size
299
+ new_label = []
300
+ for _ in range(cur_len):
301
+ tmp = []
302
+ for _ in range(label_size):
303
+ tmp.append(0 if self.non_ignore else self.ignore_index)
304
+ new_label.append(tmp)
305
+ else:
306
+ # N * seq_len
307
+ new_label = []
308
+ for _ in range(cur_len):
309
+ new_label.append(0 if self.non_ignore else self.ignore_index)
310
+ if label is not None and len(label) > 0:
311
+ begin_idx = 0
312
+ end_idx = cur_len
313
+ if self.atom_prepend_bos:
314
+ begin_idx = 1
315
+ if self.atom_append_eos:
316
+ end_idx = cur_len - 1
317
+ for idx, item in enumerate(label):
318
+ idx += begin_idx
319
+ if idx >= end_idx:
320
+ break
321
+ if output_mode in ["multi_label", "multi-label"]:
322
+ for v in item:
323
+ new_label[idx][v] = 1
324
+ else:
325
+ new_label[idx] = item
326
+ elif task_level_type == "span_level":
327
+ if output_mode in ["multi_label", "multi-label"]:
328
+ # N * seq_len * label_size
329
+ new_label = []
330
+ for _ in range(cur_len):
331
+ tmp = []
332
+ for _ in range(label_size):
333
+ tmp.append(0 if self.non_ignore else self.ignore_index)
334
+ new_label.append(tmp)
335
+ else:
336
+ # N * seq_len
337
+ new_label = []
338
+ for _ in range(cur_len):
339
+ new_label.append(0 if self.non_ignore else self.ignore_index)
340
+ if label is not None and len(label) > 0:
341
+ begin_idx = 0
342
+ end_idx = cur_len
343
+ if self.atom_prepend_bos:
344
+ begin_idx = 1
345
+ if self.atom_append_eos:
346
+ end_idx = cur_len - 1
347
+ for item in label:
348
+ for idx in range(item[0], item[1] + 1, 1):
349
+ idx += begin_idx
350
+ if idx >= end_idx:
351
+ break
352
+ if output_mode in ["multi_label", "multi-label"]:
353
+ new_label[idx][item[2]] = 1
354
+ else:
355
+ new_label[idx] = item[2]
356
+ elif task_level_type in ["seq_level"]:
357
+ if output_mode in ["multi_label", "multi-label"]:
358
+ # N * label_size
359
+ new_label = []
360
+ for _ in range(label_size):
361
+ new_label.append(0 if self.non_ignore else self.ignore_index)
362
+ else:
363
+ # N * 1
364
+ new_label = [0 if self.non_ignore else self.ignore_index]
365
+ if output_mode in ["multi_label", "multi-label"]:
366
+ if label is not None and len(label) > 0:
367
+ for v in label:
368
+ new_label[int(v)] = 1
369
+ else:
370
+ if label is not None and len(str(label)) > 0:
371
+ if isinstance(label, str):
372
+ new_label = [int(label)]
373
+ elif isinstance(label, list):
374
+ new_label = [int(label[0])]
375
+ else:
376
+ new_label = [label]
377
+ else:
378
+ raise Exception("Not support task_level_type=%s" % task_level_type)
379
+
380
+ return new_label
381
+
382
+ def __mask_tokens__(self, input_ids):
383
+ labels = input_ids.clone()
384
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
385
+
386
+ # 特殊字符处为1
387
+ special_tokens_mask = [
388
+ 1 if v in self.all_special_token_idx_list else 0 for v in labels.tolist()
389
+ ]
390
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
391
+ # 将特殊字符处填充为0.0
392
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
393
+
394
+ # 非特殊字符的位置
395
+ masked_indices = torch.bernoulli(probability_matrix).bool()
396
+ # 特殊字符处为-100
397
+ labels[~masked_indices] = self.ignore_index # We only compute loss on masked tokens
398
+
399
+ # 80% of the time, we replace masked input tokens with alphabet.mask_token ([MASK])
400
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
401
+ input_ids[indices_replaced] = self.mask_idx
402
+
403
+ # 10% of the time, we replace masked input tokens with random word
404
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
405
+ random_words = torch.randint(len(self.seq_tokenizer), labels.shape, dtype=torch.long)
406
+ input_ids[indices_random] = random_words[indices_random]
407
+
408
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
409
+ return input_ids, labels
410
+
411
+ def __atom_mask_tokens__(self, input_ids):
412
+ labels = input_ids.clone()
413
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
414
+
415
+ # 特殊字符处为1
416
+ special_tokens_mask = [
417
+ 1 if v in self.atom_all_special_token_idx_list else 0 for v in labels.tolist()
418
+ ]
419
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
420
+ # 将特殊字符处填充为0.0
421
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
422
+
423
+ # 非特殊字符的位置
424
+ masked_indices = torch.bernoulli(probability_matrix).bool()
425
+ # 特殊字符处为-100
426
+ labels[~masked_indices] = self.ignore_index # We only compute loss on masked tokens
427
+
428
+ # 80% of the time, we replace masked input tokens with alphabet.mask_token ([MASK])
429
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
430
+ input_ids[indices_replaced] = self.atom_mask_idx
431
+
432
+ # 10% of the time, we replace masked input tokens with random word
433
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
434
+ random_words = torch.randint(len(self.atom_tokenizer), labels.shape, dtype=torch.long)
435
+ input_ids[indices_random] = random_words[indices_random]
436
+
437
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
438
+ return input_ids, labels
439
+
440
+ def __seq_encode__(self, batch_size, seqs):
441
+ '''
442
+ 该函数不加特殊字符[CLS]与[SEP]
443
+ :param batch_size:
444
+ :param seqs:
445
+ :return:
446
+ '''
447
+ if self.seq_subword:
448
+ seq_encoded_list = []
449
+ for seq_str in seqs:
450
+ seq_to_list = self.seq_subword.process_line(seq_str.upper()).split(" ")
451
+ seq = " ".join(seq_to_list)
452
+ inputs = self.seq_tokenizer.encode_plus(
453
+ seq,
454
+ None,
455
+ add_special_tokens=False,
456
+ max_length=self.truncation_seq_length,
457
+ truncation=True
458
+ )
459
+ seq_encoded_list.append(inputs["input_ids"])
460
+ else:
461
+ seq_encoded_list = [self.seq_tokenizer.encode(seq_str.upper()) for seq_str in seqs]
462
+ # 该长度已经减去了需要增加的特殊字符的个数
463
+ if self.truncation_seq_length:
464
+ seq_encoded_list = [encoded[:self.truncation_seq_length] for encoded in seq_encoded_list]
465
+ max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
466
+ max_len = max_len + int(self.prepend_bos) + int(self.append_eos)
467
+ # for input
468
+ input_ids = torch.empty(
469
+ (
470
+ batch_size,
471
+ max_len,
472
+ ),
473
+ dtype=torch.int64,
474
+ )
475
+ input_ids.fill_(self.padding_idx)
476
+
477
+ position_ids = None
478
+ if not self.no_position_embeddings:
479
+ position_ids = torch.empty(
480
+ (
481
+ batch_size,
482
+ max_len,
483
+ ),
484
+ dtype=torch.int64,
485
+ )
486
+ position_ids.fill_(self.padding_idx)
487
+
488
+ token_type_ids = None
489
+ if not self.no_position_embeddings:
490
+ token_type_ids = torch.empty(
491
+ (
492
+ batch_size,
493
+ max_len,
494
+ ),
495
+ dtype=torch.int64,
496
+ )
497
+ token_type_ids.fill_(self.padding_idx)
498
+ attention_masks = torch.empty(
499
+ (
500
+ batch_size,
501
+ max_len,
502
+ ),
503
+ dtype=torch.int64,
504
+ )
505
+ attention_masks.fill_(0)
506
+
507
+ return seq_encoded_list, input_ids, position_ids, token_type_ids, attention_masks, max_len
508
+
509
+ def __multi_seq_encode__(self, batch_size, seqs):
510
+ '''
511
+ 该函数是多sentence的表征器,每个sentence都加[CLS]与[SEP]
512
+ :param batch_size:
513
+ :param seqs:
514
+ :return:
515
+ '''
516
+ assert hasattr(self, "max_sentences") and hasattr(self, "max_sentence_length")
517
+ max_sentence_len = 0
518
+ max_sentence_num = 0
519
+ if self.seq_subword:
520
+ seq_encoded_list = []
521
+ for cur_sample_seqs in seqs:
522
+ cur_seq_encoded_list = []
523
+ if len(cur_sample_seqs) > self.max_sentences:
524
+ # 每个样本最多cur_sample_seqs个
525
+ if self.trunc_type == "left":
526
+ cur_sample_seqs = cur_sample_seqs[-self.max_sentences:]
527
+ else:
528
+ cur_sample_seqs = cur_sample_seqs[:self.max_sentences]
529
+ if max_sentence_num < len(cur_sample_seqs):
530
+ max_sentence_num = len(cur_sample_seqs)
531
+ for seq_idx, seq_str in enumerate(cur_sample_seqs):
532
+ seq_to_list = self.seq_subword.process_line(seq_str.upper()).split(" ")
533
+ seq = " ".join(seq_to_list)
534
+ inputs = self.seq_tokenizer.encode_plus(
535
+ seq,
536
+ None,
537
+ add_special_tokens=False,
538
+ max_length=self.max_sentence_length,
539
+ truncation=True
540
+ )
541
+ if self.prepend_bos:
542
+ inputs["input_ids"] = [self.cls_idx] + inputs["input_ids"]
543
+ if self.append_eos:
544
+ inputs["input_ids"] = inputs["input_ids"] + [self.eos_idx]
545
+ if max_sentence_len < len(inputs["input_ids"]):
546
+ max_sentence_len = len(inputs["input_ids"])
547
+ cur_seq_encoded_list.append(inputs["input_ids"])
548
+ seq_encoded_list.append(cur_seq_encoded_list)
549
+ else:
550
+ seq_encoded_list = []
551
+ for cur_sample_seqs in seqs:
552
+ cur_seq_encoded_list = []
553
+ if len(cur_sample_seqs) > self.max_sentences:
554
+ # 每个样本最多cur_sample_seqs个
555
+ if self.trunc_type == "left":
556
+ cur_sample_seqs = cur_sample_seqs[-self.max_sentences:]
557
+ else:
558
+ cur_sample_seqs = cur_sample_seqs[:self.max_sentences]
559
+ if max_sentence_num < len(cur_sample_seqs):
560
+ max_sentence_num = len(cur_sample_seqs)
561
+ for seq_idx, seq_str in enumerate(cur_sample_seqs):
562
+ if len(seq_str) > self.max_sentence_length:
563
+ if self.trunc_type == "left":
564
+ seq_str = seq_str[-self.max_sentence_length:]
565
+ else:
566
+ seq_str = seq_str[:self.max_sentence_length]
567
+
568
+ inputs = self.seq_tokenizer.encode(seq_str.upper())
569
+ # print("len:%d, %s" % (len(seq_str), seq_str.upper()))
570
+ if self.prepend_bos:
571
+ inputs = [self.cls_idx] + inputs
572
+ if self.append_eos:
573
+ inputs = inputs + [self.eos_idx]
574
+ # print("inputs:%d, " %len(inputs), inputs)
575
+ cur_seq_encoded_list.append(inputs)
576
+ if max_sentence_len < len(inputs):
577
+ max_sentence_len = len(inputs)
578
+ seq_encoded_list.append(cur_seq_encoded_list)
579
+ # for input
580
+ input_ids = torch.empty(
581
+ (
582
+ batch_size,
583
+ max_sentence_num,
584
+ max_sentence_len,
585
+ ),
586
+ dtype=torch.int64,
587
+ )
588
+ input_ids.fill_(self.padding_idx)
589
+
590
+ position_ids = None
591
+ if not self.no_position_embeddings:
592
+ position_ids = torch.empty(
593
+ (
594
+ batch_size,
595
+ max_sentence_num,
596
+ max_sentence_len
597
+ ),
598
+ dtype=torch.int64,
599
+ )
600
+ position_ids.fill_(self.padding_idx)
601
+
602
+ token_type_ids = None
603
+ if not self.no_position_embeddings:
604
+ token_type_ids = torch.empty(
605
+ (
606
+ batch_size,
607
+ max_sentence_num,
608
+ max_sentence_len
609
+ ),
610
+ dtype=torch.int64,
611
+ )
612
+ token_type_ids.fill_(self.padding_idx)
613
+ attention_masks = torch.empty(
614
+ (
615
+ batch_size,
616
+ max_sentence_num,
617
+ max_sentence_len
618
+ ),
619
+ dtype=torch.int64,
620
+ )
621
+ attention_masks.fill_(0)
622
+
623
+ return seq_encoded_list, input_ids, position_ids, token_type_ids, attention_masks, max_sentence_num, max_sentence_len
624
+
625
+ def __atom_seq_encode__(self, batch_size, seqs):
626
+ '''
627
+ 该函数不加特殊字符[CLS]与[SEP]
628
+ :param batch_size:
629
+ :param seqs:
630
+ :return:
631
+ '''
632
+ seq_encoded_list = []
633
+ for seq_idx, cur_seq in enumerate(seqs):
634
+ if isinstance(cur_seq, str): # smiles
635
+ cur_seq_encoded = self.atom_tokenizer.encode_smi(cur_seq,
636
+ prepend_bos=False,
637
+ append_eos=False)
638
+ elif isinstance(cur_seq, list): # atom list
639
+ cur_seq_encoded = self.atom_tokenizer.encode(cur_seq,
640
+ prepend_bos=False,
641
+ append_eos=False)
642
+ else:
643
+ raise Exception("not support molecule input type:", type(cur_seq))
644
+ # 该长度已经减去了需要增加的特殊字符的个数
645
+ if self.atom_truncation_seq_length:
646
+ cur_seq_encoded = cur_seq_encoded[:self.atom_truncation_seq_length]
647
+ seq_encoded_list.append(cur_seq_encoded)
648
+ max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
649
+ max_len = max_len + int(self.atom_prepend_bos) + int(self.atom_append_eos)
650
+ # for input
651
+ input_ids = torch.empty(
652
+ (
653
+ batch_size,
654
+ max_len,
655
+ ),
656
+ dtype=torch.int64,
657
+ )
658
+ input_ids.fill_(self.atom_padding_idx)
659
+
660
+ position_ids = None
661
+ if not self.no_position_embeddings:
662
+ position_ids = torch.empty(
663
+ (
664
+ batch_size,
665
+ max_len,
666
+ ),
667
+ dtype=torch.int64,
668
+ )
669
+ position_ids.fill_(self.atom_padding_idx)
670
+
671
+ token_type_ids = None
672
+ if not self.no_position_embeddings:
673
+ token_type_ids = torch.empty(
674
+ (
675
+ batch_size,
676
+ max_len,
677
+ ),
678
+ dtype=torch.int64,
679
+ )
680
+ token_type_ids.fill_(self.atom_padding_idx)
681
+ attention_masks = torch.empty(
682
+ (
683
+ batch_size,
684
+ max_len,
685
+ ),
686
+ dtype=torch.int64,
687
+ )
688
+ attention_masks.fill_(0)
689
+
690
+ return seq_encoded_list, input_ids, position_ids, token_type_ids, attention_masks, max_len
691
+
692
+ def __vector_encode__(self, batch_size, vectors):
693
+ embedding_vector_dim = vectors[0].shape[0]
694
+ filled_vectors = torch.empty(
695
+ (
696
+ batch_size,
697
+ embedding_vector_dim
698
+ ),
699
+ dtype=torch.float32,
700
+ )
701
+ filled_vectors.fill_(0.0)
702
+ return filled_vectors, 1
703
+
704
+ def __atom_vector_encode__(self, batch_size, vectors):
705
+ return self.__vector_encode__(batch_size, vectors)
706
+
707
+ def __multi_vector_encode__(self, batch_size, vectors):
708
+ embedding_vector_dim = vectors[0][0].shape[0]
709
+ filled_vectors = torch.empty(
710
+ (
711
+ batch_size,
712
+ self.max_sentences,
713
+ embedding_vector_dim
714
+ ),
715
+ dtype=torch.float32,
716
+ )
717
+ filled_vectors.fill_(0.0)
718
+ return filled_vectors, self.max_sentences, 1
719
+
720
+ def __matrix_encode__(self, batch_size, matrices):
721
+ '''
722
+ 该函数不加特殊字符[CLS]与[SEP]的向量
723
+ :param batch_size:
724
+ :param matrices:
725
+ :return:
726
+ '''
727
+ max_len = max(matrix.shape[0] for matrix in matrices)
728
+ if self.matrix_add_special_token:
729
+ max_len -= 2
730
+ if self.truncation_matrix_length:
731
+ max_len = min(max_len, self.truncation_matrix_length)
732
+ if self.matrix_add_special_token:
733
+ max_len += 2
734
+ else:
735
+ max_len = max_len + int(self.prepend_bos) + int(self.append_eos)
736
+ embedding_vector_dim = matrices[0].shape[1]
737
+ # for input
738
+ filled_matrices = torch.empty(
739
+ (
740
+ batch_size,
741
+ max_len,
742
+ embedding_vector_dim
743
+ ),
744
+ dtype=torch.float32,
745
+ )
746
+ filled_matrices.fill_(0.0)
747
+ attention_masks = torch.empty(
748
+ (
749
+ batch_size,
750
+ max_len,
751
+ ),
752
+ dtype=torch.int64,
753
+ )
754
+ attention_masks.fill_(0)
755
+ return filled_matrices, attention_masks, max_len
756
+
757
+ def __atom_matrix_encode__(self, batch_size, matrices):
758
+ '''
759
+ 该函数不加特殊字符[CLS]与[SEP]的向量
760
+ :param batch_size:
761
+ :param matrices:
762
+ :return:
763
+ '''
764
+ max_len = max(matrix.shape[0] for matrix in matrices)
765
+ if self.matrix_add_special_token:
766
+ max_len -= 2
767
+ if self.atom_truncation_matrix_length:
768
+ max_len = min(max_len, self.atom_truncation_matrix_length)
769
+ if self.matrix_add_special_token:
770
+ max_len += 2
771
+ else:
772
+ max_len = max_len + int(self.atom_prepend_bos) + int(self.atom_append_eos)
773
+ embedding_vector_dim = matrices[0].shape[1]
774
+ # for input
775
+ filled_matrices = torch.empty(
776
+ (
777
+ batch_size,
778
+ max_len,
779
+ embedding_vector_dim
780
+ ),
781
+ dtype=torch.float32,
782
+ )
783
+ filled_matrices.fill_(0.0)
784
+ attention_masks = torch.empty(
785
+ (
786
+ batch_size,
787
+ max_len,
788
+ ),
789
+ dtype=torch.int64,
790
+ )
791
+ attention_masks.fill_(0)
792
+ return filled_matrices, attention_masks, max_len
793
+
794
+ def __multi_matrix_encode__(self, batch_size, matrices):
795
+ '''
796
+ 该函数不加特殊字符[CLS]与[SEP]的向量
797
+ :param batch_size:
798
+ :param matrices:
799
+ :return:
800
+ '''
801
+ max_sentence_num = max(len(cur_matrix) for cur_matrix in matrices)
802
+ max_sentence_num = min(max_sentence_num, self.max_sentences)
803
+ if self.trunc_type == "left":
804
+ max_sentence_len = max(max(matrix.shape[0] for matrix in cur_matrix[-max_sentence_num:]) for cur_matrix in matrices)
805
+ else:
806
+ max_sentence_len = max(max(matrix.shape[0] for matrix in cur_matrix[:max_sentence_num]) for cur_matrix in matrices)
807
+ # print("encoder max_sentence_num:%d, max_sentence_len: %d" % (max_sentence_num, max_sentence_len))
808
+ if self.matrix_add_special_token:
809
+ max_sentence_len -= 2
810
+ max_sentence_len = min(max_sentence_len, self.max_sentence_length)
811
+ # print("encoder max_sentence_num:%d, max_sentence_len: %d" % (max_sentence_num, max_sentence_len))
812
+ if self.matrix_add_special_token:
813
+ max_sentence_len += 2
814
+ else:
815
+ max_sentence_len = max_sentence_len + int(self.prepend_bos) + int(self.append_eos)
816
+ # print("encoder max_sentence_num:%d, max_sentence_len: %d" % (max_sentence_num, max_sentence_len))
817
+ # print("self.max_sentence_length: %d" % self.max_sentence_length)
818
+ # print("max_sentence_len: %d" % max_sentence_len)
819
+ embedding_vector_dim = matrices[0][0].shape[1]
820
+ # for input
821
+ filled_matrices = torch.empty(
822
+ (
823
+ batch_size,
824
+ max_sentence_num,
825
+ max_sentence_len,
826
+ embedding_vector_dim
827
+ ),
828
+ dtype=torch.float32,
829
+ )
830
+ filled_matrices.fill_(0.0)
831
+ attention_masks = torch.empty(
832
+ (
833
+ batch_size,
834
+ max_sentence_num,
835
+ max_sentence_len
836
+ ),
837
+ dtype=torch.int64,
838
+ )
839
+ attention_masks.fill_(0)
840
+ return filled_matrices, attention_masks, max_sentence_num, max_sentence_len
841
+
842
+ def __call_single__(self, batch_size, seq_types, seqs, vectors, matrices, labels):
843
+ max_length = sys.maxsize
844
+ input_ids, position_ids, token_type_ids, seq_attention_masks = None, None, None, None
845
+ seq_part_of_input = False
846
+ molecule_flag = False
847
+ multi_seq_flag = False
848
+ if seqs:
849
+ new_seqs = []
850
+ for seq_idx, seq_type in enumerate(seq_types):
851
+ if seq_type == "gene":
852
+ new_seqs.append(gene_seq_replace(seqs[seq_idx].upper()))
853
+ elif seq_type == "molecule":
854
+ if isinstance(seqs[seq_idx], str):
855
+ new_seqs.append(AlphabetAtom.smiles_2_atom_seq(seqs[seq_idx]))
856
+ else:
857
+ new_seqs.append(seqs[seq_idx])
858
+ molecule_flag = True
859
+ elif seq_type == "multi_gene":
860
+ new_seqs.append([gene_seq_replace(seq).upper() for seq in seqs[seq_idx].split(",")])
861
+ multi_seq_flag = True
862
+ elif seq_type == "multi_prot":
863
+ new_seqs.append([seq.upper() for seq in seqs[seq_idx].split(",")])
864
+ multi_seq_flag = True
865
+ else:
866
+ new_seqs.append(seqs[seq_idx].upper())
867
+ if molecule_flag:
868
+ # seq_encoded_list没有加特殊字符,input_ids标志位来占位, seq_max_length 根据标志位来加特殊字符长度
869
+ seq_encoded_list, input_ids, position_ids, token_type_ids, seq_attention_masks, seq_max_length = self.__atom_seq_encode__(
870
+ batch_size=batch_size, seqs=new_seqs)
871
+
872
+ elif multi_seq_flag:
873
+ # seq_encoded_list根据标志位来加特殊字符,input_ids根据标志位来加特殊字符, seq_max_len 根据标志位来加特殊字符长度
874
+ seq_encoded_list, input_ids, position_ids, token_type_ids, seq_attention_masks, seq_max_num, seq_max_len = self.__multi_seq_encode__(
875
+ batch_size=batch_size, seqs=new_seqs)
876
+ '''
877
+ print("seq_max_num: %d" % seq_max_num)
878
+ print("seq_max_len: %d" % seq_max_len)
879
+ print(input_ids.shape)
880
+ print("len(seq_encoded_list): %d" % len(seq_encoded_list))
881
+ for input_id in input_ids:
882
+ print(len(input_id))
883
+ for matrix in input_id:
884
+ print(matrix.shape)
885
+ print("*****")
886
+ '''
887
+ else:
888
+ # seq_encoded_list没有加特殊字符,input_ids标志位来占位, seq_max_length 根据标志位来加特殊字符长度
889
+ seq_encoded_list, input_ids, position_ids, token_type_ids, seq_attention_masks, seq_max_length = self.__seq_encode__(
890
+ batch_size=batch_size, seqs=new_seqs)
891
+ if multi_seq_flag:
892
+ max_length = min(max_length, seq_max_num * seq_max_len)
893
+ else:
894
+ max_length = min(max_length, seq_max_length)
895
+ seq_part_of_input = True
896
+
897
+ encoded_vectors = None
898
+ vector_part_of_input = False
899
+ if vectors is not None and len(vectors) > 0:
900
+ if multi_seq_flag:
901
+ encoded_vectors, vector_max_num, vector_max_len = self.__multi_vector_encode__(batch_size=batch_size, vectors=vectors)
902
+ elif molecule_flag:
903
+ encoded_vectors, vector_max_length = self.__atom_vector_encode__(batch_size=batch_size, vectors=vectors)
904
+ else:
905
+ encoded_vectors, vector_max_length = self.__vector_encode__(batch_size=batch_size, vectors=vectors)
906
+ # max_length = min(max_length, vector_max_length)
907
+ vector_part_of_input = True
908
+
909
+ encoded_matrices, matrix_attention_masks = None, None
910
+ matrix_part_of_input = False
911
+ # print("multi_seq_flag:", multi_seq_flag)
912
+ if matrices is not None and len(matrices) > 0:
913
+ if multi_seq_flag:
914
+ # 根据标记位填充,��据标记位填充,句子数量,根据标记位是否加上特殊字符长度
915
+ encoded_matrices, matrix_attention_masks, matrix_max_num, matrix_max_len = self.__multi_matrix_encode__(
916
+ batch_size=batch_size,
917
+ matrices=matrices)
918
+ '''
919
+ print("matrix_max_num: %d" % matrix_max_num)
920
+ print("matrix_max_len: %d" % matrix_max_len)
921
+ print(encoded_matrices.shape)
922
+ print("len(matrices): %d" % len(matrices))
923
+ for matrix_array in matrices:
924
+ print(len(matrix_array))
925
+ for matrix in matrix_array:
926
+ print(matrix.shape)
927
+ print("*****")
928
+ '''
929
+ elif molecule_flag:
930
+ # 根据标记位填充,根据标记位填充,句子数量,根据标记位是否加上特殊字符长度
931
+ encoded_matrices, matrix_attention_masks, matrix_max_length = self.__atom_matrix_encode__(batch_size=batch_size,
932
+ matrices=matrices
933
+ )
934
+ else:
935
+ # 根据标记位填充,根据标记位填充,句子数量,根据标记位是否加上特殊字符长度
936
+ encoded_matrices, matrix_attention_masks, matrix_max_length = self.__matrix_encode__(batch_size=batch_size,
937
+ matrices=matrices)
938
+ if multi_seq_flag:
939
+ max_length = min(max_length, matrix_max_num * matrix_max_len)
940
+ else:
941
+ max_length = min(max_length, matrix_max_length)
942
+ matrix_part_of_input = True
943
+ has_label = False
944
+ if labels:
945
+ has_label = True
946
+
947
+ new_labels = []
948
+ num_sentences = 1
949
+ sentence_length = 1
950
+ for sample_idx in range(batch_size):
951
+ # seq
952
+ if seq_part_of_input:
953
+ if multi_seq_flag:
954
+ # cls_idx 已经添加
955
+ pass
956
+ elif not molecule_flag and self.prepend_bos:
957
+ input_ids[sample_idx, 0] = self.cls_idx
958
+ elif molecule_flag and self.atom_prepend_bos:
959
+ input_ids[sample_idx, 0] = self.atom_cls_idx
960
+
961
+ seq_encoded = seq_encoded_list[sample_idx]
962
+ real_seq_len = len(seq_encoded)
963
+
964
+ # seq_tensor = torch.tensor(seq_encoded, dtype=torch.int64)
965
+ # print("seq_encoded:")
966
+ # print(seq_encoded)
967
+ if multi_seq_flag:
968
+ cur_seq_num = min(len(seq_encoded), seq_max_num)
969
+ if len(seq_encoded) > cur_seq_num:
970
+ if self.trunc_type == "left":
971
+ seq_encoded = seq_encoded[-cur_seq_num:]
972
+ else:
973
+ seq_encoded = seq_encoded[cur_seq_num:]
974
+ if num_sentences < cur_seq_num:
975
+ num_sentences = cur_seq_num
976
+ # print("cur_seq_num: %d" % len(seq_encoded))
977
+ for seq_idx in range(cur_seq_num):
978
+ cur_seq = seq_encoded[seq_idx]
979
+ cur_seq_len = min(len(cur_seq), seq_max_len)
980
+ '''
981
+ print("cur_seq:")
982
+ print(cur_seq_len)
983
+ print("input_ids:")
984
+ print(input_ids.shape)
985
+ '''
986
+ input_ids[sample_idx, seq_idx, :cur_seq_len] = torch.tensor(cur_seq[:cur_seq_len], dtype=torch.int64)
987
+ seq_attention_masks[sample_idx, seq_idx, :cur_seq_len] = 1
988
+ if cur_seq_len > sentence_length:
989
+ sentence_length = cur_seq_len
990
+ elif molecule_flag:
991
+ seq_tensor = torch.tensor(seq_encoded, dtype=torch.int64)
992
+ input_ids[sample_idx, int(self.atom_prepend_bos): real_seq_len + int(self.atom_prepend_bos)] = seq_tensor
993
+ cur_sentence_length = int(self.atom_prepend_bos) + real_seq_len + int(self.atom_prepend_bos)
994
+ if cur_sentence_length > sentence_length:
995
+ sentence_length = cur_sentence_length
996
+ else:
997
+ seq_tensor = torch.tensor(seq_encoded, dtype=torch.int64)
998
+ input_ids[sample_idx, int(self.prepend_bos): real_seq_len + int(self.prepend_bos)] = seq_tensor
999
+ cur_sentence_length = int(self.prepend_bos) + real_seq_len + int(self.prepend_bos)
1000
+ if cur_sentence_length > sentence_length:
1001
+ sentence_length = cur_sentence_length
1002
+
1003
+ if multi_seq_flag:
1004
+ # eos_idx 已经添加
1005
+ pass
1006
+ elif not molecule_flag and self.append_eos:
1007
+ input_ids[sample_idx, real_seq_len + int(self.prepend_bos)] = self.eos_idx
1008
+ elif molecule_flag and self.atom_append_eos:
1009
+ input_ids[sample_idx, real_seq_len + int(self.atom_prepend_bos)] = self.atom_eos_idx
1010
+
1011
+ if multi_seq_flag:
1012
+ cur_len = num_sentences * sentence_length
1013
+ elif molecule_flag:
1014
+ cur_len = int(self.atom_prepend_bos) + real_seq_len + int(self.atom_append_eos)
1015
+ else:
1016
+ cur_len = int(self.prepend_bos) + real_seq_len + int(self.append_eos)
1017
+
1018
+ if not self.no_position_embeddings:
1019
+ if multi_seq_flag:
1020
+ for pos_idx in range(0, cur_len):
1021
+ position_ids[sample_idx, pos_idx//sentence_length, pos_idx % sentence_length] = pos_idx % sentence_length
1022
+ else:
1023
+ for pos_idx in range(0, cur_len):
1024
+ position_ids[sample_idx, pos_idx] = pos_idx
1025
+
1026
+ if not self.no_token_type_embeddings:
1027
+ seq_type = seq_types[sample_idx]
1028
+ if seq_type == "gene":
1029
+ type_value = 0
1030
+ else:
1031
+ type_value = 1
1032
+ if multi_seq_flag:
1033
+ for pos_idx in range(0, cur_len):
1034
+ token_type_ids[sample_idx, pos_idx//sentence_length, pos_idx % sentence_length] = type_value
1035
+ else:
1036
+ for pos_idx in range(0, cur_len):
1037
+ token_type_ids[sample_idx, pos_idx] = type_value
1038
+
1039
+ if multi_seq_flag:
1040
+ pass
1041
+ else:
1042
+ seq_attention_masks[sample_idx, 0: cur_len] = 1
1043
+
1044
+ # vector
1045
+ if vector_part_of_input:
1046
+ if multi_seq_flag:
1047
+ cur_vector_num = min(len(vectors[sample_idx]), vector_max_num)
1048
+ if num_sentences < cur_vector_num:
1049
+ num_sentences = cur_vector_num
1050
+ for vector_idx in range(cur_vector_num):
1051
+ encoded_vectors[sample_idx, vector_idx, :] = torch.tensor(vectors[sample_idx][vector_idx], dtype=torch.float32)
1052
+ else:
1053
+ encoded_vectors[sample_idx, :] = torch.tensor(vectors[sample_idx], dtype=torch.float32)
1054
+
1055
+ # matrix
1056
+ if matrix_part_of_input:
1057
+ '''
1058
+ matrix_encoded = matrices[sample_idx]
1059
+ if self.matrix_add_special_token:
1060
+ real_seq_len = matrix_encoded.shape[0] - 2
1061
+ else:
1062
+ real_seq_len = matrix_encoded.shape[0]
1063
+ if multi_seq_flag:
1064
+ pass
1065
+ elif molecule_flag:
1066
+ # real_seq_len = real_seq_len - int(self.atom_prepend_bos) - int(self.atom_append_eos)
1067
+ real_seq_len = min(real_seq_len, self.atom_truncation_matrix_length)
1068
+ else:
1069
+ # real_seq_len = real_seq_len - int(self.prepend_bos) - int(self.append_eos)
1070
+ real_seq_len = min(real_seq_len, self.truncation_matrix_length)
1071
+ # print("real_seq_len: %d" % real_seq_len)
1072
+ '''
1073
+ if multi_seq_flag:
1074
+ # 多序列matrix
1075
+ matrix_encoded_list = matrices[sample_idx]
1076
+ cur_matrix_num = min(len(matrix_encoded_list), matrix_max_num)
1077
+ if len(matrix_encoded_list) > cur_matrix_num:
1078
+ if self.trunc_type == "left":
1079
+ matrix_encoded_list = matrix_encoded_list[:cur_matrix_num]
1080
+ else:
1081
+ matrix_encoded_list = matrix_encoded_list[-cur_matrix_num:]
1082
+ if num_sentences < cur_matrix_num:
1083
+ num_sentences = cur_matrix_num
1084
+ # print("matrix_encoded_list: %d" % len(matrix_encoded_list))
1085
+ for matrix_idx in range(cur_matrix_num):
1086
+ # print("matrix_idx: %d" % matrix_idx)
1087
+ cur_matrix = matrix_encoded_list[matrix_idx]
1088
+ cur_matrix = torch.tensor(cur_matrix, dtype=torch.float32)
1089
+ cur_matrix_len = min(cur_matrix.shape[0], matrix_max_len)
1090
+ if self.matrix_add_special_token:
1091
+ encoded_matrices[sample_idx, matrix_idx, 0: cur_matrix_len - 1] = cur_matrix[0:cur_matrix_len - 1]
1092
+ encoded_matrices[sample_idx, matrix_idx, cur_matrix_len - 1] = cur_matrix[-1]
1093
+ matrix_attention_masks[sample_idx, matrix_idx, 0:cur_matrix_len] = 1
1094
+ else:
1095
+ encoded_matrices[sample_idx, matrix_idx, int(self.prepend_bos): cur_matrix_len + int(self.prepend_bos)] = cur_matrix[0:cur_matrix_len]
1096
+ matrix_attention_masks[sample_idx, matrix_idx, 0: int(self.prepend_bos) + cur_matrix_len + int(self.append_eos)] = 1
1097
+ cur_matrix_len = int(self.prepend_bos) + cur_matrix_len + int(self.append_eos)
1098
+ if sentence_length < cur_matrix_len:
1099
+ sentence_length = cur_matrix_len
1100
+ else:
1101
+ matrix_encoded = matrices[sample_idx]
1102
+ if self.matrix_add_special_token:
1103
+ real_seq_len = matrix_encoded.shape[0] - 2
1104
+ else:
1105
+ real_seq_len = matrix_encoded.shape[0]
1106
+ if molecule_flag:
1107
+ # real_seq_len = real_seq_len - int(self.atom_prepend_bos) - int(self.atom_append_eos)
1108
+ real_seq_len = min(real_seq_len, self.atom_truncation_matrix_length)
1109
+ matrix = torch.tensor(matrix_encoded, dtype=torch.float32)
1110
+ if self.matrix_add_special_token:
1111
+ encoded_matrices[sample_idx, 0: real_seq_len + 2] \
1112
+ = matrix[0: real_seq_len + 2]
1113
+ matrix_attention_masks[sample_idx, 0: real_seq_len + 2] = 1
1114
+ cur_sentence_length = real_seq_len + 2
1115
+ else:
1116
+ encoded_matrices[sample_idx, int(self.atom_prepend_bos): real_seq_len + int(self.atom_prepend_bos)] \
1117
+ = matrix[0: real_seq_len]
1118
+ # matrix_attention_masks[sample_idx, int(self.atom_prepend_bos): real_seq_len + int(self.atom_prepend_bos)] = 1
1119
+ matrix_attention_masks[sample_idx, 0: int(self.atom_prepend_bos) + real_seq_len + int(self.atom_append_eos)] = 1
1120
+ cur_sentence_length = int(self.atom_prepend_bos) + real_seq_len + int(self.atom_prepend_bos)
1121
+ if cur_sentence_length > sentence_length:
1122
+ sentence_length = cur_sentence_length
1123
+ else:
1124
+ # real_seq_len = real_seq_len - int(self.prepend_bos) - int(self.append_eos)
1125
+ real_seq_len = min(real_seq_len, self.truncation_matrix_length)
1126
+ matrix = torch.tensor(matrix_encoded, dtype=torch.float32)
1127
+ if self.matrix_add_special_token:
1128
+ encoded_matrices[sample_idx, 0: real_seq_len + 2] = matrix[0: real_seq_len + 2]
1129
+ matrix_attention_masks[sample_idx, 0: real_seq_len + 2] = 1
1130
+ cur_sentence_length = real_seq_len + 2
1131
+ else:
1132
+ encoded_matrices[sample_idx, int(self.prepend_bos): real_seq_len + int(self.prepend_bos)] = matrix[0: real_seq_len]
1133
+ # matrix_attention_masks[sample_idx, int(self.prepend_bos): real_seq_len + int(self.prepend_bos)] = 1
1134
+ matrix_attention_masks[sample_idx, 0: int(self.prepend_bos) + real_seq_len + int(self.append_eos)] = 1
1135
+ cur_sentence_length = int(self.prepend_bos) + real_seq_len + int(self.prepend_bos)
1136
+ if cur_sentence_length > sentence_length:
1137
+ sentence_length = cur_sentence_length
1138
+
1139
+ if has_label:
1140
+ if multi_seq_flag:
1141
+ # to do
1142
+ new_labels.append(
1143
+ self.__parse_label__(max_length, self.task_level_type,
1144
+ self.label_size, self.output_mode, labels[sample_idx]))
1145
+ elif molecule_flag:
1146
+ new_labels.append(
1147
+ self.__atom_parse_label__(max_length, self.task_level_type,
1148
+ self.label_size, self.output_mode, labels[sample_idx]))
1149
+ else:
1150
+ new_labels.append(
1151
+ self.__parse_label__(max_length, self.task_level_type,
1152
+ self.label_size, self.output_mode, labels[sample_idx]))
1153
+ if new_labels is not None and new_labels:
1154
+ if self.output_mode in ["regression"]:
1155
+ labels = torch.tensor(new_labels, dtype=torch.float32)
1156
+ else:
1157
+ labels = torch.tensor(new_labels, dtype=torch.int64)
1158
+ else:
1159
+ labels = None
1160
+ '''
1161
+ print(input_ids.shape)
1162
+ print("encoded_matrices:")
1163
+ print(encoded_matrices.shape)
1164
+ print("num_sentences:%d" % num_sentences)
1165
+ print("sentence_length:%d" % sentence_length)
1166
+ if labels is not None:
1167
+ print("labels:")
1168
+ print(labels.shape)
1169
+ '''
1170
+
1171
+ if multi_seq_flag:
1172
+ if seq_part_of_input:
1173
+ input_ids = torch.reshape(input_ids, (input_ids.shape[0], -1))
1174
+ if matrix_part_of_input:
1175
+ encoded_matrices = torch.reshape(encoded_matrices, (encoded_matrices.shape[0], -1, encoded_matrices.shape[-1]))
1176
+ if position_ids is not None:
1177
+ position_ids = torch.reshape(position_ids, (position_ids.shape[0], -1))
1178
+ if token_type_ids is not None:
1179
+ token_type_ids = torch.reshape(token_type_ids, (token_type_ids.shape[0], -1))
1180
+ if seq_attention_masks is not None:
1181
+ seq_attention_masks = torch.reshape(seq_attention_masks, (seq_attention_masks.shape[0], -1))
1182
+ if matrix_attention_masks is not None:
1183
+ matrix_attention_masks = torch.reshape(matrix_attention_masks, (matrix_attention_masks.shape[0], -1))
1184
+ '''
1185
+ print(input_ids.shape)
1186
+ print("encoded_matrices:")
1187
+ print(encoded_matrices.shape)
1188
+ print("num_sentences:%d" % num_sentences)
1189
+ print("sentence_length:%d" % sentence_length)
1190
+ if labels is not None:
1191
+ print("labels:")
1192
+ print(labels.shape)
1193
+ print("-" * 50)
1194
+ '''
1195
+
1196
+ return input_ids, \
1197
+ position_ids, \
1198
+ token_type_ids, \
1199
+ seq_attention_masks, \
1200
+ encoded_vectors, \
1201
+ encoded_matrices, \
1202
+ matrix_attention_masks, \
1203
+ num_sentences, \
1204
+ sentence_length, \
1205
+ labels
1206
+
1207
+ def __call__(self, raw_batch: Sequence[dict]):
1208
+ batch_size = len(raw_batch)
1209
+ # pair
1210
+ if "seq_id_a" in raw_batch[0] and "seq_id_b" in raw_batch[0]:
1211
+ res = {}
1212
+ # seq_ids_a = []
1213
+ seq_types_a = []
1214
+ seqs_a = []
1215
+ vectors_a = []
1216
+ matrices_a = []
1217
+
1218
+ # seq_ids_b = []
1219
+ seq_types_b = []
1220
+ seqs_b = []
1221
+ vectors_b = []
1222
+ matrices_b = []
1223
+
1224
+ labels = []
1225
+ for item in raw_batch:
1226
+ # seq_ids_a.append(item["seq_id_a"])
1227
+ seq_types_a.append(item["seq_type_a"])
1228
+ if item["seq_a"] is not None:
1229
+ seqs_a.append(item["seq_a"])
1230
+ if item["vector_a"] is not None:
1231
+ vectors_a.append(item["vector_a"])
1232
+ if item["matrix_a"] is not None:
1233
+ matrices_a.append(item["matrix_a"])
1234
+
1235
+ # seq_ids_b.append(item["seq_id_b"])
1236
+ seq_types_b.append(item["seq_type_b"])
1237
+ if item["seq_b"] is not None:
1238
+ seqs_b.append(item["seq_b"])
1239
+ if item["vector_b"] is not None:
1240
+ vectors_b.append(item["vector_b"])
1241
+ if item["matrix_b"] is not None:
1242
+ matrices_b.append(item["matrix_b"])
1243
+ if "label" in item and item["label"] is not None:
1244
+ labels.append(item["label"])
1245
+ input_ids_a, position_ids_a, token_type_ids_a, seq_attention_masks_a, encoded_vectors_a, encoded_matrices_a, matrix_attention_masks_a, num_sentences_a, sentence_length_a, labels \
1246
+ = self.__call_single__(batch_size, seq_types_a, seqs_a, vectors_a, matrices_a, labels)
1247
+ if not hasattr(self, "max_sentences") or self.max_sentences is None:
1248
+ res.update({
1249
+ "input_ids_a": input_ids_a,
1250
+ "position_ids_a": position_ids_a,
1251
+ "token_type_ids_a": token_type_ids_a,
1252
+ "seq_attention_masks_a": seq_attention_masks_a,
1253
+ "vectors_a": encoded_vectors_a,
1254
+ "matrices_a": encoded_matrices_a,
1255
+ "matrix_attention_masks_a": matrix_attention_masks_a,
1256
+ "labels": labels if labels is not None and len(labels) > 0 else None
1257
+ })
1258
+ else:
1259
+ res.update({
1260
+ "input_ids_a": input_ids_a,
1261
+ "position_ids_a": position_ids_a,
1262
+ "token_type_ids_a": token_type_ids_a,
1263
+ "seq_attention_masks_a": seq_attention_masks_a,
1264
+ "vectors_a": encoded_vectors_a,
1265
+ "matrices_a": encoded_matrices_a,
1266
+ "matrix_attention_masks_a": matrix_attention_masks_a,
1267
+ "num_sentences_a": num_sentences_a,
1268
+ "sentence_length_a": sentence_length_a,
1269
+ "labels": labels if labels is not None and len(labels) > 0 else None
1270
+ })
1271
+ input_ids_b, position_ids_b, token_type_ids_b, seq_attention_masks_b, encoded_vectors_b, encoded_matrices_b, matrix_attention_masks_b, num_sentences_b, sentence_length_b, _ \
1272
+ = self.__call_single__(batch_size, seq_types_b, seqs_b, vectors_b, matrices_b, labels=None)
1273
+ if not hasattr(self, "max_sentences") or self.max_sentences is None:
1274
+ res.update({
1275
+ "input_ids_b": input_ids_b,
1276
+ "position_ids_b": position_ids_b,
1277
+ "token_type_ids_b": token_type_ids_b,
1278
+ "seq_attention_masks_b": seq_attention_masks_b,
1279
+ "vectors_b": encoded_vectors_b,
1280
+ "matrices_b": encoded_matrices_b,
1281
+ "matrix_attention_masks_b": matrix_attention_masks_b
1282
+ })
1283
+ else:
1284
+ res.update({
1285
+ "input_ids_b": input_ids_b,
1286
+ "position_ids_b": position_ids_b,
1287
+ "token_type_ids_b": token_type_ids_b,
1288
+ "seq_attention_masks_b": seq_attention_masks_b,
1289
+ "vectors_b": encoded_vectors_b,
1290
+ "matrices_b": encoded_matrices_b,
1291
+ "num_sentences_b": num_sentences_b,
1292
+ "sentence_length_b": sentence_length_b,
1293
+ "matrix_attention_masks_b": matrix_attention_masks_b
1294
+ })
1295
+ return res
1296
+ else:
1297
+ res = {}
1298
+ # seq_ids = []
1299
+ seq_types = []
1300
+ seqs = []
1301
+ vectors = []
1302
+ matrices = []
1303
+ labels = []
1304
+ for item in raw_batch:
1305
+ # seq_ids.append(item["seq_id"])
1306
+ seq_types.append(item["seq_type"])
1307
+ if item["seq"] is not None:
1308
+ seqs.append(item["seq"])
1309
+ if item["vector"] is not None:
1310
+ vectors.append(item["vector"])
1311
+ if item["matrix"] is not None:
1312
+ matrices.append(item["matrix"])
1313
+ if item["label"] is not None:
1314
+ labels.append(item["label"])
1315
+ '''
1316
+ print("seqs:")
1317
+ print(seqs)
1318
+ print([len(seq) for seq in seqs])
1319
+ print("matrices:")
1320
+ print(matrices)
1321
+ print([matrix.shape for matrix in matrices])
1322
+ print("labels:")
1323
+ print(labels)
1324
+ print([len(eval(label)) for label in labels])
1325
+ '''
1326
+ input_ids, position_ids, token_type_ids, seq_attention_masks, encoded_vectors, encoded_matrices, matrix_attention_masks, num_sentences, sentence_length, labels = self.__call_single__(
1327
+ batch_size, seq_types, seqs, vectors, matrices, labels=labels)
1328
+
1329
+ if not hasattr(self, "max_sentences") or self.max_sentences is None:
1330
+ res.update({
1331
+ "input_ids": input_ids,
1332
+ "position_ids": position_ids,
1333
+ "token_type_ids": token_type_ids,
1334
+ "seq_attention_masks": seq_attention_masks,
1335
+ "vectors": encoded_vectors,
1336
+ "matrices": encoded_matrices,
1337
+ "matrix_attention_masks": matrix_attention_masks,
1338
+ "labels": labels if labels is not None and len(labels) > 0 else None
1339
+ })
1340
+ else:
1341
+ res.update({
1342
+ "input_ids": input_ids,
1343
+ "position_ids": position_ids,
1344
+ "token_type_ids": token_type_ids,
1345
+ "seq_attention_masks": seq_attention_masks,
1346
+ "vectors": encoded_vectors,
1347
+ "matrices": encoded_matrices,
1348
+ "matrix_attention_masks": matrix_attention_masks,
1349
+ "num_sentences": num_sentences,
1350
+ "sentence_length": sentence_length,
1351
+ "labels": labels if labels is not None and len(labels) > 0 else None
1352
+ })
1353
+
1354
+ '''
1355
+ for item in res.items():
1356
+ key_name = item[0]
1357
+ print(key_name, ":")
1358
+ if item[1] is not None:
1359
+ print(item[1])
1360
+ print(item[1].shape)
1361
+ else:
1362
+ print("None")
1363
+ '''
1364
+ return res
1365
+
classification_loss.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: [email protected]
7
+ @tel: 137****6540
8
+ @datetime: 2023/5/3 20:35
9
+ @project: LucaOne
10
+ @file: loss.py
11
+ @desc: loss
12
+ '''
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from .masked_loss import _MaskedLoss
17
+
18
+ class MaskedFocalLoss(_MaskedLoss):
19
+ """Masked FocalLoss"""
20
+ def __init__(self, alpha=1, gamma=2, normalization=False, reduction='mean', ignore_nans=True, ignore_value=-100):
21
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
22
+ self.criterion = FocalLoss(alpha=alpha, gamma=gamma, normalization=normalization, reduction='none')
23
+
24
+
25
+ class FocalLoss(nn.Module):
26
+ '''
27
+ Focal loss
28
+ '''
29
+ def __init__(self, alpha=1, gamma=2, normalization=False, reduction="mean"):
30
+ super(FocalLoss, self).__init__()
31
+ self.alpha = alpha
32
+ self.gamma = gamma
33
+ self.normalization = normalization
34
+ self.reduction = reduction
35
+
36
+ def forward(self, inputs, targets):
37
+ if self.normalization:
38
+ '''
39
+ reduction: the operation on the output loss, which can be set to 'none', 'mean', and 'sum';
40
+ 'none' will not perform any processing on the loss,
41
+ 'mean' will calculate the mean of the loss,
42
+ 'sum' will sum the loss, and the default is 'mean'
43
+ '''
44
+ bce = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
45
+ probs = torch.sigmoid(inputs)
46
+ else:
47
+ bce = F.binary_cross_entropy(inputs, targets, reduction='none')
48
+ probs = inputs
49
+ pt = targets * probs + (1 - targets) * (1 - probs)
50
+ modulate = 1 if self.gamma is None else (1 - pt) ** self.gamma
51
+
52
+ focal_loss = modulate * bce
53
+
54
+ if self.alpha is not None:
55
+ assert 0 <= self.alpha <= 1
56
+ alpha_weights = targets * self.alpha + (1 - targets) * (1 - self.alpha)
57
+ focal_loss *= alpha_weights
58
+ if self.reduction == "mean":
59
+ # global mean
60
+ return torch.mean(focal_loss)
61
+ if self.reduction in ["summean", "meansum"]:
62
+ # sum of all samples and calc the mean value
63
+ return torch.mean(torch.sum(focal_loss, dim=1))
64
+ elif self.reduction == "sum":
65
+ return torch.sum(focal_loss, dim=1)
66
+ else:
67
+ return focal_loss
68
+
69
+
70
+ class MaskedMultiLabelCCE(_MaskedLoss):
71
+ """Masked MultiLabel CCE"""
72
+ def __init__(self, normalization=False, reduction='mean', ignore_nans=True, ignore_value=-100):
73
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
74
+ self.criterion = MultiLabelCCE(normalization=normalization, reduction='none')
75
+
76
+
77
+ class MultiLabelCCE(nn.Module):
78
+ '''
79
+ Multi Label CCE
80
+ '''
81
+ def __init__(self, normalization=False, reduction='mean'):
82
+ super(MultiLabelCCE, self).__init__()
83
+ self.normalization = normalization
84
+ self.reduction = reduction
85
+
86
+ def forward(self, inputs, targets):
87
+ """
88
+ Cross entropy of multi-label classification
89
+ Note:The shapes of y_true and y_pred are consistent, and the elements of y_true are either 0 or 1. 1 indicates
90
+ that the corresponding class is a target class, and 0 indicates that the corresponding class is a non-target class.
91
+ """
92
+ if self.normalization:
93
+ y_pred = torch.softmax(inputs, dim=-1)
94
+ else:
95
+ y_pred = inputs
96
+ y_true = targets
97
+ y_pred = (1 - 2 * y_true) * y_pred
98
+ y_pred_neg = y_pred - y_true * 1e12
99
+ y_pred_pos = y_pred - (1 - y_true) * 1e12
100
+ zeros = torch.zeros_like(y_pred[..., :1])
101
+ y_pred_neg = torch.cat((y_pred_neg, zeros), axis=-1)
102
+ y_pred_pos = torch.cat((y_pred_pos, zeros), axis=-1)
103
+ neg_loss = torch.logsumexp(y_pred_neg, axis=-1)
104
+ pos_loss = torch.logsumexp(y_pred_pos, axis=-1)
105
+ if self.reduction == 'mean':
106
+ return torch.mean(neg_loss + pos_loss)
107
+ elif self.reduction == 'sum':
108
+ return torch.sum(neg_loss + pos_loss)
109
+ else:
110
+ return neg_loss + pos_loss
111
+
112
+
113
+ class MaskedAsymmetricLoss(_MaskedLoss):
114
+ """Masked AsymmetricLoss"""
115
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False, reduction='mean', ignore_nans=True, ignore_value=-100):
116
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
117
+ self.criterion = AsymmetricLoss(gamma_neg, gamma_pos, clip, eps, disable_torch_grad_focal_loss)
118
+
119
+
120
+ class AsymmetricLoss(nn.Module):
121
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
122
+ super(AsymmetricLoss, self).__init__()
123
+
124
+ self.gamma_neg = gamma_neg
125
+ self.gamma_pos = gamma_pos
126
+ self.clip = clip
127
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
128
+ self.eps = eps
129
+
130
+ def forward(self, x, y):
131
+ """"
132
+ Parameters
133
+ ----------
134
+ x: input logits
135
+ y: targets (multi-label binarized vector)
136
+ """
137
+
138
+ # Calculating Probabilities
139
+ x_sigmoid = torch.sigmoid(x)
140
+ xs_pos = x_sigmoid
141
+ xs_neg = 1 - x_sigmoid
142
+
143
+ # Asymmetric Clipping
144
+ if self.clip is not None and self.clip > 0:
145
+ xs_neg = (xs_neg + self.clip).clamp(max=1)
146
+
147
+ # Basic CE calculation
148
+ los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
149
+ los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
150
+ loss = los_pos + los_neg
151
+
152
+ # Asymmetric Focusing
153
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
154
+ if self.disable_torch_grad_focal_loss:
155
+ torch.set_grad_enabled(False)
156
+ pt0 = xs_pos * y
157
+ pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
158
+ pt = pt0 + pt1
159
+ one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
160
+ one_sided_w = torch.pow(1 - pt, one_sided_gamma)
161
+ if self.disable_torch_grad_focal_loss:
162
+ torch.set_grad_enabled(True)
163
+ loss *= one_sided_w
164
+
165
+ return -loss.sum()
166
+
167
+
168
+ class MaskedAsymmetricLossOptimized(_MaskedLoss):
169
+ """Masked ASLSingleLabel loss"""
170
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False, reduction='mean', ignore_nans=True, ignore_value=-100):
171
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
172
+ self.criterion = AsymmetricLossOptimized(gamma_neg, gamma_pos, clip, eps, disable_torch_grad_focal_loss)
173
+
174
+
175
+ class AsymmetricLossOptimized(nn.Module):
176
+ '''
177
+ Notice - optimized version, minimizes memory allocation and gpu uploading,
178
+ favors inplace operations
179
+ '''
180
+
181
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
182
+ super(AsymmetricLossOptimized, self).__init__()
183
+
184
+ self.gamma_neg = gamma_neg
185
+ self.gamma_pos = gamma_pos
186
+ self.clip = clip
187
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
188
+ self.eps = eps
189
+
190
+ # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
191
+ self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None
192
+
193
+ def forward(self, x, y):
194
+ """"
195
+ Parameters
196
+ ----------
197
+ x: input logits
198
+ y: targets (multi-label binarized vector)
199
+ """
200
+
201
+ self.targets = y
202
+ self.anti_targets = 1 - y
203
+
204
+ # Calculating Probabilities
205
+ self.xs_pos = torch.sigmoid(x)
206
+ self.xs_neg = 1.0 - self.xs_pos
207
+
208
+ # Asymmetric Clipping
209
+ if self.clip is not None and self.clip > 0:
210
+ self.xs_neg.add_(self.clip).clamp_(max=1)
211
+
212
+ # Basic CE calculation
213
+ self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
214
+ self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
215
+
216
+ # Asymmetric Focusing
217
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
218
+ if self.disable_torch_grad_focal_loss:
219
+ torch.set_grad_enabled(False)
220
+ self.xs_pos = self.xs_pos * self.targets
221
+ self.xs_neg = self.xs_neg * self.anti_targets
222
+ self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
223
+ self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
224
+ if self.disable_torch_grad_focal_loss:
225
+ torch.set_grad_enabled(True)
226
+ self.loss *= self.asymmetric_w
227
+
228
+ return -self.loss.sum()
229
+
230
+
231
+ class MaskedASLSingleLabel(_MaskedLoss):
232
+ """Masked ASLSingleLabel loss"""
233
+ def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction='mean', ignore_nans=True, ignore_value=-100):
234
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
235
+ self.criterion = ASLSingleLabel(gamma_pos, gamma_neg, eps, reduction='none')
236
+
237
+
238
+ class ASLSingleLabel(nn.Module):
239
+ '''
240
+ This loss is intended for single-label classification problems(multi-class)
241
+ '''
242
+ def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction='mean'):
243
+ super(ASLSingleLabel, self).__init__()
244
+
245
+ self.eps = eps
246
+ self.logsoftmax = nn.LogSoftmax(dim=-1)
247
+ self.targets_classes = []
248
+ self.gamma_pos = gamma_pos
249
+ self.gamma_neg = gamma_neg
250
+ self.reduction = reduction
251
+
252
+ def forward(self, inputs, target):
253
+ '''
254
+ "input" dimensions: - (batch_size, number_classes)
255
+ "target" dimensions: - (batch_size)
256
+ '''
257
+ num_classes = inputs.size()[-1]
258
+ log_preds = self.logsoftmax(inputs)
259
+ self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)
260
+
261
+ # ASL weights
262
+ targets = self.targets_classes
263
+ anti_targets = 1 - targets
264
+ xs_pos = torch.exp(log_preds)
265
+ xs_neg = 1 - xs_pos
266
+ xs_pos = xs_pos * targets
267
+ xs_neg = xs_neg * anti_targets
268
+ asymmetric_w = torch.pow(1 - xs_pos - xs_neg, self.gamma_pos * targets + self.gamma_neg * anti_targets)
269
+ log_preds = log_preds * asymmetric_w
270
+
271
+ if self.eps > 0:
272
+ # label smoothing
273
+ self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)
274
+
275
+ # loss calculation
276
+ loss = - self.targets_classes.mul(log_preds)
277
+
278
+ loss = loss.sum(dim=-1)
279
+ if self.reduction == 'mean':
280
+ loss = loss.mean()
281
+
282
+ return loss
283
+
284
+
285
+ class MaskedBCEWithLogitsLoss(_MaskedLoss):
286
+ """Masked MSE loss"""
287
+ def __init__(self, pos_weight=None, weight=None, reduction='mean', ignore_nans=True, ignore_value=-100):
288
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
289
+ self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight, weight=weight, reduction='none')
290
+
291
+
292
+ class MaskedCrossEntropyLoss(_MaskedLoss):
293
+ """Masked MSE loss"""
294
+ def __init__(self, weight=None, reduction='mean', ignore_nans=True, ignore_value=-100):
295
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
296
+ self.criterion = nn.CrossEntropyLoss(weight=weight, reduction='none', ignore_index=ignore_value)
config.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alphabet": "gene_prot",
3
+ "architectures": [
4
+ "LucaGPLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "lucaone_gplm_config.LucaGPLMConfig",
9
+ "AutoModel": "lucaone_gplm.LucaGPLM"
10
+ },
11
+ "bos_token_id": 2,
12
+ "classifier_dropout": 0.0,
13
+ "classifier_dropout_prob": 0.0,
14
+ "classifier_hidden_act": "gelu",
15
+ "embed_scale": 1.0,
16
+ "eos_token_id": 3,
17
+ "gene_mask_classifier_output_size": 2048,
18
+ "gene_mask_label_num": 39,
19
+ "gene_taxonomy_classifier_output_size": 2048,
20
+ "gene_taxonomy_label_num": 735,
21
+ "gene_type_classifier_output_size": 128,
22
+ "gene_type_label_num": 8,
23
+ "hidden_act": "gelu",
24
+ "hidden_dropout_prob": 0.0,
25
+ "hidden_size": 2560,
26
+ "id2label": {
27
+ "0": "LABEL_0",
28
+ "1": "LABEL_1",
29
+ "2": "LABEL_2"
30
+ },
31
+ "ignore_index": -100,
32
+ "label2id": {
33
+ "LABEL_0": 0,
34
+ "LABEL_1": 1,
35
+ "LABEL_2": 2
36
+ },
37
+ "mask_token_id": 4,
38
+ "max_position_embeddings": 1280,
39
+ "model_type": "lucagplm",
40
+ "no_position_embeddings": true,
41
+ "no_token_type_embeddings": false,
42
+ "num_attention_heads": 40,
43
+ "num_hidden_layers": 20,
44
+ "pad_token_id": 0,
45
+ "prot_contact_classifier_output_size": 3072,
46
+ "prot_domain_classifier_output_size": 10240,
47
+ "prot_domain_label_num": 13717,
48
+ "prot_homo_classifier_output_size": 4096,
49
+ "prot_homo_label_num": 3443,
50
+ "prot_keyword_classifier_output_size": 2048,
51
+ "prot_keyword_label_num": 1179,
52
+ "prot_mask_classifier_output_size": 2048,
53
+ "prot_mask_label_num": 39,
54
+ "prot_secondary_classifier_output_size": 3072,
55
+ "prot_site_classifier_output_size": 1024,
56
+ "prot_site_label_num": 946,
57
+ "prot_structure_classifier_output_size": 128,
58
+ "prot_structure_label_num": 3,
59
+ "prot_taxonomy_classifier_output_size": 2048,
60
+ "prot_taxonomy_label_num": 2196,
61
+ "sep_token_id": 3,
62
+ "token_dropout": false,
63
+ "torch_dtype": "float32",
64
+ "trans_classifier_output_size": 128,
65
+ "transformers_version": "4.29.0",
66
+ "type_vocab_size": 2,
67
+ "unk_token_id": 1,
68
+ "use_embed_layer_norm": false,
69
+ "use_last_layer_norm": true,
70
+ "vocab_size": 39
71
+ }
file_operator.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ import csv,sys
5
+ import io, textwrap, itertools
6
+ from Bio import SeqIO
7
+ from Bio.Seq import Seq
8
+ from Bio.SeqRecord import SeqRecord
9
+ csv.field_size_limit(sys.maxsize)
10
+
11
+
12
+ common_nucleotide_set = {'A', 'T', 'C', 'G', 'U', 'N'}
13
+
14
+ # not {'O', 'U', 'Z', 'J', 'B'}
15
+ # Common amino acids
16
+ common_amino_acid_set = {'R', 'X', 'S', 'G', 'W', 'I', 'Q', 'A', 'T', 'V', 'K', 'Y', 'C', 'N', 'L', 'F', 'D', 'M', 'P', 'H', 'E'}
17
+
18
+
19
+ def clean_seq(protein_id, seq):
20
+ seq = seq.upper()
21
+ new_seq = ""
22
+ has_invalid_char = False
23
+ invalid_char_set = set()
24
+ for ch in seq:
25
+ if 'A' <= ch <= 'Z' and ch not in ['J']:
26
+ new_seq += ch
27
+ else:
28
+ invalid_char_set.add(ch)
29
+ has_invalid_char = True
30
+ if has_invalid_char:
31
+ print("id: %s. Seq: %s" % (protein_id, seq))
32
+ print("invalid char set:", invalid_char_set)
33
+ return new_seq
34
+
35
+
36
+ def file_reader(filename, header=True, header_filter=True):
37
+ if filename.endswith(".fa") or filename.endswith(".fas") or filename.endswith(".fasta"):
38
+ return fasta_reader(filename)
39
+ elif filename.endswith(".csv"):
40
+ return csv_reader(filename, header=True, header_filter=True)
41
+ elif filename.endswith(".tsv"):
42
+ return tsv_reader(filename, header=True, header_filter=True)
43
+ else:
44
+ return txt_reader(filename, header=header, header_filter=header_filter)
45
+
46
+
47
+ def txt_reader(handle, header=True, header_filter=True):
48
+ '''
49
+ csv 读取器,适合大文件
50
+ :param handle:
51
+ :param header:
52
+ :param header_filter: 返回结果是否去掉头
53
+ :return:
54
+ '''
55
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
56
+ try:
57
+ cnt = 0
58
+ for line in handle:
59
+ cnt += 1
60
+ if header and header_filter and cnt == 1:
61
+ continue
62
+ yield line.strip()
63
+ except Exception as e:
64
+ raise StopIteration
65
+ finally:
66
+ if not handle.closed:
67
+ handle.close()
68
+
69
+
70
+ def tsv_reader(handle, header=True, header_filter=True):
71
+ '''
72
+ csv 读取器,适合大文件
73
+ :param handle:
74
+ :param header:
75
+ :param header_filter: 返回结果是否去掉头
76
+ :return:
77
+ '''
78
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
79
+ try:
80
+ reader = csv.reader(handle, delimiter="\t")
81
+ cnt = 0
82
+ for row in reader:
83
+ cnt += 1
84
+ if header and header_filter and cnt == 1:
85
+ continue
86
+ yield row
87
+ except Exception as e:
88
+ raise StopIteration
89
+ finally:
90
+ if not handle.closed:
91
+ handle.close()
92
+
93
+
94
+ def csv_reader(handle, header=True, header_filter=True):
95
+ '''
96
+ csv 读取器,适合大文件
97
+ :param handle:
98
+ :param header:
99
+ :param header_filter: 返回结果是否去掉头
100
+ :return:
101
+ '''
102
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
103
+ try:
104
+ # data = csv.reader((line.replace('\0','') for line in data_initial), delimiter=",")
105
+ # reader = csv.reader(handle)
106
+ reader = csv.reader((line.replace('\0', '') for line in handle))
107
+ cnt = 0
108
+ for row in reader:
109
+ cnt += 1
110
+ if header and header_filter and cnt == 1:
111
+ continue
112
+ yield row
113
+ except Exception as e:
114
+ raise StopIteration
115
+ finally:
116
+ if not handle.closed:
117
+ handle.close()
118
+
119
+
120
+ def txt_writer(dataset, handle, header=None):
121
+ '''
122
+ txt 写
123
+ :param dataset: 数据
124
+ :param handle: 文件
125
+ :param header: 头
126
+ :return:
127
+ '''
128
+ '''
129
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'w')
130
+ try:
131
+ if header:
132
+ if isinstance(header, list):
133
+ handle.write(",".join(header) + "\n")
134
+ else:
135
+ handle.write(header + "\n")
136
+ print("header: %s" %header)
137
+ for row in dataset:
138
+ handle.write(str(row) + "\n")
139
+ except Exception as e:
140
+ raise e
141
+ finally:
142
+ if not handle.closed:
143
+ handle.close()
144
+ '''
145
+ with open(handle, "w") as wfp:
146
+ if header:
147
+ if isinstance(header, list):
148
+ wfp.write(",".join(header) + "\n")
149
+ else:
150
+ wfp.write(header + "\n")
151
+ for row in dataset:
152
+ wfp.write(str(row) + "\n")
153
+
154
+
155
+ def csv_writer(dataset, handle, header):
156
+ '''
157
+ csv 写,适合大文件
158
+ :param dataset: 数据
159
+ :param handle: 文件
160
+ :param header: 头
161
+ :return:
162
+ '''
163
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'w')
164
+ try:
165
+ writer = csv.writer(handle)
166
+ if header:
167
+ writer.writerow(header)
168
+ for row in dataset:
169
+ writer.writerow(row)
170
+ except Exception as e:
171
+ raise e
172
+ finally:
173
+ if not handle.closed:
174
+ handle.close()
175
+
176
+
177
+ def fasta_reader(handle, width=None):
178
+ """
179
+ Reads a FASTA file, yielding header, sequence pairs for each sequence recovered 适合大文件
180
+ args:
181
+ :handle (str, pathliob.Path, or file pointer) - fasta to read from
182
+ :width (int or None) - formats the sequence to have max `width` character per line.
183
+ If <= 0, processed as None. If None, there is no max width.
184
+ yields:
185
+ :(header, sequence) tuples
186
+ returns:
187
+ :None
188
+ """
189
+ FASTA_STOP_CODON = "*"
190
+
191
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
192
+ width = width if isinstance(width, int) and width > 0 else None
193
+ try:
194
+ header = None
195
+ for is_header, group in itertools.groupby(handle, lambda line: line.startswith(">")):
196
+ if is_header:
197
+ header = group.__next__().strip()
198
+ else:
199
+ seq = ''.join(line.strip() for line in group).strip().rstrip(FASTA_STOP_CODON)
200
+ if width is not None:
201
+ seq = textwrap.fill(seq, width)
202
+ yield header, seq
203
+ except Exception as e:
204
+ raise StopIteration
205
+ finally:
206
+ if not handle.closed:
207
+ handle.close()
208
+
209
+
210
+ def write_fasta(filepath, sequences):
211
+ '''
212
+ write fasta file
213
+ :param filepath: savepath
214
+ :param sequences: fasta sequence(each item: [id, seq])
215
+ :return:
216
+ '''
217
+
218
+ if sequences:
219
+ with open(filepath, "w") as output_handle:
220
+ if len(sequences[0]) > 1 and isinstance(sequences[0][0], str):
221
+ for row in sequences:
222
+ protein_id = row[0]
223
+ seq = row[1]
224
+ sequence = SeqRecord(Seq(seq, None), id=protein_id[1:] if protein_id and protein_id[0] == ">" else protein_id, description="")
225
+ SeqIO.write(sequence, output_handle, "fasta")
226
+ else:
227
+ for sequence in sequences:
228
+ SeqIO.write(sequence, output_handle, "fasta")
229
+
230
+
loss.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: [email protected]
7
+ @tel: 137****6540
8
+ @datetime: 2023/5/3 20:35
9
+ @project: LucaOne
10
+ @file: loss.py
11
+ @desc: loss
12
+ '''
13
+ import torch, math
14
+ import torch.nn as nn
15
+
16
+ from .classification_loss import *
17
+ from .regression_loss import *
18
+
19
+
20
+
21
+ class NewGELUActivation(nn.Module):
22
+ """
23
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
24
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
25
+ """
26
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
27
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
28
+
29
+
30
+ def create_activate(activate_func):
31
+ if activate_func:
32
+ activate_func = activate_func.lower()
33
+ if activate_func == "tanh":
34
+ return nn.Tanh()
35
+ elif activate_func == "relu":
36
+ return nn.ReLU()
37
+ elif activate_func == "leakyrelu":
38
+ return nn.LeakyReLU()
39
+ elif activate_func == "gelu":
40
+ return nn.GELU()
41
+ elif activate_func == "gelu_new":
42
+ return NewGELUActivation()
43
+ else:
44
+ return nn.Tanh()
45
+
46
+
47
+ def create_loss_function(config,
48
+ args,
49
+ task_level_type,
50
+ task_level_name,
51
+ sigmoid,
52
+ output_mode,
53
+ num_labels,
54
+ loss_type,
55
+ ignore_index=-100,
56
+ pair_level=False,
57
+ return_types=["dropout", "hidden_layer", "hidden_act", "classifier", "output", "loss"]
58
+ ):
59
+ '''
60
+ create the output layer and loss layer
61
+ :param task_level_name:
62
+ :param task_level_type:
63
+ :param pair_level:
64
+ :param config:
65
+ :param args:
66
+ :param sigmoid:
67
+ :param output_mode:
68
+ :param num_labels:
69
+ :param loss_type:
70
+ :param ignore_index:
71
+ :param return_types:
72
+ :return:
73
+ '''
74
+ dropout, hidden_layer, hidden_act, classifier, output, loss_fct = None, None, None, None, None, None
75
+ if "dropout" in return_types:
76
+ if hasattr(config, "classifier_dropout_prob"):
77
+ dropout = nn.Dropout(config.classifier_dropout_prob)
78
+ elif hasattr(config, "dropout_prob"):
79
+ dropout = nn.Dropout(config.dropout_prob)
80
+ else:
81
+ dropout = nn.Dropout(0.1)
82
+
83
+ if pair_level:
84
+ hidden_size = 2 * config.hidden_size
85
+ else:
86
+ hidden_size = config.hidden_size
87
+ if "hidden_layer" in return_types:
88
+ if isinstance(args.classifier_size, int):
89
+ hidden_layer_size = args.classifier_size
90
+ else:
91
+ hidden_layer_size = args.classifier_size[task_level_type][task_level_name]
92
+ hidden_layer = nn.Linear(hidden_size, hidden_layer_size, bias=True)
93
+ hidden_size = hidden_layer_size
94
+
95
+ if "hidden_act" in return_types:
96
+ if hasattr(args, "classifier_hidden_act"):
97
+ hidden_act = create_activate(args.classifier_hidden_act)
98
+ elif hasattr(config, "classifier_hidden_act"):
99
+ hidden_act = create_activate(config.classifier_hidden_act)
100
+
101
+ if "classifier" in return_types:
102
+ if sigmoid:
103
+ if output_mode in ["binary_class", "binary-class"]:
104
+ classifier = nn.Linear(hidden_size, 1, bias=True)
105
+ else:
106
+ classifier = nn.Linear(hidden_size, num_labels, bias=True)
107
+ else:
108
+ classifier = nn.Linear(hidden_size, num_labels, bias=True)
109
+ if "output" in return_types:
110
+ if sigmoid or output_mode in ["multi_label", "multi-label", "binary_class", "binary-class"]:
111
+ output = nn.Sigmoid()
112
+ elif output_mode in ["multi_class", "multi-class"]:
113
+ output = nn.Softmax(dim=-1)
114
+ else:
115
+ output = None
116
+
117
+ if "loss" in return_types:
118
+ # positive weight
119
+ if hasattr(args, "pos_weight") and args.pos_weight:
120
+ pos_weight = args.pos_weight
121
+ elif hasattr(config, "pos_weight") and config.pos_weight:
122
+ pos_weight = config.pos_weight
123
+ else:
124
+ pos_weight = None
125
+
126
+ if hasattr(args, "weight") and args.weight is not None:
127
+ weight = args.weight
128
+ elif hasattr(config, "weight") and config.weight is not None:
129
+ weight = config.weight
130
+ else:
131
+ weight = None
132
+
133
+ reduction = config.loss_reduction if hasattr(config, "loss_reduction") else "meanmean"
134
+ if output_mode in ["regression"]:
135
+ if loss_type == "l2":
136
+ loss_fct = MaskedMSELoss(reduction=reduction, ignore_nans=True,
137
+ ignore_value=ignore_index * 1.0 if ignore_index else None)
138
+ elif loss_type == "l1":
139
+ loss_fct = MaskedL1Loss(reduction=reduction, ignore_nans=True,
140
+ ignore_value=ignore_index * 1.0 if ignore_index else None)
141
+ elif output_mode in ["multi_label", "multi-label"]:
142
+ if loss_type == "bce":
143
+ if pos_weight:
144
+ if isinstance(pos_weight, str) or isinstance(pos_weight, int):
145
+ pos_weight = [float(pos_weight)] * num_labels
146
+ elif isinstance(pos_weight, float):
147
+ pos_weight = [pos_weight] * num_labels
148
+ pos_weight = torch.tensor(pos_weight, dtype=torch.float32).to(args.device)
149
+ print("multi_label pos_weight:")
150
+ print(pos_weight)
151
+ assert pos_weight.ndim == 1 and pos_weight.shape[0] == num_labels
152
+ print("multi_label reduction:")
153
+ print(reduction)
154
+ loss_fct = MaskedBCEWithLogitsLoss(pos_weight=pos_weight, reduction=reduction,
155
+ ignore_nans=True, ignore_value=ignore_index)
156
+ else:
157
+ loss_fct = MaskedBCEWithLogitsLoss(reduction=reduction,
158
+ ignore_nans=True, ignore_value=ignore_index)
159
+ elif loss_type == "asl":
160
+ loss_fct = MaskedAsymmetricLossOptimized(gamma_neg=args.asl_gamma_neg if hasattr(args, "asl_gamma_neg") else 4.0,
161
+ gamma_pos=args.asl_gamma_pos if hasattr(args, "asl_gamma_pos") else 1.0,
162
+ clip=args.clip if hasattr(args, "clip") else 0.05,
163
+ eps=args.eps if hasattr(args, "eps") else 1e-8,
164
+ disable_torch_grad_focal_loss=args.disable_torch_grad_focal_loss if hasattr(args, "disable_torch_grad_focal_loss") else False,
165
+ reduction=reduction,
166
+ ignore_nans=True,
167
+ ignore_value=ignore_index)
168
+ elif loss_type == "focal_loss":
169
+ loss_fct = MaskedFocalLoss(alpha=args.focal_loss_alpha if hasattr(args, "focal_loss_alpha") else 0.7,
170
+ gamma=args.focal_loss_gamma if hasattr(args, "focal_loss_gamma") else 2.0,
171
+ normalization=True,
172
+ reduction=reduction,
173
+ ignore_nans=True,
174
+ ignore_value=ignore_index)
175
+ elif loss_type == "multilabel_cce":
176
+ loss_fct = MaskedMultiLabelCCE(normalization=True,
177
+ reduction=reduction,
178
+ ignore_nans=True,
179
+ ignore_value=ignore_index)
180
+ elif output_mode in ["binary_class", "binary-class"]:
181
+ if loss_type == "bce":
182
+ if pos_weight:
183
+ if isinstance(pos_weight, int) or isinstance(pos_weight, str):
184
+ pos_weight = torch.tensor([float(pos_weight)], dtype=torch.float32).to(args.device)
185
+ elif isinstance(pos_weight, float):
186
+ pos_weight = torch.tensor([pos_weight], dtype=torch.float32).to(args.device)
187
+ print("binary_class pos_weight:")
188
+ print(pos_weight)
189
+ assert pos_weight.ndim == 1 and pos_weight.shape[0] == 1
190
+ loss_fct = MaskedBCEWithLogitsLoss(pos_weight=pos_weight, reduction=reduction, ignore_nans=True,
191
+ ignore_value=ignore_index)
192
+ else:
193
+ loss_fct = MaskedBCEWithLogitsLoss(reduction=reduction, ignore_nans=True, ignore_value=ignore_index)
194
+ elif loss_type == "focal_loss":
195
+ loss_fct = MaskedFocalLoss(alpha=args.focal_loss_alpha if hasattr(args, "focal_loss_alpha") else 0.7,
196
+ gamma=args.focal_loss_gamma if hasattr(args, "focal_loss_gamma") else 2.0,
197
+ normalization=True,
198
+ reduction=reduction,
199
+ ignore_nans=True,
200
+ ignore_value=ignore_index)
201
+ elif output_mode in ["multi_class", "multi-class"]:
202
+ if weight:
203
+ # [1, 1, 1, ,1, 1...] length: num_labels
204
+ if isinstance(weight, str) or isinstance(weight, int):
205
+ weight = [float(weight)] * num_labels
206
+ if isinstance(weight, float):
207
+ weight = [weight] * num_labels
208
+ weight = torch.tensor(weight, dtype=torch.float32).to(args.device)
209
+ print("multi_class weight:")
210
+ print(weight)
211
+ assert weight.ndim == 1 and weight.shape[0] == num_labels
212
+ if ignore_index is None:
213
+ loss_fct = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
214
+ else:
215
+ loss_fct = MaskedCrossEntropyLoss(weight=weight, reduction=reduction, ignore_nans=True, ignore_value=ignore_index)
216
+ else:
217
+ if ignore_index is None:
218
+ loss_fct = nn.CrossEntropyLoss(reduction=reduction)
219
+ else:
220
+ loss_fct = MaskedCrossEntropyLoss(reduction=reduction, ignore_nans=True, ignore_value=ignore_index)
221
+ else:
222
+ raise Exception("Not support output mode: %s." % output_mode)
223
+
224
+ return dropout, hidden_layer, hidden_act, classifier, output, loss_fct
lucaone_gplm.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ from .loss import *
5
+ from .model_utils import AllOutput, create_output_loss_lucagplm
6
+ from .alphabet import Alphabet
7
+ from .modeling_gplm import *
8
+ from .lucaone_gplm_config import LucaGPLMConfig
9
+ from transformers import PreTrainedModel
10
+
11
+ class LucaGPLM(PreTrainedModel):
12
+ config_class = LucaGPLMConfig
13
+
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+ self.config = config
17
+ self.max_position_embeddings = config.max_position_embeddings
18
+ self.type_vocab_size = config.type_vocab_size
19
+ self.num_layers = config.num_hidden_layers
20
+ self.embed_dim = config.hidden_size
21
+ self.attention_heads = config.num_attention_heads
22
+ self.no_position_embeddings = config.no_position_embeddings
23
+ self.no_token_type_embeddings = config.no_token_type_embeddings
24
+ if not isinstance(config.alphabet, Alphabet):
25
+ self.alphabet = Alphabet.from_predefined(config.alphabet)
26
+ else:
27
+ self.alphabet = config.alphabet
28
+ self.alphabet_size = len(self.alphabet)
29
+ self.padding_idx = self.alphabet.padding_idx
30
+ self.mask_idx = self.alphabet.mask_idx
31
+ self.cls_idx = self.alphabet.cls_idx
32
+ self.eos_idx = self.alphabet.eos_idx
33
+ self.prepend_bos = self.alphabet.prepend_bos
34
+ self.append_eos = self.alphabet.append_eos
35
+ self.token_dropout = config.token_dropout
36
+ self.ignore_index = config.ignore_index
37
+ self.use_embed_layer_norm = config.use_embed_layer_norm
38
+ self.use_last_layer_norm = config.use_last_layer_norm
39
+ self.embed_scale = config.embed_scale
40
+ self._init_submodules()
41
+
42
+ def _init_submodules(self):
43
+ # normal_(0, 1)
44
+ self.embed_tokens = nn.Embedding(
45
+ self.alphabet_size,
46
+ self.embed_dim,
47
+ padding_idx=self.padding_idx,
48
+ )
49
+ self.embed_pos = None
50
+ if not self.no_position_embeddings:
51
+ self.embed_pos = nn.Embedding(self.max_position_embeddings, self.embed_dim)
52
+ self.embed_type = None
53
+ if not self.no_token_type_embeddings:
54
+ self.embed_type = nn.Embedding(self.type_vocab_size, self.embed_dim)
55
+ if self.use_embed_layer_norm:
56
+ self.embed_layer_norm = LucaGPLM1bLayerNorm(self.embed_dim)
57
+ else:
58
+ self.embed_layer_norm = None
59
+
60
+ self.layers = nn.ModuleList(
61
+ [
62
+ LucaGPLMTransformerLayer(
63
+ self.embed_dim,
64
+ 4 * self.embed_dim,
65
+ self.attention_heads,
66
+ add_bias_kv=False,
67
+ use_lucagplm1b_layer_norm=True,
68
+ use_rotary_embeddings=True,
69
+ )
70
+ for _ in range(self.num_layers)
71
+ ]
72
+ )
73
+ self.layer_size = len(self.layers)
74
+
75
+ self.contact_head = ContactPredictionHead(
76
+ self.num_layers * self.attention_heads,
77
+ self.prepend_bos,
78
+ self.append_eos,
79
+ eos_idx=self.eos_idx,
80
+ )
81
+ if self.use_last_layer_norm:
82
+ self.last_layer_norm = LucaGPLM1bLayerNorm(self.embed_dim)
83
+ else:
84
+ self.last_layer_norm = None
85
+
86
+ self.lm_head = RobertaLMHead(
87
+ embed_dim=self.embed_dim,
88
+ output_dim=self.alphabet_size,
89
+ weight=self.embed_tokens.weight,
90
+ )
91
+
92
+ def _init_embedding(self, pretrained_token_matrix, token_matrix):
93
+ '''
94
+ 0->2
95
+ 1->0
96
+ 2->3
97
+ 3->1
98
+ 4->10
99
+ ...
100
+ 28->34
101
+ 29->36
102
+ 30->37
103
+ 31->38
104
+ 32->4
105
+ '''
106
+ print("Load pretrained exsists embedding vectors:")
107
+ token_matrix[2, :] = pretrained_token_matrix[0, :]
108
+ token_matrix[0, :] = pretrained_token_matrix[1, :]
109
+ token_matrix[3, :] = pretrained_token_matrix[2, :]
110
+ token_matrix[1, :] = pretrained_token_matrix[3, :]
111
+ for idx in range(10, 35):
112
+ token_matrix[idx, :] = pretrained_token_matrix[idx - 6, :]
113
+ token_matrix[36, :] = pretrained_token_matrix[29, :]
114
+ token_matrix[37, :] = pretrained_token_matrix[30, :]
115
+ token_matrix[38, :] = pretrained_token_matrix[31, :]
116
+ token_matrix[4, :] = pretrained_token_matrix[32, :]
117
+ return token_matrix
118
+
119
+ def _init_submodules_new(self, pretrained_model_name):
120
+ print("Load pretrained model exists weights:")
121
+ from esm import pretrained
122
+ from collections import OrderedDict
123
+ pretrained, _ = pretrained.load_model_and_alphabet(pretrained_model_name)
124
+ pretrained_state_dict = pretrained.state_dict()
125
+ new_state_dict = OrderedDict()
126
+ our_model_state_dict = {}
127
+ for key, value in self.state_dict().items():
128
+ our_model_state_dict[key] = value
129
+ for name, weight in pretrained_state_dict.items():
130
+ if "final_layer_norm" in name:
131
+ name = name.replace("final_layer_norm", "post_layer_norm")
132
+ elif "self_attn_layer_norm" in name:
133
+ name = name.replace("self_attn_layer_norm", "pre_layer_norm")
134
+ elif "emb_layer_norm_after" in name:
135
+ name = name.replace("emb_layer_norm_after", "last_layer_norm")
136
+ if name.startswith("layers."):
137
+ layer_id = name.split(".")[1]
138
+ if int(layer_id) >= self.num_layers:
139
+ continue
140
+ if name == "embed_tokens.weight":
141
+ new_state_dict[name] = self._init_embedding(weight, our_model_state_dict[name])
142
+ del our_model_state_dict[name]
143
+ elif name in our_model_state_dict and our_model_state_dict[name].shape == weight.shape:
144
+ del our_model_state_dict[name]
145
+ new_state_dict[name] = weight
146
+
147
+ print("Exists layer names:")
148
+ print(new_state_dict.keys())
149
+ print("Not exists Layer names:")
150
+ print(our_model_state_dict.keys())
151
+ new_state_dict.update(our_model_state_dict)
152
+ self.load_state_dict(new_state_dict)
153
+
154
+ def __calc_loss__(self, task_level_type, output_mode, logits, label, label_size, loss_fct, loss_reduction):
155
+ '''
156
+ if label_size <= 2 or output_mode in ["binary_class", "binary-class"]:
157
+ loss = loss_fct(logits.view(-1), label.view(-1).float())
158
+ elif output_mode in ["multi_label", "multi-label"]:
159
+ loss = loss_fct(logits.view(-1, label_size), label.view(-1, label_size).float())
160
+ elif output_mode in ["multi_class", "multi-class"]:
161
+ loss = loss_fct(logits.view(-1, label_size), label.view(-1))
162
+ else:
163
+ loss = loss_fct(logits.view(-1), label.view(-1))
164
+ return loss
165
+ '''
166
+ '''
167
+ print(task_level_type, output_mode, label_size, loss_fct, loss_reduction)
168
+ print("logits:")
169
+ print(logits.shape)
170
+ print("label:")
171
+ print(label.shape)
172
+ '''
173
+ if output_mode in ["regression"]:
174
+ if task_level_type not in ["seq_level"] and loss_reduction == "meanmean":
175
+ # structure-level regression
176
+ # logits: N, seq_len, 3
177
+ # label: N, seq_len, 3
178
+ loss = loss_fct(logits, label)
179
+ else:
180
+ # structure-level regression
181
+ # logits: N * seq_len * 3
182
+ # label: N * seq_len * 3
183
+ loss = loss_fct(logits.view(-1), label.view(-1))
184
+ elif output_mode in ["multi_label", "multi-label"]:
185
+ # only for seq-level
186
+ if loss_reduction == "meanmean":
187
+ # logits: N , label_size
188
+ # label: N , label_size
189
+ loss = loss_fct(logits, label.float())
190
+ else:
191
+ # logits: N , label_size
192
+ # label: N , label_size
193
+ loss = loss_fct(logits.view(-1, label_size), label.view(-1, label_size).float())
194
+ elif label_size <= 2 or output_mode in ["binary_class", "binary-class"]:
195
+ if task_level_type not in ["seq_level"] and loss_reduction == "meanmean":
196
+ # token-level & meanmean
197
+ # logits: N ,seq_len, 1
198
+ # label: N, seq_len
199
+ loss = loss_fct(logits, label.float())
200
+ else:
201
+ # seq-level || token-level
202
+ # logits: N
203
+ # label: N
204
+ loss = loss_fct(logits.view(-1), label.view(-1).float())
205
+ elif output_mode in ["multi_class", "multi-class"]:
206
+ if task_level_type not in ["seq_level"] and loss_reduction == "meanmean":
207
+ # token-level
208
+ # logits: N ,seq_len, label_size
209
+ # label: N , seq_len
210
+ loss = loss_fct(logits, label)
211
+ else:
212
+ # token-level
213
+ # logits: N * seq_len, label_size
214
+ # label: N * seq_len
215
+ # seq-level
216
+ # logits: N, label_size
217
+ # label: N
218
+ loss = loss_fct(logits.view(-1, label_size), label.view(-1))
219
+ else:
220
+ raise Exception("Not support output_mode=%s" % output_mode)
221
+ return loss
222
+
223
+ def __forword__(self,
224
+ input_ids: Optional[torch.Tensor] = None,
225
+ attention_mask: Optional[torch.Tensor] = None,
226
+ token_type_ids: Optional[torch.Tensor] = None,
227
+ position_ids: Optional[torch.Tensor] = None,
228
+ output_keys: Optional[dict[str, set[str]]] = None,
229
+ labels: Optional[dict[str, dict[str, torch.Tensor]]] = None,
230
+ repr_layers=[-1],
231
+ need_head_weights=False,
232
+ return_contacts=False,
233
+ use_last_layer_norm=True):
234
+ assert all(-(self.layer_size + 1) <= i <= self.layer_size for i in repr_layers)
235
+ repr_layers = [(i + self.layer_size + 1) % (self.layer_size + 1) for i in repr_layers]
236
+
237
+ if return_contacts:
238
+ need_head_weights = True
239
+
240
+ assert input_ids.ndim == 2
241
+ # 动态求mask,(B * Seq_len) 被mask掉位置的值为True
242
+ if attention_mask is None:
243
+ padding_mask = input_ids.eq(self.padding_idx)
244
+ else:
245
+ padding_mask = attention_mask.eq(self.padding_idx)
246
+
247
+ x = self.embed_scale * self.embed_tokens(input_ids)
248
+ if self.embed_pos is not None and position_ids is not None:
249
+ x += self.embed_scale * self.embed_pos(position_ids)
250
+ if self.embed_type is not None and token_type_ids is not None:
251
+ x += self.embed_scale * self.embed_type(token_type_ids)
252
+ if self.embed_layer_norm is not None:
253
+ x = self.embed_layer_norm(x)
254
+ # Token dropout
255
+ if self.token_dropout:
256
+ x.masked_fill_((input_ids == self.mask_idx).unsqueeze(-1), 0.0)
257
+ # x: B x L x C
258
+ mask_ratio_train = 0.15 * 0.8
259
+ src_lengths = (~padding_mask).sum(-1)
260
+ mask_ratio_observed = (input_ids == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
261
+ x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
262
+
263
+ # Mask 操作
264
+ if padding_mask is not None:
265
+ x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
266
+
267
+ # 返回值包括哪些
268
+ repr_layers = set(repr_layers)
269
+ hidden_representations = {}
270
+ # 0:embedding
271
+ if 0 in repr_layers:
272
+ hidden_representations[0] = x
273
+
274
+ # 是否需要返回head weights
275
+ if need_head_weights:
276
+ attn_weights = []
277
+
278
+ # (B, L, E) => (L, B, E)
279
+ x = x.transpose(0, 1)
280
+
281
+ if not padding_mask.any():
282
+ padding_mask = None
283
+
284
+ for layer_idx, layer in enumerate(self.layers):
285
+ x, attn = layer(
286
+ x,
287
+ self_attn_padding_mask=padding_mask,
288
+ need_head_weights=need_head_weights,
289
+ )
290
+ if (layer_idx + 1) in repr_layers:
291
+ hidden_representations[layer_idx + 1] = x.transpose(0, 1)
292
+ if need_head_weights:
293
+ # (H, B, L, L) => (B, H, L, L)
294
+ attn_weights.append(attn.transpose(1, 0))
295
+
296
+ # (L, B, E)
297
+ if self.last_layer_norm is not None and use_last_layer_norm:
298
+ # 最后一层隐含层 加一层layernorm
299
+ x = self.last_layer_norm(x)
300
+ x = x.transpose(0, 1) # (L, B, E) => (B, L, E)
301
+
302
+ # last hidden representation should have layer norm applied
303
+ if (layer_idx + 1) in repr_layers:
304
+ hidden_representations[layer_idx + 1] = x
305
+ # 最后一层作为表征矩阵
306
+ # (B, L, E)
307
+ representation_matrix = hidden_representations[self.layer_size]
308
+ # mask 任务
309
+ # B * Seq_len * vocab_size
310
+ lm_mask_logits = self.lm_head(x)
311
+ # lm head的输出向量作为表征向量
312
+ # (B, E)
313
+ representation_vector = representation_matrix[:, 0, :]
314
+
315
+ logits = {}
316
+ losses = {}
317
+ outputs = {}
318
+ representations = {
319
+ "representation_matrix": representation_matrix,
320
+ "representation_vector": representation_vector
321
+ }
322
+ # 每一层的attention值
323
+ if need_head_weights:
324
+ # attentions: B x Layers x H x L x L
325
+ attentions = torch.stack(attn_weights, 1)
326
+ if padding_mask is not None:
327
+ attention_mask = 1 - padding_mask.type_as(attentions)
328
+ attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
329
+ attentions = attentions * attention_mask[:, None, None, :, :]
330
+ representations["attentions"] = attentions
331
+ # 预测contact矩阵
332
+ if return_contacts:
333
+ contacts = self.contact_head(input_ids, attentions)
334
+ representations["contacts"] = contacts
335
+ '''
336
+ print("output_keys:")
337
+ print(output_keys)
338
+ '''
339
+ if output_keys:
340
+ for item in output_keys.items():
341
+ cur_task_level_type = item[0]
342
+ if cur_task_level_type not in logits:
343
+ logits[cur_task_level_type] = {}
344
+ outputs[cur_task_level_type] = {}
345
+ for cur_task_level_name in item[1]:
346
+ if cur_task_level_type == "token_level":
347
+ cur_logits = lm_mask_logits
348
+ elif cur_task_level_type == "seq_level":
349
+ cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](representation_vector)
350
+ cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
351
+ if cur_hidden_layer is not None:
352
+ cur_logits = cur_hidden_layer(cur_logits)
353
+ cur_hidden_act = self.hidden_act[cur_task_level_type][cur_task_level_name]
354
+ if cur_hidden_act is not None:
355
+ cur_logits = cur_hidden_act(cur_logits)
356
+ cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
357
+ elif cur_task_level_type == "span_level":
358
+ cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](representation_matrix)
359
+ cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
360
+ if cur_hidden_layer is not None:
361
+ cur_logits = cur_hidden_layer(cur_logits)
362
+ cur_hidden_act = self.hidden_act[cur_task_level_type][cur_task_level_name]
363
+ if cur_hidden_act is not None:
364
+ cur_logits = cur_hidden_act(cur_logits)
365
+ cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
366
+ elif cur_task_level_type == "structure_level":
367
+ cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](representation_matrix)
368
+ cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
369
+ if cur_hidden_layer is not None:
370
+ cur_logits = cur_hidden_layer(cur_logits)
371
+ cur_hidden_act = self.hidden_act[cur_task_level_type][cur_task_level_name]
372
+ if cur_hidden_act is not None:
373
+ cur_logits = cur_hidden_act(cur_logits)
374
+ cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
375
+ logits[cur_task_level_type][cur_task_level_name] = cur_logits
376
+ if cur_task_level_type in self.output and cur_task_level_name in self.output[cur_task_level_type] \
377
+ and self.output[cur_task_level_type][cur_task_level_name] is not None:
378
+ outputs[cur_task_level_type][cur_task_level_name] = self.output[cur_task_level_type][cur_task_level_name](cur_logits)
379
+ else:
380
+ outputs[cur_task_level_type][cur_task_level_name] = cur_logits
381
+ if labels is not None and cur_task_level_type in labels and cur_task_level_name in labels[cur_task_level_type]:
382
+ if cur_task_level_type not in losses:
383
+ losses[cur_task_level_type] = {}
384
+ cur_label = labels[cur_task_level_type][cur_task_level_name]
385
+ cur_label_size = self.label_size[cur_task_level_type][cur_task_level_name]
386
+ cur_output_mode = self.output_mode[cur_task_level_type][cur_task_level_name]
387
+ cur_loss_fct = self.loss_fct[cur_task_level_type][cur_task_level_name]
388
+ cur_loss = self.__calc_loss__(
389
+ task_level_type=cur_task_level_type,
390
+ output_mode=cur_output_mode,
391
+ logits=cur_logits,
392
+ label=cur_label,
393
+ label_size=cur_label_size,
394
+ loss_fct=cur_loss_fct,
395
+ loss_reduction="meanmean")
396
+ losses[cur_task_level_type][cur_task_level_name] = cur_loss
397
+ return representations, logits, outputs, losses
398
+
399
+ def forward(
400
+ self,
401
+ input_ids: Optional[torch.Tensor] = None,
402
+ attention_mask: Optional[torch.Tensor] = None,
403
+ global_attention_mask: Optional[torch.Tensor] = None,
404
+ token_type_ids: Optional[torch.Tensor] = None,
405
+ position_ids: Optional[torch.Tensor] = None,
406
+ head_mask: Optional[torch.Tensor] = None,
407
+ inputs_embeds: Optional[torch.Tensor] = None,
408
+ output_keys: Optional[dict[str, set[str]]] = None,
409
+ labels: Optional[dict[str, dict[str, torch.Tensor]]] = None,
410
+ input_ids_b: Optional[torch.Tensor] = None,
411
+ attention_mask_b: Optional[torch.Tensor] = None,
412
+ global_attention_mask_b: Optional[torch.Tensor] = None,
413
+ token_type_ids_b: Optional[torch.Tensor] = None,
414
+ position_ids_b: Optional[torch.Tensor] = None,
415
+ head_mask_b: Optional[torch.Tensor] = None,
416
+ inputs_embeds_b: Optional[torch.Tensor] = None,
417
+ output_keys_b: Optional[dict[str, set[str]]] = None,
418
+ labels_b: Optional[dict[str, dict[str, torch.Tensor]]] = None,
419
+ pair_label: Optional[dict[str, dict[str, torch.Tensor]]] = None,
420
+ pair_output_keys: Optional[dict[str, set[str]]] = None,
421
+ output_hidden_states: Optional[dict[str, set[str]]] = None,
422
+ output_attentions: Optional[dict[str, set[str]]] = None,
423
+ need_head_weights: Optional[bool] = None,
424
+ return_contacts: Optional[bool] = None,
425
+ repr_layers: Optional[list[int]] = None,
426
+ return_dict: Optional[bool] = None,
427
+ use_last_layer_norm: Optional[bool] = True
428
+ ) -> Union[Tuple[torch.Tensor], AllOutput]:
429
+ if return_dict is None and self.config is not None:
430
+ return_dict = self.config.use_return_dict
431
+ if return_dict is None:
432
+ return_dict = False
433
+ if repr_layers is None or len(repr_layers) == 0:
434
+ repr_layers = [-1]
435
+ if return_contacts is None:
436
+ return_contacts = False
437
+ if need_head_weights is None:
438
+ need_head_weights = True
439
+ has_pair = False
440
+ has_pair_b = False
441
+ if input_ids is not None or inputs_embeds is not None:
442
+ encoding, logits, outputs, losses = self.__forword__(
443
+ input_ids=input_ids,
444
+ attention_mask=attention_mask,
445
+ token_type_ids=token_type_ids,
446
+ position_ids=position_ids,
447
+ output_keys=output_keys,
448
+ labels=labels,
449
+ repr_layers=repr_layers,
450
+ need_head_weights=need_head_weights,
451
+ return_contacts=return_contacts,
452
+ use_last_layer_norm=use_last_layer_norm
453
+ )
454
+ has_pair = True
455
+ if input_ids_b is not None or inputs_embeds_b is not None:
456
+ encoding_b, logits_b, outputs_b, losses_b = self.__forword__(
457
+ input_ids=input_ids_b,
458
+ attention_mask=attention_mask_b,
459
+ token_type_ids=token_type_ids_b,
460
+ position_ids=position_ids_b,
461
+ output_keys=output_keys_b,
462
+ labels=labels_b,
463
+ repr_layers=repr_layers,
464
+ need_head_weights=need_head_weights,
465
+ return_contacts=return_contacts,
466
+ use_last_layer_norm=use_last_layer_norm
467
+ )
468
+ has_pair_b = True
469
+ if has_pair and has_pair_b and pair_output_keys and len(pair_output_keys) > 0:
470
+ cur_representation_vector = encoding["representation_vector"]
471
+ cur_representation_vector_b = encoding_b["representation_vector"]
472
+
473
+ pair_logits = {}
474
+ pair_outputs = {}
475
+ for item1 in pair_output_keys.items():
476
+ cur_task_level_type = item1[0]
477
+ if cur_task_level_type not in pair_outputs:
478
+ pair_outputs[cur_task_level_type] = {}
479
+ pair_logits[cur_task_level_type] = {}
480
+ for cur_task_level_name in item1[1]:
481
+ cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](
482
+ torch.cat((cur_representation_vector, cur_representation_vector_b), dim=-1)
483
+ )
484
+ cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
485
+ if cur_hidden_layer is not None:
486
+ cur_logits = cur_hidden_layer(cur_logits)
487
+ cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
488
+ pair_logits[cur_task_level_type][cur_task_level_name] = cur_logits
489
+ pair_outputs[cur_task_level_type][cur_task_level_name] = self.output[cur_task_level_type][cur_task_level_name](cur_logits)
490
+
491
+ if pair_label is not None:
492
+ pair_loss = {}
493
+ for item1 in pair_output_keys.items():
494
+ cur_task_level_type = item1[0]
495
+ if cur_task_level_type not in pair_label:
496
+ continue
497
+ if cur_task_level_type in pair_label:
498
+ pair_loss[cur_task_level_type] = {}
499
+ for cur_task_level_name in item1[1]:
500
+ if cur_task_level_name not in pair_label[cur_task_level_type]:
501
+ continue
502
+ cur_label = pair_label[cur_task_level_type][cur_task_level_name]
503
+ cur_label_size = self.label_size[cur_task_level_type][cur_task_level_name]
504
+ cur_output_mode = self.output_mode[cur_task_level_type][cur_task_level_name]
505
+ cur_loss_fct = self.loss_fct[cur_task_level_type][cur_task_level_name]
506
+ cur_logits = pair_logits[cur_task_level_type][cur_task_level_name]
507
+ cur_loss = self.__calc_loss__(
508
+ task_level_type=cur_task_level_type,
509
+ output_mode=cur_output_mode, logits=cur_logits,
510
+ label=cur_label, label_size=cur_label_size, loss_fct=cur_loss_fct,
511
+ loss_reduction="meanmean")
512
+ pair_loss[cur_task_level_type][cur_task_level_name] = cur_loss
513
+
514
+ if not return_dict:
515
+ return [[losses, losses_b, pair_loss], [outputs, outputs_b, pair_outputs]] + [[encoding, encoding_b]]
516
+ return AllOutput(
517
+ losses=losses,
518
+ outputs=outputs,
519
+ hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
520
+ attentions=encoding["attentions"] if "attentions" in encoding else None,
521
+ global_attentions=None,
522
+ contacts=encoding["contacts"] if "contacts" in encoding else None,
523
+ losses_b=losses_b,
524
+ outputs_b=outputs_b,
525
+ hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
526
+ attentions_b=encoding_b["attentions"] if "hidden_states" in encoding_b else None,
527
+ global_attentions_b=None,
528
+ contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None,
529
+ pair_outputs=pair_outputs,
530
+ pair_losses=pair_loss)
531
+ else:
532
+ if not return_dict:
533
+ return [[losses, losses_b], [outputs, outputs_b]] + [[encoding, encoding_b]]
534
+ return AllOutput(
535
+ losses=losses,
536
+ outputs=outputs,
537
+ hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
538
+ attentions=encoding["attentions"] if "attentions" in encoding else None,
539
+ global_attentions=None,
540
+ contacts=encoding["contacts"] if "contacts" in encoding else None,
541
+ losses_b=losses_b,
542
+ outputs_b=outputs_b,
543
+ hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
544
+ attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
545
+ global_attentions_b=None,
546
+ contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
547
+ )
548
+ elif has_pair:
549
+ if not return_dict:
550
+ return [[losses], [outputs], [encoding]]
551
+ return AllOutput(
552
+ losses=losses,
553
+ outputs=outputs,
554
+ hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
555
+ attentions=encoding["attentions"] if "attentions" in encoding else None,
556
+ global_attentions=None,
557
+ contacts=encoding["contacts"] if "contacts" in encoding else None
558
+ )
559
+ else:
560
+ if not return_dict:
561
+ return [[losses_b], [outputs_b], [encoding_b]]
562
+ return AllOutput(
563
+ losses_b=losses_b,
564
+ outputs_b=outputs_b,
565
+ hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
566
+ attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
567
+ global_attentions_b=None,
568
+ contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
569
+ )
570
+
571
+ def predict_contacts(self, input_ids, position_ids=None, token_type_ids=None):
572
+ return self(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, return_contacts=True)["contacts"]
lucaone_gplm_config.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ from transformers import PretrainedConfig
5
+
6
+ class LucaGPLMConfig(PretrainedConfig):
7
+ model_type = "lucagplm"
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=-1,
12
+ pad_token_id=0,
13
+ max_position_embeddings: int = 4096,
14
+ type_vocab_size: int = 2,
15
+ num_hidden_layers: int = 24,
16
+ hidden_size: int = 1280,
17
+ num_attention_heads: int = 20,
18
+ no_position_embeddings: bool = False,
19
+ no_token_type_embeddings: bool = False,
20
+ alphabet: str = "gene_prot",
21
+ token_dropout: bool = True,
22
+ attention_probs_dropout_prob=0.1,
23
+ hidden_dropout_prob=0.1,
24
+ classifier_dropout_prob=0.1,
25
+ use_embed_layer_norm=True,
26
+ use_last_layer_norm=True,
27
+ embed_scale=1.0,
28
+ ignore_index=-100,
29
+ **kwargs
30
+ ):
31
+
32
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
33
+ self.alphabet = alphabet
34
+ self.vocab_size = vocab_size
35
+ self.max_position_embeddings = max_position_embeddings
36
+ self.type_vocab_size = type_vocab_size
37
+ self.no_token_type_embeddings = no_token_type_embeddings
38
+ self.no_position_embeddings = no_position_embeddings
39
+ self.num_hidden_layers = num_hidden_layers
40
+ self.hidden_size = hidden_size
41
+ self.num_attention_heads = num_attention_heads
42
+ self.token_dropout = token_dropout
43
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
44
+ self.hidden_dropout_prob = hidden_dropout_prob
45
+ self.classifier_dropout_prob = classifier_dropout_prob
46
+ self.ignore_index = ignore_index
47
+ self.use_embed_layer_norm = use_embed_layer_norm
48
+ self.use_last_layer_norm = use_last_layer_norm
49
+ self.embed_scale = embed_scale
masked_loss.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: [email protected]
7
+ @tel: 137****6540
8
+ @datetime: 2023/6/28 10:25
9
+ @project: LucaOne
10
+ @file: masked_loss.py
11
+ @desc: masked loss
12
+ '''
13
+ import warnings
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+
18
+ class _MaskedLoss(nn.Module):
19
+ """Base class for masked losses"""
20
+
21
+ def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
22
+ super().__init__()
23
+ self.reduction = reduction
24
+ self.ignore_nans = ignore_nans
25
+ self.ignore_value = ignore_value
26
+
27
+ def forward(self, pred, target, mask=None):
28
+ """Compute a loss between pred and target for given mask.
29
+ Note that this implementation is faster than loss(pred[mask], target[mask])
30
+ for a given loss, and is nan-proof."""
31
+ '''
32
+ if not (target.size() == pred.size()):
33
+ warnings.warn(
34
+ "Using a target size ({}) that is different to the pred size ({}). "
35
+ "This will likely lead to incorrect results due to broadcasting. "
36
+ "Please ensure they have the same size.".format(
37
+ target.size(), pred.size()),
38
+ stacklevel=2,
39
+ )
40
+ '''
41
+ if mask is None and self.ignore_value is not None:
42
+ mask = target != self.ignore_value
43
+ elif mask is None:
44
+ mask = torch.ones_like(target, dtype=bool)
45
+ target_proxy = target
46
+ if self.ignore_nans:
47
+ target_proxy = target.clone()
48
+ nans = torch.isnan(target)
49
+ if nans.any():
50
+ with torch.no_grad():
51
+ mask = mask & ~nans
52
+ target_proxy[nans] = 0
53
+ # full_loss = self.criterion(pred, target_proxy)
54
+ # print("mask shape")
55
+ # print(mask.shape)
56
+ if self.reduction == 'meanmean' and pred.ndim == 3 and pred.shape[-1] == 1:
57
+ # token-level binary classification
58
+ # pred: n , seq_len, 1 -> n * seq_len
59
+ # target: n, seq_len -> n * seq_len
60
+ full_loss = self.criterion(pred.view(-1), target_proxy.view(-1))
61
+ full_loss = torch.reshape(full_loss, (-1, pred.shape[1]))
62
+ # print("ok1")
63
+ elif self.reduction == 'meanmean' and pred.ndim == 3:
64
+ if target.ndim == 3:
65
+ # token-level regression
66
+ # pred: n , seq_len, label_size -> n * seq_len * label_size
67
+ # target: n, seq_len, label_size -> n * seq_len * label_size
68
+ full_loss = self.criterion(pred.view(-1), target_proxy.view(-1))
69
+ full_loss = torch.reshape(full_loss, (-1, pred.shape[1], pred.shape[-1]))
70
+ # print("ok21")
71
+ else:
72
+ # token-level multi classification
73
+ # pred: n , seq_len, label_size -> n * seq_len, label_size
74
+ # target: n, seq_len -> n * seq_len
75
+ full_loss = self.criterion(pred.view(-1, pred.shape[-1]), target_proxy.view(-1))
76
+ full_loss = torch.reshape(full_loss, (-1, pred.shape[1]))
77
+ # print("ok22")
78
+ elif self.reduction == 'meanmean' and pred.ndim == 2 and target.ndim == 2:
79
+ # seq-level multi label
80
+ # pred: n , label_size -> n * label_size
81
+ # target: n, label_size -> n * label_size
82
+ full_loss = self.criterion(pred.view(-1), target_proxy.view(-1))
83
+ full_loss = torch.reshape(full_loss, (-1, pred.shape[1]))
84
+ # print("ok3")
85
+ elif self.reduction == 'meanmean':
86
+ self.reduction = "mean"
87
+ full_loss = self.criterion(pred, target_proxy)
88
+ # print("ok4")
89
+ else:
90
+ full_loss = self.criterion(pred, target_proxy)
91
+ # print("ok5")
92
+
93
+ full_loss[~mask] = 0
94
+ '''
95
+ if not mask.any():
96
+ warnings.warn("Evaluation mask is False everywhere, this might lead to incorrect results.")
97
+ print(full_loss.sum(), mask.to(full_loss.dtype).sum())
98
+ '''
99
+ if self.reduction == 'none':
100
+ return full_loss
101
+ if self.reduction == 'sum':
102
+ return full_loss.sum()
103
+ if self.reduction == 'mean':
104
+ '''
105
+ print("mask:")
106
+ print(mask.to(full_loss.dtype).sum(dim=-1))
107
+ print(mask.to(full_loss.dtype).sum())
108
+ '''
109
+ return full_loss.sum() / (mask.to(full_loss.dtype).sum() + 1e-12)
110
+ if self.reduction == 'meanmean':
111
+ if mask.ndim == 3:
112
+ mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
113
+ '''
114
+ print("mask:")
115
+ print(mask_sum)
116
+ '''
117
+ full_loss = full_loss.sum(dim=-1) / (mask_sum + 1e-12)
118
+ mask_sum = mask_sum.to(torch.bool).sum(dim=-1)
119
+ # print(mask_sum)
120
+ full_loss = full_loss.sum(dim=-1) / (mask_sum + 1e-12)
121
+ mask_sum = mask_sum.to(torch.bool).sum()
122
+ # print(mask_sum)
123
+ loss = full_loss.sum() / (mask_sum + 1e-12)
124
+ else:
125
+ mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
126
+ '''
127
+ print("mask:")
128
+ print(mask_sum)
129
+ print(mask_sum.to(torch.bool).sum())
130
+ '''
131
+ loss = torch.sum(full_loss.sum(dim=-1) / (mask_sum + 1e-12)) / (mask_sum.to(torch.bool).sum() + 1e-12)
132
+ # print(full_loss.sum() / (mask.to(full_loss.dtype).sum() + 1e-12), loss)
133
+ return loss
134
+ if self.reduction in ["summean", "meansum"]:
135
+ if mask.ndim == 3:
136
+ mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
137
+ '''
138
+ print("mask:")
139
+ print(mask_sum)
140
+ '''
141
+ full_loss = full_loss.sum(dim=-1)
142
+ mask_sum = mask_sum.to(torch.bool).sum(dim=-1)
143
+ # print(mask_sum)
144
+ full_loss = full_loss.sum(dim=-1) / (mask_sum + 1e-12)
145
+ mask_sum = mask_sum.to(torch.bool).sum()
146
+ # print(mask_sum)
147
+ loss = full_loss.sum() / (mask_sum + 1e-12)
148
+ else:
149
+ mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
150
+ '''
151
+ print("mask:")
152
+ print(mask_sum)
153
+ print(mask_sum.to(torch.bool).sum())
154
+ '''
155
+ loss = full_loss.sum() / (mask_sum.to(torch.bool).sum() + 1e-12)
156
+ return loss
157
+ return full_loss
158
+
159
+
metrics.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: sanyuan.**@**.com
7
+ @tel: 137****6540
8
+ @datetime: 2022/11/26 21:05
9
+ @project: LucaOne
10
+ @file: metrics.py
11
+ @desc: metrics for binary classification or multi-class classification
12
+ '''
13
+ import csv
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+ plt.rcParams.update({'font.size': 18})
17
+ plt.rcParams['axes.unicode_minus'] = False
18
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, \
19
+ average_precision_score, confusion_matrix, mean_absolute_error, mean_squared_error, r2_score
20
+
21
+
22
+ def topk_accuracy_score(targets, probs, k=3):
23
+ '''
24
+ topk accuracy
25
+ :param targets:
26
+ :param probs:
27
+ :param k:
28
+ :return:
29
+ '''
30
+ # obtain top-k label
31
+ max_k_preds = probs.argsort(axis=1)[:, -k:][:, ::-1]
32
+ a_real = np.resize(targets, (targets.shape[0], 1))
33
+ # obtain the match result
34
+ match_array = np.logical_or.reduce(max_k_preds == a_real, axis=1)
35
+ topk_acc_score = match_array.sum() / match_array.shape[0]
36
+ return topk_acc_score
37
+
38
+
39
+ def multi_class_acc(targets, probs, threshold=0.5):
40
+ if targets.ndim == 2:
41
+ targets = np.argmax(targets, axis=1)
42
+ preds = np.argmax(probs, axis=1)
43
+ return accuracy_score(targets, preds)
44
+
45
+
46
+ def multi_class_precision(targets, probs, average='macro'):
47
+ if targets.ndim == 2:
48
+ targets = np.argmax(targets, axis=1)
49
+ preds = np.argmax(probs, axis=1)
50
+ return precision_score(targets, preds, average=average)
51
+
52
+
53
+ def multi_class_recall(targets, probs, average='macro'):
54
+ if targets.ndim == 2:
55
+ targets = np.argmax(targets, axis=1)
56
+ preds = np.argmax(probs, axis=1)
57
+ return recall_score(targets, preds, average=average)
58
+
59
+
60
+ def multi_class_f1(targets, probs, average='macro'):
61
+ if targets.ndim == 2:
62
+ targets = np.argmax(targets, axis=1)
63
+ preds = np.argmax(probs, axis=1)
64
+ return f1_score(targets, preds, average=average)
65
+
66
+
67
+ def multi_class_roc_auc(targets, probs, average='macro'):
68
+ if targets.ndim == 2:
69
+ targets = np.argmax(targets, axis=1)
70
+ return roc_auc_score(targets, probs, average=average, multi_class='ovr')
71
+
72
+
73
+ def multi_class_pr_auc(targets, probs, average='macro'):
74
+ if targets.ndim == 2:
75
+ targets = np.argmax(targets, axis=1)
76
+ z = probs.shape[1]
77
+ new_targets = np.eye(z)[targets]
78
+ pr_auc = average_precision_score(new_targets, probs, average=average)
79
+ return pr_auc
80
+
81
+
82
+ def metrics_multi_class(targets, probs, average="macro"):
83
+ '''
84
+ metrics of multi-class classification
85
+ :param targets: 1d-array class index (n_samples, )
86
+ :param probs: 2d-array probability (n_samples, m_classes)
87
+ :return:
88
+ '''
89
+ if targets.ndim == 2 and targets.shape[1] > 1:
90
+ targets = np.argmax(targets, axis=1)
91
+ elif targets.ndim == 2 and targets.shape[1] == 1:
92
+ targets = np.squeeze(targets, axis=1)
93
+
94
+ preds = np.argmax(probs, axis=1)
95
+ acc = accuracy_score(targets, preds)
96
+ prec = precision_score(targets, preds, average=average)
97
+ recall = recall_score(targets, preds, average=average)
98
+ f1 = f1_score(targets, preds, average=average)
99
+ result = {
100
+ "acc": round(float(acc), 6),
101
+ "prec": round(float(prec), 6),
102
+ "recall": round(float(recall), 6),
103
+ "f1": round(float(f1), 6)
104
+ }
105
+ result.update({
106
+ "top2_acc": round(float(topk_accuracy_score(targets, probs, k=2)), 6),
107
+ "top3_acc": round(float(topk_accuracy_score(targets, probs, k=3)), 6),
108
+ "top5_acc": round(float(topk_accuracy_score(targets, probs, k=5)), 6),
109
+ "top10_acc": round(float(topk_accuracy_score(targets, probs, k=10)), 6)
110
+ })
111
+ try:
112
+ roc_auc = roc_auc_score(targets, probs, average=average, multi_class='ovr')
113
+ result.update({
114
+ "roc_auc": round(float(roc_auc), 6)
115
+ })
116
+ except Exception as e:
117
+ pass
118
+ try:
119
+ z = probs.shape[1]
120
+ new_targets = np.eye(z)[targets]
121
+ pr_auc = average_precision_score(new_targets, probs, average=average)
122
+ result.update({
123
+ "pr_auc": round(float(pr_auc), 6),
124
+ })
125
+ except Exception as e:
126
+ pass
127
+ return result
128
+
129
+
130
+ def metrics_multi_class_for_pred(targets, preds, probs=None, average="macro", savepath=None):
131
+ '''
132
+ metrcis for multi-class classification
133
+ :param targets: 1d-array class index (n_samples, )
134
+ :param preds: 1d-array class index (n_samples, )
135
+ :return:
136
+ '''
137
+ if targets.ndim == 2 and targets.shape[1] > 1:
138
+ targets = np.argmax(targets, axis=1)
139
+ elif targets.ndim == 2 and targets.shape[1] == 1:
140
+ targets = np.squeeze(targets, axis=1)
141
+
142
+ acc = accuracy_score(targets, preds)
143
+ prec = precision_score(targets, preds, average=average)
144
+ recall = recall_score(targets, preds, average=average)
145
+ f1 = f1_score(y_true=targets, y_pred=preds, average=average)
146
+ result = {
147
+ "acc": round(float(acc), 6),
148
+ "prec": round(float(prec), 6),
149
+ "recall": round(float(recall), 6),
150
+ "f1": round(float(f1), 6)
151
+ }
152
+ try:
153
+ roc_auc = roc_auc_score(targets, probs, average=average, multi_class='ovr')
154
+ result.update({
155
+ "roc_auc": round(float(roc_auc), 6)
156
+ })
157
+ except Exception as e:
158
+ pass
159
+ try:
160
+ z = probs.shape[1]
161
+ new_targets = np.eye(z)[targets]
162
+ pr_auc = average_precision_score(new_targets, probs, average=average)
163
+ result.update({
164
+ "pr_auc": round(float(pr_auc), 6),
165
+ })
166
+ except Exception as e:
167
+ pass
168
+ return result
169
+
170
+
171
+ def metrics_regression(targets, preds):
172
+ '''
173
+ metrcis for regression
174
+ :param targets: 1d-array class index (n_samples, )
175
+ :param preds: 1d-array class index (n_samples, )
176
+ :return:
177
+ '''
178
+ mae = mean_absolute_error(targets, preds)
179
+ mse = mean_squared_error(targets, preds)
180
+ r2 = r2_score(targets, preds)
181
+ return {
182
+ "mae": round(float(mae), 6),
183
+ "mse": round(float(mse), 6),
184
+ "r2": round(float(r2), 6)
185
+ }
186
+
187
+
188
+ def transform(targets, probs, threshold):
189
+ '''
190
+ metrics of binary classification
191
+ :param targets: 1d-array class index (n_samples, )
192
+ :param probs: 1d-array larger class probability (n_samples, )
193
+ :param threshold: 0-1 prob threshokd
194
+ :return:
195
+ '''
196
+ if targets.ndim == 2:
197
+ if targets.shape[1] == 2: # [[0, 1], [1, 0]]
198
+ targets = np.argmax(targets, axis=1)
199
+ else: # [[1], [0]]
200
+ targets = targets.flatten()
201
+ if probs.ndim == 2:
202
+ if probs.shape[1] == 2: # [[0.1, 0.9], [0.9, 0.1]]
203
+ preds = np.argmax(probs, axis=1)
204
+ probs = probs[:, 1].flatten()
205
+ else: # [[0.9], [0.1]]
206
+ preds = (probs >= threshold).astype(int).flatten()
207
+ probs = probs.flatten()
208
+ else:
209
+ preds = (probs >= threshold).astype(int)
210
+ return targets, probs, preds
211
+
212
+
213
+ def binary_acc(targets, probs, threshold=0.5):
214
+ targets, probs, preds = transform(targets, probs, threshold)
215
+ return accuracy_score(targets, preds)
216
+
217
+
218
+ def binary_precision(targets, probs, threshold=0.5, average='binary'):
219
+ targets, probs, preds = transform(targets, probs, threshold)
220
+ return precision_score(targets, preds, average=average)
221
+
222
+
223
+ def binary_recall(targets, probs, threshold=0.5, average='binary'):
224
+ targets, probs, preds = transform(targets, probs, threshold)
225
+ return recall_score(targets, preds, average=average)
226
+
227
+
228
+ def binary_f1(targets, probs, threshold=0.5, average='binary'):
229
+ targets, probs, preds = transform(targets, probs, threshold)
230
+ return f1_score(targets, preds, average=average)
231
+
232
+
233
+ def binary_roc_auc(targets, probs, threshold=0.5, average='macro'):
234
+ targets, probs, preds = transform(targets, probs, threshold)
235
+ return roc_auc_score(targets, probs, average=average)
236
+
237
+
238
+ def binary_pr_auc(targets, probs, threshold=0.5, average='macro'):
239
+ targets, probs, preds = transform(targets, probs, threshold)
240
+ return average_precision_score(targets, probs, average=average)
241
+
242
+
243
+ def binary_confusion_matrix(targets, probs, threshold=0.5, savepath=None):
244
+ targets, probs, preds = transform(targets, probs, threshold)
245
+ cm_obj = confusion_matrix(targets, preds, labels=[0, 1])
246
+ plot_confusion_matrix_for_binary_class(targets, preds, cm=cm_obj, savepath=savepath)
247
+ tn, fp, fn, tp = cm_obj.ravel()
248
+ cm = {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
249
+ return cm
250
+
251
+
252
+ def metrics_binary(targets, probs, threshold=0.5, average="binary", savepath=None):
253
+ '''
254
+ metrics for binary classification
255
+ :param targets: 1d-array class index (n_samples, )
256
+ :param probs: 1d-array larger class probability (n_samples, )
257
+ :param threshold: 0-1 prob threshold
258
+ :return:
259
+ '''
260
+ if targets.ndim == 2:
261
+ if targets.shape[1] == 2: # [[0, 1], [1, 0]]
262
+ targets = np.argmax(targets, axis=1)
263
+ else: # [[1], [0]]
264
+ targets = targets.flatten()
265
+ if probs.ndim == 2:
266
+ if probs.shape[1] == 2: # [[0.1, 0.9], [0.9, 0.1]]
267
+ preds = np.argmax(probs, axis=1)
268
+ probs = probs[:, 1].flatten()
269
+ else: # [[0.9], [0.1]]
270
+ preds = (probs >= threshold).astype(int).flatten()
271
+ probs = probs.flatten()
272
+ else:
273
+ preds = (probs >= threshold).astype(int)
274
+ acc = accuracy_score(targets, preds)
275
+ prec = precision_score(targets, preds, average=average)
276
+ recall = recall_score(targets, preds, average=average)
277
+ f1 = f1_score(targets, preds, average=average)
278
+ result = {
279
+ "acc": round(float(acc), 6),
280
+ "prec": round(float(prec), 6),
281
+ "recall": round(float(recall), 6),
282
+ "f1": round(float(f1), 6)
283
+ }
284
+ try:
285
+ roc_auc = roc_auc_score(targets, probs, average="macro")
286
+ result.update({
287
+ "roc_auc": round(float(roc_auc), 6)
288
+ })
289
+ except Exception as e:
290
+ pass
291
+ try:
292
+ pr_auc = average_precision_score(targets, probs, average="macro")
293
+ result.update({
294
+ "pr_auc": round(float(pr_auc), 6)
295
+ })
296
+ except Exception as e:
297
+ pass
298
+ try:
299
+ cm_obj = confusion_matrix(targets, preds, labels=[0, 1])
300
+ plot_confusion_matrix_for_binary_class(targets, preds, cm=cm_obj, savepath=savepath)
301
+ tn, fp, fn, tp = cm_obj.ravel()
302
+ cm = {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
303
+ result.update({
304
+ "confusion_matrix": cm
305
+ })
306
+ except Exception as e:
307
+ pass
308
+ # add mcc
309
+ try:
310
+ tn, fp, fn, tp = cm["tn"], cm["fp"], cm["fn"], cm["tp"]
311
+ mcc = (tn*tp - fp*fn) / (((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)) ** 0.5)
312
+ result.update({
313
+ "mcc": round(mcc, 6)
314
+ })
315
+ except Exception as e:
316
+ pass
317
+ return result
318
+
319
+
320
+ def metrics_binary_for_pred(targets, preds, probs=None, average="binary", savepath=None):
321
+ '''
322
+ metrics for binary classification
323
+ :param targets: 1d-array class index (n_samples, )
324
+ :param preds: 1d-array larger class index (n_samples, )
325
+ :return:
326
+ '''
327
+ if targets.ndim == 2:
328
+ if targets.shape[1] == 2: # [[1, 0], [0, 1]
329
+ targets = np.argmax(targets, axis=1)
330
+ else: # [[1], [0]]
331
+ targets = targets.flatten()
332
+ if preds.ndim == 2:
333
+ if preds.shape[1] == 2: # [[0.9, 0.1], [0.1, 0.9]]
334
+ preds = np.argmax(preds, axis=1)
335
+ else: # [[0], [1]]
336
+ preds = preds.flatten()
337
+ cm_obj = confusion_matrix(targets, preds, labels=[0, 1])
338
+ plot_confusion_matrix_for_binary_class(targets, preds, cm=cm_obj, savepath=savepath)
339
+ tn, fp, fn, tp = cm_obj.ravel()
340
+ cm = {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
341
+ if len(np.unique(targets)) > 1:
342
+ acc = accuracy_score(targets, preds)
343
+ prec = precision_score(targets, preds, average=average)
344
+ recall = recall_score(targets, preds, average=average)
345
+ f1 = f1_score(y_true=targets, y_pred=preds, average=average)
346
+ result = {
347
+ "acc": round(float(acc), 6),
348
+ "prec": round(float(prec), 6),
349
+ "recall": round(float(recall), 6),
350
+ "f1": round(float(f1), 6)
351
+ }
352
+ else:
353
+
354
+ result = {
355
+ "acc": round(float((cm["tp"] + cm["tn"]) / (cm["tp"] + cm["tn"] + cm["fp"] + cm["fn"])), 6),
356
+ "prec": round(float(cm["tp"]/(cm["tp"] + cm["fp"]) if cm["tp"] + cm["fp"] > 0 else 1.0), 6),
357
+ "recall": round(float(cm["tp"]/(cm["tp"] + cm["fn"]) if cm["tp"] + cm["fn"] > 0 else 1.0), 6),
358
+ }
359
+ result["f1"] = 2 * result["prec"] * result["recall"] / (result["prec"] + result["recall"])
360
+
361
+ try:
362
+ pr_auc = average_precision_score(targets, probs, average="macro")
363
+ result.update({
364
+ "pr_auc": round(float(pr_auc), 6)
365
+ })
366
+ except Exception as e:
367
+ pass
368
+ try:
369
+ roc_auc = roc_auc_score(targets, probs, average="macro")
370
+ result.update({
371
+ "roc_auc": round(float(roc_auc), 6)
372
+ })
373
+ except Exception as e:
374
+ pass
375
+ try:
376
+ tn, fp, fn, tp = cm["tn"], cm["fp"], cm["fn"], cm["tp"]
377
+ mcc = (tn*tp - fp*fn) / (((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)) ** 0.5)
378
+ result.update({
379
+ "mcc": round(mcc, 6)
380
+ })
381
+ except Exception as e:
382
+ pass
383
+ result.update({
384
+ "confusion_matrix": cm
385
+ })
386
+ return result
387
+
388
+
389
+ def write_error_samples_multi_class(filepath, samples, input_indexs, input_id_2_names, output_id_2_name, targets, probs,
390
+ use_other_diags=False, use_other_operas=False, use_checkin_department=False):
391
+ '''
392
+ write the bad cases of multi-class classification
393
+ :param filepath:
394
+ :param samples:
395
+ :param input_indexs:
396
+ :param input_id_2_names:
397
+ :param output_id_2_name:
398
+ :param targets:
399
+ :param probs:
400
+ :param use_other_diags:
401
+ :param use_other_operas:
402
+ :param use_checkin_department:
403
+ :return:
404
+ '''
405
+ targets = np.argmax(targets, axis=1)
406
+ preds = np.argmax(probs, axis=1)
407
+ with open(filepath, "w") as fp:
408
+ writer = csv.writer(fp)
409
+ writer.writerow(["score", "y_true", "y_pred", "inputs"])
410
+ for i in range(len(targets)):
411
+ target = targets[i]
412
+ pred = preds[i]
413
+ score = 1
414
+ if target != pred:
415
+ score = 0
416
+ if output_id_2_name:
417
+ target_label = output_id_2_name[target]
418
+ pred_label = output_id_2_name[pred]
419
+ else:
420
+ target_label = target
421
+ pred_label = pred
422
+ sample = samples[i]
423
+ if input_id_2_names:
424
+ new_sample = []
425
+ for idx, input_index in enumerate(input_indexs):
426
+ if input_index == 3 and not use_checkin_department:
427
+ input_index = 12
428
+ new_sample.append([input_id_2_names[idx][v] for v in sample[input_index]])
429
+ if (input_index == 6 and use_other_diags) or (input_index == 8 and use_other_operas) or (input_index == 10 and use_other_diags):
430
+ new_sample.append([input_id_2_names[idx][v] for v in sample[input_index + 1]])
431
+ else:
432
+ new_sample = sample
433
+ row = [score, target_label, pred_label, new_sample]
434
+ writer.writerow(row)
435
+
436
+
437
+ def write_error_samples_binary(filepath, samples, input_indexs, input_id_2_names, targets, probs, threshold=0.5,
438
+ use_other_diags=False, use_other_operas=False, use_checkin_department=False):
439
+ '''
440
+ write bad cases of binary classification
441
+ :param filepath:
442
+ :param samples:
443
+ :param input_indexs:
444
+ :param input_id_2_names:
445
+ :param targets:
446
+ :param probs:
447
+ :param threshold:
448
+ :param use_other_diags:
449
+ :param use_other_operas:
450
+ :param use_checkin_department:
451
+ :return:
452
+ '''
453
+ with open(filepath, "w") as fp:
454
+ writer = csv.writer(fp)
455
+ writer.writerow(["score", "y_true", "y_pred", "inputs"])
456
+ for i in range(len(targets)):
457
+ target = targets[i][0]
458
+ if target != 1:
459
+ target = 1
460
+ prob = probs[i][0]
461
+ if prob >= threshold:
462
+ pred = 1
463
+ else:
464
+ pred = 0
465
+ score = 1
466
+ if target != pred:
467
+ score = 0
468
+ target_label = "True" if target == 1 else "False"
469
+ pred_label = "True" if target == 1 else "False"
470
+ sample = samples[i]
471
+ if input_id_2_names:
472
+ new_sample = []
473
+ for idx, input_index in enumerate(input_indexs):
474
+ if input_index == 3 and not use_checkin_department:
475
+ input_index = 12
476
+ new_sample.append([input_id_2_names[idx][v] for v in sample[input_index]])
477
+ if (input_index == 6 and use_other_diags) or (input_index == 8 and use_other_operas) or (input_index == 10 and use_other_diags):
478
+ new_sample.append([input_id_2_names[idx][v] for v in sample[input_index + 1]])
479
+ else:
480
+ new_sample = sample
481
+ row = [score, target_label, pred_label, new_sample]
482
+ writer.writerow(row)
483
+
484
+
485
+ def plot_confusion_matrix_for_binary_class(targets, preds, cm=None, savepath=None):
486
+ '''
487
+ :param targets: ground truth
488
+ :param preds: prediction probs
489
+ :param cm: confusion matrix
490
+ :param savepath: confusion matrix picture savepth
491
+ '''
492
+
493
+ plt.figure(figsize=(40, 20), dpi=100)
494
+ if cm is None:
495
+ cm = confusion_matrix(targets, preds, labels=[0, 1])
496
+
497
+ plt.matshow(cm, cmap=plt.cm.Oranges)
498
+ plt.colorbar()
499
+
500
+ for x in range(len(cm)):
501
+ for y in range(len(cm)):
502
+ plt.annotate(cm[x, y], xy=(y, x), verticalalignment='center', horizontalalignment='center')
503
+ plt.ylabel('True')
504
+ plt.xlabel('Prediction')
505
+ if savepath:
506
+ plt.savefig(savepath, dpi=100)
507
+ else:
508
+ plt.show()
509
+ plt.close("all")
510
+
511
+
512
+ if __name__ == "__main__":
513
+ '''multi_class'''
514
+ targets = np.array([0, 1, 2, 1, 3])
515
+ probs = np.array([[0.9, 0.05, 0.05, 0], [0.5, 0.45, 0.05, 0], [0.4, 0.05, 0.55, 0], [0.1, 0.55, 0.25, 0.1], [0.4, 0.25, 0.35, 0]])
516
+ print(metrics_multi_class(targets, probs))
517
+
518
+ targets = np.array([0, 1, 2, 3, 3])
519
+ probs = np.array([[0.9, 0.05, 0.05, 0], [0.5, 0.45, 0.05, 0], [0.4, 0.05, 0.55, 0], [0.1, 0.25, 0.25, 0.4], [0.1, 0.25, 0.25, 0.4]])
520
+ print(metrics_multi_class(targets, probs))
521
+ targets = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 1]])
522
+ probs = np.array([[0.9, 0.05, 0.05, 0], [0.5, 0.45, 0.05, 0], [0.4, 0.05, 0.55, 0], [0.1, 0.25, 0.25, 0.4], [0.1, 0.25, 0.25, 0.4]])
523
+ print(metrics_multi_class(targets, probs))
524
+
525
+ '''binary'''
526
+ targets = np.array([0, 0, 1, 1])
527
+ probs = np.array([[0.1], [0.1], [0.1], [0.9]])
528
+ print(metrics_binary(targets, probs))
529
+
530
+ targets = np.array([[0], [0], [1], [1]])
531
+ probs = np.array([[0.1], [0.1], [0.1], [0.9]])
532
+ print(metrics_binary(targets, probs))
533
+
534
+ targets = np.array([0, 0, 1, 1])
535
+ probs = np.array([[0.1, 0.1, 0.1, 0.9]])
536
+ print(metrics_binary(targets, probs))
537
+
538
+ targets = np.array([0, 0, 1, 1])
539
+ probs = np.array([0.1, 0.1, 0.1, 0.9])
540
+ print(metrics_binary(targets, probs))
541
+
542
+ targets = np.array([0, 1, 2, 1, 3])
543
+ probs = np.array([[0.9, 0.05, 0.05, 0], [0.5, 0.45, 0.05, 0], [0.4, 0.05, 0.55, 0], [0.1, 0.55, 0.25, 0.1], [0.4, 0.25, 0.25, 0.1]])
544
+ z = probs.shape[1]
545
+ # print(z)
546
+ print(np.eye(z))
547
+ new_targets = np.eye(z)[targets]
548
+ print(new_targets)
549
+
model_utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ from typing import Optional, Tuple
5
+ from dataclasses import dataclass
6
+ from transformers.modeling_outputs import ModelOutput
7
+ import sys, copy, math
8
+
9
+ from .pooling import *
10
+ from .loss import *
11
+
12
+ @dataclass
13
+ class AllOutput(ModelOutput):
14
+ losses: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
15
+ outputs: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
16
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
17
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
18
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
19
+ global_attentions: Optional[Tuple[torch.FloatTensor]] = None
20
+ contacts: Optional[Tuple[torch.FloatTensor]] = None
21
+ losses_b: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
22
+ outputs_b: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
23
+ hidden_states_b: Optional[Tuple[torch.FloatTensor]] = None
24
+ attentions_b: Optional[Tuple[torch.FloatTensor]] = None
25
+ cross_attentions_b: Optional[Tuple[torch.FloatTensor]] = None
26
+ global_attentions_b: Optional[Tuple[torch.FloatTensor]] = None
27
+ contacts_b: Optional[Tuple[torch.FloatTensor]] = None
28
+ pair_outputs: Optional[Tuple[torch.FloatTensor]] = None
29
+ pair_losses: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
30
+
31
+
32
+ def create_pooler(task_level_type, task_level_name, config, args):
33
+ '''
34
+ pooler building
35
+ :param task_level_type:
36
+ :param task_level_name:
37
+ :param config:
38
+ :param args:
39
+ :return:
40
+ '''
41
+ hidden_size = config.hidden_size[task_level_type][task_level_name]
42
+ pooling_type = args.pooling_type[task_level_type][task_level_name]
43
+
44
+ if pooling_type == "max":
45
+ return GlobalMaskMaxPooling1D()
46
+ elif pooling_type == "sum":
47
+ return GlobalMaskSumPooling1D(axis=1)
48
+ elif pooling_type == "avg":
49
+ return GlobalMaskAvgPooling1D()
50
+ elif pooling_type == "attention":
51
+ return GlobalMaskContextAttentionPooling1D(embed_size=hidden_size)
52
+ elif pooling_type == "context_attention":
53
+ return GlobalMaskContextAttentionPooling1D(embed_size=hidden_size)
54
+ elif pooling_type == "weighted_attention":
55
+ return GlobalMaskWeightedAttentionPooling1D(embed_size=hidden_size)
56
+ elif pooling_type == "value_attention":
57
+ return GlobalMaskValueAttentionPooling1D(embed_size=hidden_size)
58
+ elif pooling_type == "transformer":
59
+ copy_config = copy.deepcopy(config)
60
+ copy_config.hidden_size = hidden_size
61
+ return GlobalMaskTransformerPooling1D(copy_config)
62
+ else:
63
+ return None
64
+
65
+
66
+ def create_output_loss_lucagplm(task_level_type, task_level_name, config):
67
+ '''not cls module'''
68
+ if not hasattr(config, "sigmoid"):
69
+ config.sigmoid = {task_level_type: {}}
70
+ elif task_level_type not in config.sigmoid:
71
+ config.sigmoid[task_level_type] = {}
72
+ config.sigmoid[task_level_type][task_level_name] = False if config.output_mode[task_level_type][task_level_name] \
73
+ in ["multi_class", "multi-class", "regression"] else True
74
+ # 特殊情况,contact需要是sigmoid, 需要思考strcuture需不需要sigmoid
75
+ if task_level_name == "prot_contact":
76
+ config.sigmoid[task_level_type][task_level_name] = True
77
+ config.num_labels = config.label_size[task_level_type][task_level_name]
78
+ if task_level_type in ["token_level", "whole_level"]:
79
+ return_types = ["output", "loss"]
80
+ else:
81
+ return_types = ["dropout", "hidden_layer", "hidden_act", "classifier", "output", "loss"]
82
+ return create_loss_function(config,
83
+ task_level_type=task_level_type,
84
+ task_level_name=task_level_name,
85
+ sigmoid=config.sigmoid[task_level_type][task_level_name],
86
+ output_mode=config.output_mode[task_level_type][task_level_name],
87
+ num_labels=config.num_labels,
88
+ loss_type=config.loss_type[task_level_type][task_level_name],
89
+ ignore_index=config.ignore_index,
90
+ pair_level=True if task_level_type == "pair_level" else False,
91
+ return_types=return_types)
92
+
93
+
94
+ def create_output_loss(task_level_type, task_level_name, cls_module, config, args):
95
+ cls = None
96
+ if task_level_type in ["token_level", "whole_level"]:
97
+ cls = cls_module(config)
98
+ dropout, hidden_layer, hidden_act, classifier, output, loss_fct = create_output_loss_lucagplm(task_level_type, task_level_name, config, args)
99
+ return cls, dropout, hidden_layer, hidden_act, classifier, output, loss_fct
modeling_bert.py ADDED
@@ -0,0 +1,1917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: sanyuan.**@**.com
7
+ @tel: 137****6540
8
+ @datetime: 2022/12/2 09:38
9
+ @project: LucaOneTasks
10
+ @file: modeling_bert
11
+ @desc: transformer layers
12
+ '''
13
+ import math
14
+ import os
15
+ import warnings
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+ from packaging import version
22
+ from torch import nn
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+
25
+ from transformers.activations import ACT2FN
26
+
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutputWithPastAndCrossAttentions,
29
+ BaseModelOutputWithPoolingAndCrossAttentions,
30
+ CausalLMOutputWithCrossAttentions,
31
+ MaskedLMOutput,
32
+ MultipleChoiceModelOutput,
33
+ NextSentencePredictorOutput,
34
+ QuestionAnsweringModelOutput,
35
+ SequenceClassifierOutput,
36
+ TokenClassifierOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
40
+ from transformers.utils import (
41
+ ModelOutput,
42
+ add_code_sample_docstrings,
43
+ add_start_docstrings,
44
+ add_start_docstrings_to_model_forward,
45
+ logging,
46
+ replace_return_docstrings,
47
+ )
48
+ from transformers.models.bert.configuration_bert import BertConfig
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+ _CHECKPOINT_FOR_DOC = "bert-base-uncased"
54
+ _CONFIG_FOR_DOC = "BertConfig"
55
+ _TOKENIZER_FOR_DOC = "BertTokenizer"
56
+
57
+ # TokenClassification docstring
58
+ _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
59
+ _TOKEN_CLASS_EXPECTED_OUTPUT = (
60
+ "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
61
+ )
62
+ _TOKEN_CLASS_EXPECTED_LOSS = 0.01
63
+
64
+ # QuestionAnswering docstring
65
+ _CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2"
66
+ _QA_EXPECTED_OUTPUT = "'a nice puppet'"
67
+ _QA_EXPECTED_LOSS = 7.41
68
+ _QA_TARGET_START_INDEX = 14
69
+ _QA_TARGET_END_INDEX = 15
70
+
71
+ # SequenceClassification docstring
72
+ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
73
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
74
+ _SEQ_CLASS_EXPECTED_LOSS = 0.01
75
+
76
+
77
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
78
+ "bert-base-uncased",
79
+ "bert-large-uncased",
80
+ "bert-base-cased",
81
+ "bert-large-cased",
82
+ "bert-base-multilingual-uncased",
83
+ "bert-base-multilingual-cased",
84
+ "bert-base-chinese",
85
+ "bert-base-german-cased",
86
+ "bert-large-uncased-whole-word-masking",
87
+ "bert-large-cased-whole-word-masking",
88
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
89
+ "bert-large-cased-whole-word-masking-finetuned-squad",
90
+ "bert-base-cased-finetuned-mrpc",
91
+ "bert-base-german-dbmdz-cased",
92
+ "bert-base-german-dbmdz-uncased",
93
+ "cl-tohoku/bert-base-japanese",
94
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
95
+ "cl-tohoku/bert-base-japanese-char",
96
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
97
+ "TurkuNLP/bert-base-finnish-cased-v1",
98
+ "TurkuNLP/bert-base-finnish-uncased-v1",
99
+ "wietsedv/bert-base-dutch-cased",
100
+ # See all BERT models at https://huggingface.co/models?filter=bert
101
+ ]
102
+
103
+
104
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
105
+ """Load tf checkpoints in a pytorch model."""
106
+ try:
107
+ import re
108
+
109
+ import numpy as np
110
+ import tensorflow as tf
111
+ except ImportError:
112
+ logger.error(
113
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
114
+ "https://www.tensorflow.org/install/ for installation instructions."
115
+ )
116
+ raise
117
+ tf_path = os.path.abspath(tf_checkpoint_path)
118
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
119
+ # Load weights from TF model
120
+ init_vars = tf.train.list_variables(tf_path)
121
+ names = []
122
+ arrays = []
123
+ for name, shape in init_vars:
124
+ logger.info(f"Loading TF weight {name} with shape {shape}")
125
+ array = tf.train.load_variable(tf_path, name)
126
+ names.append(name)
127
+ arrays.append(array)
128
+
129
+ for name, array in zip(names, arrays):
130
+ name = name.split("/")
131
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
132
+ # which are not required for using pretrained model
133
+ if any(
134
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
135
+ for n in name
136
+ ):
137
+ logger.info(f"Skipping {'/'.join(name)}")
138
+ continue
139
+ pointer = model
140
+ for m_name in name:
141
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
142
+ scope_names = re.split(r"_(\d+)", m_name)
143
+ else:
144
+ scope_names = [m_name]
145
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
146
+ pointer = getattr(pointer, "weight")
147
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
148
+ pointer = getattr(pointer, "bias")
149
+ elif scope_names[0] == "output_weights":
150
+ pointer = getattr(pointer, "weight")
151
+ elif scope_names[0] == "squad":
152
+ pointer = getattr(pointer, "classifier")
153
+ else:
154
+ try:
155
+ pointer = getattr(pointer, scope_names[0])
156
+ except AttributeError:
157
+ logger.info(f"Skipping {'/'.join(name)}")
158
+ continue
159
+ if len(scope_names) >= 2:
160
+ num = int(scope_names[1])
161
+ pointer = pointer[num]
162
+ if m_name[-11:] == "_embeddings":
163
+ pointer = getattr(pointer, "weight")
164
+ elif m_name == "kernel":
165
+ array = np.transpose(array)
166
+ try:
167
+ if pointer.shape != array.shape:
168
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
169
+ except AssertionError as e:
170
+ e.args += (pointer.shape, array.shape)
171
+ raise
172
+ logger.info(f"Initialize PyTorch weight {name}")
173
+ pointer.data = torch.from_numpy(array)
174
+ return model
175
+
176
+
177
+ class BertEmbeddings(nn.Module):
178
+ """Construct the embeddings from word, position and token_type embeddings."""
179
+
180
+ def __init__(self, config):
181
+ super().__init__()
182
+ if hasattr(config, "no_token_embeddings"):
183
+ self.no_token_embeddings = config.no_token_embeddings
184
+ else:
185
+ self.no_token_embeddings = False
186
+ if not self.no_token_embeddings:
187
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
188
+ if hasattr(config, "no_position_embeddings"):
189
+ self.no_position_embeddings = config.no_position_embeddings
190
+ else:
191
+ self.no_position_embeddings = False
192
+ if hasattr(config, "no_token_type_embeddings"):
193
+ self.no_token_type_embeddings = config.no_token_type_embeddings
194
+ else:
195
+ self.no_token_type_embeddings = False
196
+ if not self.no_position_embeddings:
197
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
198
+ if not self.no_token_type_embeddings:
199
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
200
+
201
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
202
+ # any TensorFlow checkpoint file
203
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
204
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
205
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
206
+ if not self.no_position_embeddings:
207
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
208
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
209
+ if not self.no_token_type_embeddings and not self.no_position_embeddings:
210
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
211
+ self.register_buffer(
212
+ "token_type_ids",
213
+ torch.zeros(self.position_ids.size(), dtype=torch.long),
214
+ persistent=False,
215
+ )
216
+
217
+ def forward(
218
+ self,
219
+ input_ids: Optional[torch.LongTensor] = None,
220
+ token_type_ids: Optional[torch.LongTensor] = None,
221
+ position_ids: Optional[torch.LongTensor] = None,
222
+ inputs_embeds: Optional[torch.FloatTensor] = None,
223
+ past_key_values_length: int = 0,
224
+ ) -> torch.Tensor:
225
+ if input_ids is not None:
226
+ input_shape = input_ids.size()
227
+ else:
228
+ input_shape = inputs_embeds.size()[:-1]
229
+
230
+ seq_length = input_shape[1]
231
+
232
+ if not self.no_position_embeddings and position_ids is None :
233
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
234
+
235
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
236
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
237
+ # issue #5664
238
+ if not self.no_token_type_embeddings and token_type_ids is None:
239
+ if hasattr(self, "token_type_ids"):
240
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
241
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
242
+ token_type_ids = buffered_token_type_ids_expanded
243
+ else:
244
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device if input_ids is not None else inputs_embeds.device)
245
+ if self.no_token_embeddings and inputs_embeds is None:
246
+ raise Exception("The model has not token_embeddings layer, the inputs_embeds cannot None")
247
+
248
+ if inputs_embeds is None:
249
+ inputs_embeds = self.word_embeddings(input_ids)
250
+ embeddings = inputs_embeds
251
+
252
+ if not self.no_token_type_embeddings:
253
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
254
+ embeddings += token_type_embeddings
255
+
256
+ if not self.no_position_embeddings and self.position_embedding_type == "absolute":
257
+ position_embeddings = self.position_embeddings(position_ids)
258
+ embeddings += position_embeddings
259
+
260
+ embeddings = self.LayerNorm(embeddings)
261
+ embeddings = self.dropout(embeddings)
262
+ return embeddings
263
+
264
+
265
+ class BertSelfAttention(nn.Module):
266
+ def __init__(self, config, position_embedding_type=None):
267
+ super().__init__()
268
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
269
+ raise ValueError(
270
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
271
+ f"heads ({config.num_attention_heads})"
272
+ )
273
+
274
+ self.num_attention_heads = config.num_attention_heads
275
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
276
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
277
+
278
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
279
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
280
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
281
+
282
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
283
+ self.position_embedding_type = position_embedding_type or getattr(
284
+ config, "position_embedding_type", "absolute"
285
+ )
286
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
287
+ self.max_position_embeddings = config.max_position_embeddings
288
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
289
+
290
+ self.is_decoder = config.is_decoder
291
+
292
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
293
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
294
+ x = x.view(new_x_shape)
295
+ return x.permute(0, 2, 1, 3)
296
+
297
+ def forward(
298
+ self,
299
+ hidden_states: torch.Tensor,
300
+ attention_mask: Optional[torch.FloatTensor] = None,
301
+ head_mask: Optional[torch.FloatTensor] = None,
302
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
303
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
304
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
305
+ output_attentions: Optional[bool] = False,
306
+ ) -> Tuple[torch.Tensor]:
307
+ mixed_query_layer = self.query(hidden_states)
308
+
309
+ # If this is instantiated as a cross-attention module, the keys
310
+ # and values come from an encoder; the attention mask needs to be
311
+ # such that the encoder's padding tokens are not attended to.
312
+ is_cross_attention = encoder_hidden_states is not None
313
+
314
+ if is_cross_attention and past_key_value is not None:
315
+ # reuse k,v, cross_attentions
316
+ key_layer = past_key_value[0]
317
+ value_layer = past_key_value[1]
318
+ attention_mask = encoder_attention_mask
319
+ elif is_cross_attention:
320
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
321
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
322
+ attention_mask = encoder_attention_mask
323
+ elif past_key_value is not None:
324
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
325
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
326
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
327
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
328
+ else:
329
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
330
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
331
+
332
+ query_layer = self.transpose_for_scores(mixed_query_layer)
333
+
334
+ if self.is_decoder:
335
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
336
+ # Further calls to cross_attention layer can then reuse all cross-attention
337
+ # key/value_states (first "if" case)
338
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
339
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
340
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
341
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
342
+ past_key_value = (key_layer, value_layer)
343
+
344
+ # Take the dot product between "query" and "key" to get the raw attention scores.
345
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
346
+
347
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
348
+ seq_length = hidden_states.size()[1]
349
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
350
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
351
+ distance = position_ids_l - position_ids_r
352
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
353
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
354
+
355
+ if self.position_embedding_type == "relative_key":
356
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
357
+ attention_scores = attention_scores + relative_position_scores
358
+ elif self.position_embedding_type == "relative_key_query":
359
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
360
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
361
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
362
+
363
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
364
+ if attention_mask is not None:
365
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
366
+ attention_scores = attention_scores + attention_mask
367
+
368
+ # Normalize the attention scores to probabilities.
369
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
370
+
371
+ # This is actually dropping out entire tokens to attend to, which might
372
+ # seem a bit unusual, but is taken from the original Transformer paper.
373
+ attention_probs = self.dropout(attention_probs)
374
+
375
+ # Mask heads if we want to
376
+ if head_mask is not None:
377
+ attention_probs = attention_probs * head_mask
378
+
379
+ context_layer = torch.matmul(attention_probs, value_layer)
380
+
381
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
382
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
383
+ context_layer = context_layer.view(new_context_layer_shape)
384
+
385
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
386
+
387
+ if self.is_decoder:
388
+ outputs = outputs + (past_key_value,)
389
+ return outputs
390
+
391
+
392
+ class BertSelfOutput(nn.Module):
393
+ def __init__(self, config):
394
+ super().__init__()
395
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
396
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
397
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
398
+
399
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
400
+ hidden_states = self.dense(hidden_states)
401
+ hidden_states = self.dropout(hidden_states)
402
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
403
+ return hidden_states
404
+
405
+
406
+ class BertAttention(nn.Module):
407
+ def __init__(self, config, position_embedding_type=None):
408
+ super().__init__()
409
+ self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
410
+ self.output = BertSelfOutput(config)
411
+ self.pruned_heads = set()
412
+
413
+ def prune_heads(self, heads):
414
+ if len(heads) == 0:
415
+ return
416
+ heads, index = find_pruneable_heads_and_indices(
417
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
418
+ )
419
+
420
+ # Prune linear layers
421
+ self.self.query = prune_linear_layer(self.self.query, index)
422
+ self.self.key = prune_linear_layer(self.self.key, index)
423
+ self.self.value = prune_linear_layer(self.self.value, index)
424
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
425
+
426
+ # Update hyper params and store pruned heads
427
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
428
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
429
+ self.pruned_heads = self.pruned_heads.union(heads)
430
+
431
+ def forward(
432
+ self,
433
+ hidden_states: torch.Tensor,
434
+ attention_mask: Optional[torch.FloatTensor] = None,
435
+ head_mask: Optional[torch.FloatTensor] = None,
436
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
437
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
438
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
439
+ output_attentions: Optional[bool] = False,
440
+ ) -> Tuple[torch.Tensor]:
441
+ self_outputs = self.self(
442
+ hidden_states,
443
+ attention_mask,
444
+ head_mask,
445
+ encoder_hidden_states,
446
+ encoder_attention_mask,
447
+ past_key_value,
448
+ output_attentions,
449
+ )
450
+ attention_output = self.output(self_outputs[0], hidden_states)
451
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
452
+ return outputs
453
+
454
+
455
+ class BertIntermediate(nn.Module):
456
+ def __init__(self, config):
457
+ super().__init__()
458
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
459
+ if isinstance(config.hidden_act, str):
460
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
461
+ else:
462
+ self.intermediate_act_fn = config.hidden_act
463
+
464
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
465
+ hidden_states = self.dense(hidden_states)
466
+ hidden_states = self.intermediate_act_fn(hidden_states)
467
+ return hidden_states
468
+
469
+
470
+ class BertOutput(nn.Module):
471
+ def __init__(self, config):
472
+ super().__init__()
473
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
474
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
475
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
476
+
477
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
478
+ hidden_states = self.dense(hidden_states)
479
+ hidden_states = self.dropout(hidden_states)
480
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
481
+ return hidden_states
482
+
483
+
484
+ class BertLayer(nn.Module):
485
+ def __init__(self, config):
486
+ super().__init__()
487
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
488
+ self.seq_len_dim = 1
489
+ self.attention = BertAttention(config)
490
+ self.is_decoder = config.is_decoder
491
+ self.add_cross_attention = config.add_cross_attention
492
+ if self.add_cross_attention:
493
+ if not self.is_decoder:
494
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
495
+ self.crossattention = BertAttention(config, position_embedding_type="absolute")
496
+ self.intermediate = BertIntermediate(config)
497
+ self.output = BertOutput(config)
498
+
499
+ def forward(
500
+ self,
501
+ hidden_states: torch.Tensor,
502
+ attention_mask: Optional[torch.FloatTensor] = None,
503
+ head_mask: Optional[torch.FloatTensor] = None,
504
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
505
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
506
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
507
+ output_attentions: Optional[bool] = False,
508
+ ) -> Tuple[torch.Tensor]:
509
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
510
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
511
+ self_attention_outputs = self.attention(
512
+ hidden_states,
513
+ attention_mask,
514
+ head_mask,
515
+ output_attentions=output_attentions,
516
+ past_key_value=self_attn_past_key_value,
517
+ )
518
+ attention_output = self_attention_outputs[0]
519
+
520
+ # if decoder, the last output is tuple of self-attn cache
521
+ if self.is_decoder:
522
+ outputs = self_attention_outputs[1:-1]
523
+ present_key_value = self_attention_outputs[-1]
524
+ else:
525
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
526
+
527
+ cross_attn_present_key_value = None
528
+ if self.is_decoder and encoder_hidden_states is not None:
529
+ if not hasattr(self, "crossattention"):
530
+ raise ValueError(
531
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
532
+ " by setting `config.add_cross_attention=True`"
533
+ )
534
+
535
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
536
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
537
+ cross_attention_outputs = self.crossattention(
538
+ attention_output,
539
+ attention_mask,
540
+ head_mask,
541
+ encoder_hidden_states,
542
+ encoder_attention_mask,
543
+ cross_attn_past_key_value,
544
+ output_attentions,
545
+ )
546
+ attention_output = cross_attention_outputs[0]
547
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
548
+
549
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
550
+ cross_attn_present_key_value = cross_attention_outputs[-1]
551
+ present_key_value = present_key_value + cross_attn_present_key_value
552
+
553
+ layer_output = apply_chunking_to_forward(
554
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
555
+ )
556
+ outputs = (layer_output,) + outputs
557
+
558
+ # if decoder, return the attn key/values as the last output
559
+ if self.is_decoder:
560
+ outputs = outputs + (present_key_value,)
561
+
562
+ return outputs
563
+
564
+ def feed_forward_chunk(self, attention_output):
565
+ intermediate_output = self.intermediate(attention_output)
566
+ layer_output = self.output(intermediate_output, attention_output)
567
+ return layer_output
568
+
569
+
570
+ class BertEncoder(nn.Module):
571
+ def __init__(self, config):
572
+ super().__init__()
573
+ self.config = config
574
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
575
+ self.gradient_checkpointing = False
576
+
577
+ def forward(
578
+ self,
579
+ hidden_states: torch.Tensor,
580
+ attention_mask: Optional[torch.FloatTensor] = None,
581
+ head_mask: Optional[torch.FloatTensor] = None,
582
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
583
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
584
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
585
+ use_cache: Optional[bool] = None,
586
+ output_attentions: Optional[bool] = False,
587
+ output_hidden_states: Optional[bool] = False,
588
+ return_dict: Optional[bool] = True,
589
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
590
+ all_hidden_states = () if output_hidden_states else None
591
+ all_self_attentions = () if output_attentions else None
592
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
593
+
594
+ next_decoder_cache = () if use_cache else None
595
+ for i, layer_module in enumerate(self.layer):
596
+ if output_hidden_states:
597
+ all_hidden_states = all_hidden_states + (hidden_states,)
598
+
599
+ layer_head_mask = head_mask[i] if head_mask is not None else None
600
+ past_key_value = past_key_values[i] if past_key_values is not None else None
601
+
602
+ if self.gradient_checkpointing and self.training:
603
+
604
+ if use_cache:
605
+ logger.warning(
606
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
607
+ )
608
+ use_cache = False
609
+
610
+ def create_custom_forward(module):
611
+ def custom_forward(*inputs):
612
+ return module(*inputs, past_key_value, output_attentions)
613
+
614
+ return custom_forward
615
+
616
+ layer_outputs = torch.utils.checkpoint.checkpoint(
617
+ create_custom_forward(layer_module),
618
+ hidden_states,
619
+ attention_mask,
620
+ layer_head_mask,
621
+ encoder_hidden_states,
622
+ encoder_attention_mask,
623
+ )
624
+ else:
625
+ layer_outputs = layer_module(
626
+ hidden_states,
627
+ attention_mask,
628
+ layer_head_mask,
629
+ encoder_hidden_states,
630
+ encoder_attention_mask,
631
+ past_key_value,
632
+ output_attentions,
633
+ )
634
+
635
+ hidden_states = layer_outputs[0]
636
+ if use_cache:
637
+ next_decoder_cache += (layer_outputs[-1],)
638
+ if output_attentions:
639
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
640
+ if self.config.add_cross_attention:
641
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
642
+
643
+ if output_hidden_states:
644
+ all_hidden_states = all_hidden_states + (hidden_states,)
645
+
646
+ if not return_dict:
647
+ return tuple(
648
+ v
649
+ for v in [
650
+ hidden_states,
651
+ next_decoder_cache,
652
+ all_hidden_states,
653
+ all_self_attentions,
654
+ all_cross_attentions,
655
+ ]
656
+ if v is not None
657
+ )
658
+ return BaseModelOutputWithPastAndCrossAttentions(
659
+ last_hidden_state=hidden_states,
660
+ past_key_values=next_decoder_cache,
661
+ hidden_states=all_hidden_states,
662
+ attentions=all_self_attentions,
663
+ cross_attentions=all_cross_attentions,
664
+ )
665
+
666
+
667
+ class BertPooler(nn.Module):
668
+ def __init__(self, config):
669
+ super().__init__()
670
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
671
+ self.activation = nn.Tanh()
672
+
673
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
674
+ # We "pool" the model by simply taking the hidden state corresponding
675
+ # to the first token.
676
+ first_token_tensor = hidden_states[:, 0]
677
+ pooled_output = self.dense(first_token_tensor)
678
+ pooled_output = self.activation(pooled_output)
679
+ return pooled_output
680
+
681
+
682
+ class BertPredictionHeadTransform(nn.Module):
683
+ def __init__(self, config):
684
+ super().__init__()
685
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
686
+ if isinstance(config.hidden_act, str):
687
+ self.transform_act_fn = ACT2FN[config.hidden_act]
688
+ else:
689
+ self.transform_act_fn = config.hidden_act
690
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
691
+
692
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
693
+ hidden_states = self.dense(hidden_states)
694
+ hidden_states = self.transform_act_fn(hidden_states)
695
+ hidden_states = self.LayerNorm(hidden_states)
696
+ return hidden_states
697
+
698
+
699
+ class BertLMPredictionHead(nn.Module):
700
+ def __init__(self, config):
701
+ super().__init__()
702
+ self.transform = BertPredictionHeadTransform(config)
703
+
704
+ # The output weights are the same as the input embeddings, but there is
705
+ # an output-only bias for each token.
706
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
707
+
708
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
709
+
710
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
711
+ self.decoder.bias = self.bias
712
+
713
+ def forward(self, hidden_states):
714
+ hidden_states = self.transform(hidden_states)
715
+ hidden_states = self.decoder(hidden_states)
716
+ return hidden_states
717
+
718
+
719
+ class BertOnlyMLMHead(nn.Module):
720
+ def __init__(self, config):
721
+ super().__init__()
722
+ self.predictions = BertLMPredictionHead(config)
723
+
724
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
725
+ prediction_scores = self.predictions(sequence_output)
726
+ return prediction_scores
727
+
728
+
729
+ class BertOnlyNSPHead(nn.Module):
730
+ def __init__(self, config):
731
+ super().__init__()
732
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
733
+
734
+ def forward(self, pooled_output):
735
+ seq_relationship_score = self.seq_relationship(pooled_output)
736
+ return seq_relationship_score
737
+
738
+
739
+ class BertPreTrainingHeads(nn.Module):
740
+ def __init__(self, config):
741
+ super().__init__()
742
+ self.predictions = BertLMPredictionHead(config)
743
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
744
+
745
+ def forward(self, sequence_output, pooled_output):
746
+ prediction_scores = self.predictions(sequence_output)
747
+ seq_relationship_score = self.seq_relationship(pooled_output)
748
+ return prediction_scores, seq_relationship_score
749
+
750
+
751
+ class BertPreTrainedModel(PreTrainedModel):
752
+ """
753
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
754
+ models.
755
+ """
756
+
757
+ config_class = BertConfig
758
+ load_tf_weights = load_tf_weights_in_bert
759
+ base_model_prefix = "bert"
760
+ supports_gradient_checkpointing = True
761
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
762
+
763
+ def _init_weights(self, module):
764
+ """Initialize the weights"""
765
+ if isinstance(module, nn.Linear):
766
+ # Slightly different from the TF version which uses truncated_normal for initialization
767
+ # cf https://github.com/pytorch/pytorch/pull/5617
768
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
769
+ if module.bias is not None:
770
+ module.bias.data.zero_()
771
+ elif isinstance(module, nn.Embedding):
772
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
773
+ if module.padding_idx is not None:
774
+ module.weight.data[module.padding_idx].zero_()
775
+ elif isinstance(module, nn.LayerNorm):
776
+ module.bias.data.zero_()
777
+ module.weight.data.fill_(1.0)
778
+
779
+ def _set_gradient_checkpointing(self, module, value=False):
780
+ if isinstance(module, BertEncoder):
781
+ module.gradient_checkpointing = value
782
+
783
+
784
+ @dataclass
785
+ class BertForPreTrainingOutput(ModelOutput):
786
+ """
787
+ Output type of [`BertForPreTraining`].
788
+
789
+ Args:
790
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
791
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
792
+ (classification) loss.
793
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
794
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
795
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
796
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
797
+ before SoftMax).
798
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
799
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
800
+ shape `(batch_size, sequence_length, hidden_size)`.
801
+
802
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
803
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
804
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
805
+ sequence_length)`.
806
+
807
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
808
+ heads.
809
+ """
810
+
811
+ loss: Optional[torch.FloatTensor] = None
812
+ prediction_logits: torch.FloatTensor = None
813
+ seq_relationship_logits: torch.FloatTensor = None
814
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
815
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
816
+
817
+
818
+ BERT_START_DOCSTRING = r"""
819
+
820
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
821
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
822
+ etc.)
823
+
824
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
825
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
826
+ and behavior.
827
+
828
+ Parameters:
829
+ config ([`BertConfig`]): Model configuration class with all the parameters of the model.
830
+ Initializing with a config file does not load the weights associated with the model, only the
831
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
832
+ """
833
+
834
+ BERT_INPUTS_DOCSTRING = r"""
835
+ Args:
836
+ input_ids (`torch.LongTensor` of shape `({0})`):
837
+ Indices of input sequence tokens in the vocabulary.
838
+
839
+ Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
840
+ [`PreTrainedTokenizer.__call__`] for details.
841
+
842
+ [What are input IDs?](../glossary#input-ids)
843
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
844
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
845
+
846
+ - 1 for tokens that are **not masked**,
847
+ - 0 for tokens that are **masked**.
848
+
849
+ [What are attention masks?](../glossary#attention-mask)
850
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
851
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
852
+ 1]`:
853
+
854
+ - 0 corresponds to a *sentence A* token,
855
+ - 1 corresponds to a *sentence B* token.
856
+
857
+ [What are token type IDs?](../glossary#token-type-ids)
858
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
859
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
860
+ config.max_position_embeddings - 1]`.
861
+
862
+ [What are position IDs?](../glossary#position-ids)
863
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
864
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
865
+
866
+ - 1 indicates the head is **not masked**,
867
+ - 0 indicates the head is **masked**.
868
+
869
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
870
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
871
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
872
+ model's internal embedding lookup matrix.
873
+ output_attentions (`bool`, *optional*):
874
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
875
+ tensors for more detail.
876
+ output_hidden_states (`bool`, *optional*):
877
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
878
+ more detail.
879
+ return_dict (`bool`, *optional*):
880
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
881
+ """
882
+
883
+
884
+ @add_start_docstrings(
885
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
886
+ BERT_START_DOCSTRING,
887
+ )
888
+ class BertModel(BertPreTrainedModel):
889
+ """
890
+
891
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
892
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
893
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
894
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
895
+
896
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
897
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
898
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
899
+ """
900
+
901
+ def __init__(self, config, use_pretrained_embedding=False, add_pooling_layer=True):
902
+ super().__init__(config)
903
+ self.config = config
904
+ self.use_pretrained_embedding = use_pretrained_embedding
905
+ self.add_pooling_layer = add_pooling_layer
906
+
907
+ self.embeddings = nn.Linear(config.embedding_input_size, config.hidden_size) if use_pretrained_embedding else BertEmbeddings(config)
908
+
909
+ self.encoder = BertEncoder(config)
910
+
911
+ self.pooler = BertPooler(config) if add_pooling_layer else None
912
+
913
+ # Initialize weights and apply final processing
914
+ self.post_init()
915
+
916
+ def get_input_embeddings(self):
917
+ return self.embeddings.word_embeddings
918
+
919
+ def set_input_embeddings(self, value):
920
+ self.embeddings.word_embeddings = value
921
+
922
+ def _prune_heads(self, heads_to_prune):
923
+ """
924
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
925
+ class PreTrainedModel
926
+ """
927
+ for layer, heads in heads_to_prune.items():
928
+ self.encoder.layer[layer].attention.prune_heads(heads)
929
+
930
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
931
+ @add_code_sample_docstrings(
932
+ processor_class=_TOKENIZER_FOR_DOC,
933
+ checkpoint=_CHECKPOINT_FOR_DOC,
934
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
935
+ config_class=_CONFIG_FOR_DOC,
936
+ )
937
+ def forward(
938
+ self,
939
+ input_ids: Optional[torch.Tensor] = None,
940
+ attention_mask: Optional[torch.Tensor] = None,
941
+ token_type_ids: Optional[torch.Tensor] = None,
942
+ position_ids: Optional[torch.Tensor] = None,
943
+ head_mask: Optional[torch.Tensor] = None,
944
+ inputs_embeds: Optional[torch.Tensor] = None,
945
+ encoder_hidden_states: Optional[torch.Tensor] = None,
946
+ encoder_attention_mask: Optional[torch.Tensor] = None,
947
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
948
+ use_cache: Optional[bool] = None,
949
+ output_attentions: Optional[bool] = None,
950
+ output_hidden_states: Optional[bool] = None,
951
+ return_dict: Optional[bool] = None,
952
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
953
+ r"""
954
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
955
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
956
+ the model is configured as a decoder.
957
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
958
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
959
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
960
+
961
+ - 1 for tokens that are **not masked**,
962
+ - 0 for tokens that are **masked**.
963
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
964
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
965
+
966
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
967
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
968
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
969
+ use_cache (`bool`, *optional*):
970
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
971
+ `past_key_values`).
972
+ """
973
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
974
+ output_hidden_states = (
975
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
976
+ )
977
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
978
+
979
+ if self.config.is_decoder:
980
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
981
+ else:
982
+ use_cache = False
983
+
984
+ if input_ids is not None and inputs_embeds is not None:
985
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
986
+ elif input_ids is not None:
987
+ input_shape = input_ids.size()
988
+ elif inputs_embeds is not None:
989
+ input_shape = inputs_embeds.size()[:-1]
990
+ else:
991
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
992
+
993
+ batch_size, seq_length = input_shape
994
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
995
+
996
+ # past_key_values_length
997
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
998
+
999
+ if attention_mask is None:
1000
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
1001
+
1002
+ if token_type_ids is None:
1003
+ if hasattr(self.embeddings, "token_type_ids"):
1004
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
1005
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
1006
+ token_type_ids = buffered_token_type_ids_expanded
1007
+ else:
1008
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1009
+
1010
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1011
+ # ourselves in which case we just need to make it broadcastable to all heads.
1012
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
1013
+
1014
+ # If a 2D or 3D attention mask is provided for the cross-attention
1015
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1016
+ if self.config.is_decoder and encoder_hidden_states is not None:
1017
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1018
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1019
+ if encoder_attention_mask is None:
1020
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1021
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1022
+ else:
1023
+ encoder_extended_attention_mask = None
1024
+
1025
+ # Prepare head mask if needed
1026
+ # 1.0 in head_mask indicate we keep the head
1027
+ # attention_probs has shape bsz x n_heads x N x N
1028
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1029
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1030
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1031
+
1032
+ if self.use_pretrained_embedding:
1033
+ embedding_output = self.embeddings(inputs_embeds)
1034
+ else:
1035
+ embedding_output = self.embeddings(
1036
+ input_ids=input_ids,
1037
+ position_ids=position_ids,
1038
+ token_type_ids=token_type_ids,
1039
+ inputs_embeds=inputs_embeds,
1040
+ past_key_values_length=past_key_values_length,
1041
+ )
1042
+ encoder_outputs = self.encoder(
1043
+ embedding_output,
1044
+ attention_mask=extended_attention_mask,
1045
+ head_mask=head_mask,
1046
+ encoder_hidden_states=encoder_hidden_states,
1047
+ encoder_attention_mask=encoder_extended_attention_mask,
1048
+ past_key_values=past_key_values,
1049
+ use_cache=use_cache,
1050
+ output_attentions=output_attentions,
1051
+ output_hidden_states=output_hidden_states,
1052
+ return_dict=return_dict,
1053
+ )
1054
+ sequence_output = encoder_outputs[0]
1055
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1056
+
1057
+ if not return_dict:
1058
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1059
+
1060
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1061
+ last_hidden_state=sequence_output,
1062
+ pooler_output=pooled_output,
1063
+ past_key_values=encoder_outputs.past_key_values,
1064
+ hidden_states=encoder_outputs.hidden_states,
1065
+ attentions=encoder_outputs.attentions,
1066
+ cross_attentions=encoder_outputs.cross_attentions,
1067
+ )
1068
+
1069
+
1070
+ @add_start_docstrings(
1071
+ """
1072
+ Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
1073
+ sentence prediction (classification)` head.
1074
+ """,
1075
+ BERT_START_DOCSTRING,
1076
+ )
1077
+ class BertForPreTraining(BertPreTrainedModel):
1078
+ def __init__(self, config):
1079
+ super().__init__(config)
1080
+
1081
+ self.bert = BertModel(config)
1082
+ self.cls = BertPreTrainingHeads(config)
1083
+
1084
+ # Initialize weights and apply final processing
1085
+ self.post_init()
1086
+
1087
+ def get_output_embeddings(self):
1088
+ return self.cls.predictions.decoder
1089
+
1090
+ def set_output_embeddings(self, new_embeddings):
1091
+ self.cls.predictions.decoder = new_embeddings
1092
+
1093
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1094
+ @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1095
+ def forward(
1096
+ self,
1097
+ input_ids: Optional[torch.Tensor] = None,
1098
+ attention_mask: Optional[torch.Tensor] = None,
1099
+ token_type_ids: Optional[torch.Tensor] = None,
1100
+ position_ids: Optional[torch.Tensor] = None,
1101
+ head_mask: Optional[torch.Tensor] = None,
1102
+ inputs_embeds: Optional[torch.Tensor] = None,
1103
+ labels: Optional[torch.Tensor] = None,
1104
+ next_sentence_label: Optional[torch.Tensor] = None,
1105
+ output_attentions: Optional[bool] = None,
1106
+ output_hidden_states: Optional[bool] = None,
1107
+ return_dict: Optional[bool] = None,
1108
+ ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]:
1109
+ r"""
1110
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1111
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1112
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
1113
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1114
+ next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1115
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
1116
+ pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
1117
+
1118
+ - 0 indicates sequence B is a continuation of sequence A,
1119
+ - 1 indicates sequence B is a random sequence.
1120
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
1121
+ Used to hide legacy arguments that have been deprecated.
1122
+
1123
+ Returns:
1124
+
1125
+ Example:
1126
+
1127
+ ```python
1128
+ >>> from transformers import BertTokenizer, BertForPreTraining
1129
+ >>> import torch
1130
+
1131
+ >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
1132
+ >>> model = BertForPreTraining.from_pretrained("bert-base-uncased")
1133
+
1134
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1135
+ >>> outputs = model(**inputs)
1136
+
1137
+ >>> prediction_logits = outputs.prediction_logits
1138
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
1139
+ ```
1140
+ """
1141
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1142
+
1143
+ outputs = self.bert(
1144
+ input_ids,
1145
+ attention_mask=attention_mask,
1146
+ token_type_ids=token_type_ids,
1147
+ position_ids=position_ids,
1148
+ head_mask=head_mask,
1149
+ inputs_embeds=inputs_embeds,
1150
+ output_attentions=output_attentions,
1151
+ output_hidden_states=output_hidden_states,
1152
+ return_dict=return_dict,
1153
+ )
1154
+
1155
+ sequence_output, pooled_output = outputs[:2]
1156
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
1157
+
1158
+ total_loss = None
1159
+ if labels is not None and next_sentence_label is not None:
1160
+ loss_fct = CrossEntropyLoss()
1161
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1162
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1163
+ total_loss = masked_lm_loss + next_sentence_loss
1164
+
1165
+ if not return_dict:
1166
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
1167
+ return ((total_loss,) + output) if total_loss is not None else output
1168
+
1169
+ return BertForPreTrainingOutput(
1170
+ loss=total_loss,
1171
+ prediction_logits=prediction_scores,
1172
+ seq_relationship_logits=seq_relationship_score,
1173
+ hidden_states=outputs.hidden_states,
1174
+ attentions=outputs.attentions,
1175
+ )
1176
+
1177
+
1178
+ @add_start_docstrings(
1179
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
1180
+ )
1181
+ class BertLMHeadModel(BertPreTrainedModel):
1182
+
1183
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1184
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1185
+
1186
+ def __init__(self, config):
1187
+ super().__init__(config)
1188
+
1189
+ if not config.is_decoder:
1190
+ logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
1191
+
1192
+ self.bert = BertModel(config, add_pooling_layer=False)
1193
+ self.cls = BertOnlyMLMHead(config)
1194
+
1195
+ # Initialize weights and apply final processing
1196
+ self.post_init()
1197
+
1198
+ def get_output_embeddings(self):
1199
+ return self.cls.predictions.decoder
1200
+
1201
+ def set_output_embeddings(self, new_embeddings):
1202
+ self.cls.predictions.decoder = new_embeddings
1203
+
1204
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1205
+ @add_code_sample_docstrings(
1206
+ processor_class=_TOKENIZER_FOR_DOC,
1207
+ checkpoint=_CHECKPOINT_FOR_DOC,
1208
+ output_type=CausalLMOutputWithCrossAttentions,
1209
+ config_class=_CONFIG_FOR_DOC,
1210
+ )
1211
+ def forward(
1212
+ self,
1213
+ input_ids: Optional[torch.Tensor] = None,
1214
+ attention_mask: Optional[torch.Tensor] = None,
1215
+ token_type_ids: Optional[torch.Tensor] = None,
1216
+ position_ids: Optional[torch.Tensor] = None,
1217
+ head_mask: Optional[torch.Tensor] = None,
1218
+ inputs_embeds: Optional[torch.Tensor] = None,
1219
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1220
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1221
+ labels: Optional[torch.Tensor] = None,
1222
+ past_key_values: Optional[List[torch.Tensor]] = None,
1223
+ use_cache: Optional[bool] = None,
1224
+ output_attentions: Optional[bool] = None,
1225
+ output_hidden_states: Optional[bool] = None,
1226
+ return_dict: Optional[bool] = None,
1227
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
1228
+ r"""
1229
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1230
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1231
+ the model is configured as a decoder.
1232
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1233
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1234
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1235
+
1236
+ - 1 for tokens that are **not masked**,
1237
+ - 0 for tokens that are **masked**.
1238
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1239
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1240
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1241
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
1242
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1243
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1244
+
1245
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1246
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1247
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1248
+ use_cache (`bool`, *optional*):
1249
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1250
+ `past_key_values`).
1251
+ """
1252
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1253
+ if labels is not None:
1254
+ use_cache = False
1255
+
1256
+ outputs = self.bert(
1257
+ input_ids,
1258
+ attention_mask=attention_mask,
1259
+ token_type_ids=token_type_ids,
1260
+ position_ids=position_ids,
1261
+ head_mask=head_mask,
1262
+ inputs_embeds=inputs_embeds,
1263
+ encoder_hidden_states=encoder_hidden_states,
1264
+ encoder_attention_mask=encoder_attention_mask,
1265
+ past_key_values=past_key_values,
1266
+ use_cache=use_cache,
1267
+ output_attentions=output_attentions,
1268
+ output_hidden_states=output_hidden_states,
1269
+ return_dict=return_dict,
1270
+ )
1271
+
1272
+ sequence_output = outputs[0]
1273
+ prediction_scores = self.cls(sequence_output)
1274
+
1275
+ lm_loss = None
1276
+ if labels is not None:
1277
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1278
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1279
+ labels = labels[:, 1:].contiguous()
1280
+ loss_fct = CrossEntropyLoss()
1281
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1282
+
1283
+ if not return_dict:
1284
+ output = (prediction_scores,) + outputs[2:]
1285
+ return ((lm_loss,) + output) if lm_loss is not None else output
1286
+
1287
+ return CausalLMOutputWithCrossAttentions(
1288
+ loss=lm_loss,
1289
+ logits=prediction_scores,
1290
+ past_key_values=outputs.past_key_values,
1291
+ hidden_states=outputs.hidden_states,
1292
+ attentions=outputs.attentions,
1293
+ cross_attentions=outputs.cross_attentions,
1294
+ )
1295
+
1296
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
1297
+ input_shape = input_ids.shape
1298
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1299
+ if attention_mask is None:
1300
+ attention_mask = input_ids.new_ones(input_shape)
1301
+
1302
+ # cut decoder_input_ids if past is used
1303
+ if past is not None:
1304
+ input_ids = input_ids[:, -1:]
1305
+
1306
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
1307
+
1308
+ def _reorder_cache(self, past, beam_idx):
1309
+ reordered_past = ()
1310
+ for layer_past in past:
1311
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1312
+ return reordered_past
1313
+
1314
+
1315
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
1316
+ class BertForMaskedLM(BertPreTrainedModel):
1317
+
1318
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1319
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1320
+
1321
+ def __init__(self, config):
1322
+ super().__init__(config)
1323
+
1324
+ if config.is_decoder:
1325
+ logger.warning(
1326
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
1327
+ "bi-directional self-attention."
1328
+ )
1329
+
1330
+ self.bert = BertModel(config, add_pooling_layer=False)
1331
+ self.cls = BertOnlyMLMHead(config)
1332
+
1333
+ # Initialize weights and apply final processing
1334
+ self.post_init()
1335
+
1336
+ def get_output_embeddings(self):
1337
+ return self.cls.predictions.decoder
1338
+
1339
+ def set_output_embeddings(self, new_embeddings):
1340
+ self.cls.predictions.decoder = new_embeddings
1341
+
1342
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1343
+ @add_code_sample_docstrings(
1344
+ processor_class=_TOKENIZER_FOR_DOC,
1345
+ checkpoint=_CHECKPOINT_FOR_DOC,
1346
+ output_type=MaskedLMOutput,
1347
+ config_class=_CONFIG_FOR_DOC,
1348
+ expected_output="'paris'",
1349
+ expected_loss=0.88,
1350
+ )
1351
+ def forward(
1352
+ self,
1353
+ input_ids: Optional[torch.Tensor] = None,
1354
+ attention_mask: Optional[torch.Tensor] = None,
1355
+ token_type_ids: Optional[torch.Tensor] = None,
1356
+ position_ids: Optional[torch.Tensor] = None,
1357
+ head_mask: Optional[torch.Tensor] = None,
1358
+ inputs_embeds: Optional[torch.Tensor] = None,
1359
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1360
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1361
+ labels: Optional[torch.Tensor] = None,
1362
+ output_attentions: Optional[bool] = None,
1363
+ output_hidden_states: Optional[bool] = None,
1364
+ return_dict: Optional[bool] = None,
1365
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1366
+ r"""
1367
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1368
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1369
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1370
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1371
+ """
1372
+
1373
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1374
+
1375
+ outputs = self.bert(
1376
+ input_ids,
1377
+ attention_mask=attention_mask,
1378
+ token_type_ids=token_type_ids,
1379
+ position_ids=position_ids,
1380
+ head_mask=head_mask,
1381
+ inputs_embeds=inputs_embeds,
1382
+ encoder_hidden_states=encoder_hidden_states,
1383
+ encoder_attention_mask=encoder_attention_mask,
1384
+ output_attentions=output_attentions,
1385
+ output_hidden_states=output_hidden_states,
1386
+ return_dict=return_dict,
1387
+ )
1388
+
1389
+ sequence_output = outputs[0]
1390
+ prediction_scores = self.cls(sequence_output)
1391
+
1392
+ masked_lm_loss = None
1393
+ if labels is not None:
1394
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1395
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1396
+
1397
+ if not return_dict:
1398
+ output = (prediction_scores,) + outputs[2:]
1399
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1400
+
1401
+ return MaskedLMOutput(
1402
+ loss=masked_lm_loss,
1403
+ logits=prediction_scores,
1404
+ hidden_states=outputs.hidden_states,
1405
+ attentions=outputs.attentions,
1406
+ )
1407
+
1408
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1409
+ input_shape = input_ids.shape
1410
+ effective_batch_size = input_shape[0]
1411
+
1412
+ # add a dummy token
1413
+ if self.config.pad_token_id is None:
1414
+ raise ValueError("The PAD token should be defined for generation")
1415
+
1416
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1417
+ dummy_token = torch.full(
1418
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1419
+ )
1420
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1421
+
1422
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1423
+
1424
+
1425
+ @add_start_docstrings(
1426
+ """Bert Model with a `next sentence prediction (classification)` head on top.""",
1427
+ BERT_START_DOCSTRING,
1428
+ )
1429
+ class BertForNextSentencePrediction(BertPreTrainedModel):
1430
+ def __init__(self, config):
1431
+ super().__init__(config)
1432
+
1433
+ self.bert = BertModel(config)
1434
+ self.cls = BertOnlyNSPHead(config)
1435
+
1436
+ # Initialize weights and apply final processing
1437
+ self.post_init()
1438
+
1439
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1440
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
1441
+ def forward(
1442
+ self,
1443
+ input_ids: Optional[torch.Tensor] = None,
1444
+ attention_mask: Optional[torch.Tensor] = None,
1445
+ token_type_ids: Optional[torch.Tensor] = None,
1446
+ position_ids: Optional[torch.Tensor] = None,
1447
+ head_mask: Optional[torch.Tensor] = None,
1448
+ inputs_embeds: Optional[torch.Tensor] = None,
1449
+ labels: Optional[torch.Tensor] = None,
1450
+ output_attentions: Optional[bool] = None,
1451
+ output_hidden_states: Optional[bool] = None,
1452
+ return_dict: Optional[bool] = None,
1453
+ **kwargs,
1454
+ ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
1455
+ r"""
1456
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1457
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1458
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
1459
+
1460
+ - 0 indicates sequence B is a continuation of sequence A,
1461
+ - 1 indicates sequence B is a random sequence.
1462
+
1463
+ Returns:
1464
+
1465
+ Example:
1466
+
1467
+ ```python
1468
+ >>> from transformers import BertTokenizer, BertForNextSentencePrediction
1469
+ >>> import torch
1470
+
1471
+ >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
1472
+ >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased")
1473
+
1474
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1475
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1476
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
1477
+
1478
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
1479
+ >>> logits = outputs.logits
1480
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1481
+ ```
1482
+ """
1483
+
1484
+ if "next_sentence_label" in kwargs:
1485
+ warnings.warn(
1486
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
1487
+ " `labels` instead.",
1488
+ FutureWarning,
1489
+ )
1490
+ labels = kwargs.pop("next_sentence_label")
1491
+
1492
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1493
+
1494
+ outputs = self.bert(
1495
+ input_ids,
1496
+ attention_mask=attention_mask,
1497
+ token_type_ids=token_type_ids,
1498
+ position_ids=position_ids,
1499
+ head_mask=head_mask,
1500
+ inputs_embeds=inputs_embeds,
1501
+ output_attentions=output_attentions,
1502
+ output_hidden_states=output_hidden_states,
1503
+ return_dict=return_dict,
1504
+ )
1505
+
1506
+ pooled_output = outputs[1]
1507
+
1508
+ seq_relationship_scores = self.cls(pooled_output)
1509
+
1510
+ next_sentence_loss = None
1511
+ if labels is not None:
1512
+ loss_fct = CrossEntropyLoss()
1513
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
1514
+
1515
+ if not return_dict:
1516
+ output = (seq_relationship_scores,) + outputs[2:]
1517
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1518
+
1519
+ return NextSentencePredictorOutput(
1520
+ loss=next_sentence_loss,
1521
+ logits=seq_relationship_scores,
1522
+ hidden_states=outputs.hidden_states,
1523
+ attentions=outputs.attentions,
1524
+ )
1525
+
1526
+
1527
+ @add_start_docstrings(
1528
+ """
1529
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1530
+ output) e.g. for GLUE tasks.
1531
+ """,
1532
+ BERT_START_DOCSTRING,
1533
+ )
1534
+ class BertForSequenceClassification(BertPreTrainedModel):
1535
+ def __init__(self, config):
1536
+ super().__init__(config)
1537
+ self.num_labels = config.num_labels
1538
+ self.config = config
1539
+
1540
+ self.bert = BertModel(config)
1541
+ classifier_dropout_prob = (
1542
+ config.classifier_dropout_prob if config.classifier_dropout_prob is not None else config.hidden_dropout_prob
1543
+ )
1544
+ self.dropout = nn.Dropout(classifier_dropout_prob)
1545
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1546
+
1547
+ # Initialize weights and apply final processing
1548
+ self.post_init()
1549
+
1550
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1551
+ @add_code_sample_docstrings(
1552
+ processor_class=_TOKENIZER_FOR_DOC,
1553
+ checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1554
+ output_type=SequenceClassifierOutput,
1555
+ config_class=_CONFIG_FOR_DOC,
1556
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1557
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1558
+ )
1559
+ def forward(
1560
+ self,
1561
+ input_ids: Optional[torch.Tensor] = None,
1562
+ attention_mask: Optional[torch.Tensor] = None,
1563
+ token_type_ids: Optional[torch.Tensor] = None,
1564
+ position_ids: Optional[torch.Tensor] = None,
1565
+ head_mask: Optional[torch.Tensor] = None,
1566
+ inputs_embeds: Optional[torch.Tensor] = None,
1567
+ labels: Optional[torch.Tensor] = None,
1568
+ output_attentions: Optional[bool] = None,
1569
+ output_hidden_states: Optional[bool] = None,
1570
+ return_dict: Optional[bool] = None,
1571
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1572
+ r"""
1573
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1574
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1575
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1576
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1577
+ """
1578
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1579
+
1580
+ outputs = self.bert(
1581
+ input_ids,
1582
+ attention_mask=attention_mask,
1583
+ token_type_ids=token_type_ids,
1584
+ position_ids=position_ids,
1585
+ head_mask=head_mask,
1586
+ inputs_embeds=inputs_embeds,
1587
+ output_attentions=output_attentions,
1588
+ output_hidden_states=output_hidden_states,
1589
+ return_dict=return_dict,
1590
+ )
1591
+
1592
+ pooled_output = outputs[1]
1593
+
1594
+ pooled_output = self.dropout(pooled_output)
1595
+ logits = self.classifier(pooled_output)
1596
+
1597
+ loss = None
1598
+ if labels is not None:
1599
+ if self.config.problem_type is None:
1600
+ if self.num_labels == 1:
1601
+ self.config.problem_type = "regression"
1602
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1603
+ self.config.problem_type = "single_label_classification"
1604
+ else:
1605
+ self.config.problem_type = "multi_label_classification"
1606
+
1607
+ if self.config.problem_type == "regression":
1608
+ loss_fct = MSELoss()
1609
+ if self.num_labels == 1:
1610
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1611
+ else:
1612
+ loss = loss_fct(logits, labels)
1613
+ elif self.config.problem_type == "single_label_classification":
1614
+ loss_fct = CrossEntropyLoss()
1615
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1616
+ elif self.config.problem_type == "multi_label_classification":
1617
+ loss_fct = BCEWithLogitsLoss()
1618
+ loss = loss_fct(logits, labels)
1619
+ if not return_dict:
1620
+ output = (logits,) + outputs[2:]
1621
+ return ((loss,) + output) if loss is not None else output
1622
+
1623
+ return SequenceClassifierOutput(
1624
+ loss=loss,
1625
+ logits=logits,
1626
+ hidden_states=outputs.hidden_states,
1627
+ attentions=outputs.attentions,
1628
+ )
1629
+
1630
+
1631
+ @add_start_docstrings(
1632
+ """
1633
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1634
+ softmax) e.g. for RocStories/SWAG tasks.
1635
+ """,
1636
+ BERT_START_DOCSTRING,
1637
+ )
1638
+ class BertForMultipleChoice(BertPreTrainedModel):
1639
+ def __init__(self, config):
1640
+ super().__init__(config)
1641
+
1642
+ self.bert = BertModel(config)
1643
+ classifier_dropout_prob = (
1644
+ config.classifier_dropout_prob if config.classifier_dropout_prob is not None else config.hidden_dropout_prob
1645
+ )
1646
+ self.dropout = nn.Dropout(classifier_dropout_prob)
1647
+ self.classifier = nn.Linear(config.hidden_size, 1)
1648
+
1649
+ # Initialize weights and apply final processing
1650
+ self.post_init()
1651
+
1652
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1653
+ @add_code_sample_docstrings(
1654
+ processor_class=_TOKENIZER_FOR_DOC,
1655
+ checkpoint=_CHECKPOINT_FOR_DOC,
1656
+ output_type=MultipleChoiceModelOutput,
1657
+ config_class=_CONFIG_FOR_DOC,
1658
+ )
1659
+ def forward(
1660
+ self,
1661
+ input_ids: Optional[torch.Tensor] = None,
1662
+ attention_mask: Optional[torch.Tensor] = None,
1663
+ token_type_ids: Optional[torch.Tensor] = None,
1664
+ position_ids: Optional[torch.Tensor] = None,
1665
+ head_mask: Optional[torch.Tensor] = None,
1666
+ inputs_embeds: Optional[torch.Tensor] = None,
1667
+ labels: Optional[torch.Tensor] = None,
1668
+ output_attentions: Optional[bool] = None,
1669
+ output_hidden_states: Optional[bool] = None,
1670
+ return_dict: Optional[bool] = None,
1671
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1672
+ r"""
1673
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1674
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1675
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1676
+ `input_ids` above)
1677
+ """
1678
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1679
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1680
+
1681
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1682
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1683
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1684
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1685
+ inputs_embeds = (
1686
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1687
+ if inputs_embeds is not None
1688
+ else None
1689
+ )
1690
+
1691
+ outputs = self.bert(
1692
+ input_ids,
1693
+ attention_mask=attention_mask,
1694
+ token_type_ids=token_type_ids,
1695
+ position_ids=position_ids,
1696
+ head_mask=head_mask,
1697
+ inputs_embeds=inputs_embeds,
1698
+ output_attentions=output_attentions,
1699
+ output_hidden_states=output_hidden_states,
1700
+ return_dict=return_dict,
1701
+ )
1702
+
1703
+ pooled_output = outputs[1]
1704
+
1705
+ pooled_output = self.dropout(pooled_output)
1706
+ logits = self.classifier(pooled_output)
1707
+ reshaped_logits = logits.view(-1, num_choices)
1708
+
1709
+ loss = None
1710
+ if labels is not None:
1711
+ loss_fct = CrossEntropyLoss()
1712
+ loss = loss_fct(reshaped_logits, labels)
1713
+
1714
+ if not return_dict:
1715
+ output = (reshaped_logits,) + outputs[2:]
1716
+ return ((loss,) + output) if loss is not None else output
1717
+
1718
+ return MultipleChoiceModelOutput(
1719
+ loss=loss,
1720
+ logits=reshaped_logits,
1721
+ hidden_states=outputs.hidden_states,
1722
+ attentions=outputs.attentions,
1723
+ )
1724
+
1725
+
1726
+ @add_start_docstrings(
1727
+ """
1728
+ Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1729
+ Named-Entity-Recognition (NER) tasks.
1730
+ """,
1731
+ BERT_START_DOCSTRING,
1732
+ )
1733
+ class BertForTokenClassification(BertPreTrainedModel):
1734
+
1735
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1736
+
1737
+ def __init__(self, config):
1738
+ super().__init__(config)
1739
+ self.num_labels = config.num_labels
1740
+
1741
+ self.bert = BertModel(config, add_pooling_layer=False)
1742
+ classifier_dropout_prob = (
1743
+ config.classifier_dropout_prob if config.classifier_dropout_prob is not None else config.hidden_dropout_prob
1744
+ )
1745
+ self.dropout = nn.Dropout(classifier_dropout_prob)
1746
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1747
+
1748
+ # Initialize weights and apply final processing
1749
+ self.post_init()
1750
+
1751
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1752
+ @add_code_sample_docstrings(
1753
+ processor_class=_TOKENIZER_FOR_DOC,
1754
+ checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
1755
+ output_type=TokenClassifierOutput,
1756
+ config_class=_CONFIG_FOR_DOC,
1757
+ expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
1758
+ expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
1759
+ )
1760
+ def forward(
1761
+ self,
1762
+ input_ids: Optional[torch.Tensor] = None,
1763
+ attention_mask: Optional[torch.Tensor] = None,
1764
+ token_type_ids: Optional[torch.Tensor] = None,
1765
+ position_ids: Optional[torch.Tensor] = None,
1766
+ head_mask: Optional[torch.Tensor] = None,
1767
+ inputs_embeds: Optional[torch.Tensor] = None,
1768
+ labels: Optional[torch.Tensor] = None,
1769
+ output_attentions: Optional[bool] = None,
1770
+ output_hidden_states: Optional[bool] = None,
1771
+ return_dict: Optional[bool] = None,
1772
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1773
+ r"""
1774
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1775
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1776
+ """
1777
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1778
+
1779
+ outputs = self.bert(
1780
+ input_ids,
1781
+ attention_mask=attention_mask,
1782
+ token_type_ids=token_type_ids,
1783
+ position_ids=position_ids,
1784
+ head_mask=head_mask,
1785
+ inputs_embeds=inputs_embeds,
1786
+ output_attentions=output_attentions,
1787
+ output_hidden_states=output_hidden_states,
1788
+ return_dict=return_dict,
1789
+ )
1790
+
1791
+ sequence_output = outputs[0]
1792
+
1793
+ sequence_output = self.dropout(sequence_output)
1794
+ logits = self.classifier(sequence_output)
1795
+
1796
+ loss = None
1797
+ if labels is not None:
1798
+ loss_fct = CrossEntropyLoss()
1799
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1800
+
1801
+ if not return_dict:
1802
+ output = (logits,) + outputs[2:]
1803
+ return ((loss,) + output) if loss is not None else output
1804
+
1805
+ return TokenClassifierOutput(
1806
+ loss=loss,
1807
+ logits=logits,
1808
+ hidden_states=outputs.hidden_states,
1809
+ attentions=outputs.attentions,
1810
+ )
1811
+
1812
+
1813
+ @add_start_docstrings(
1814
+ """
1815
+ Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1816
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1817
+ """,
1818
+ BERT_START_DOCSTRING,
1819
+ )
1820
+ class BertForQuestionAnswering(BertPreTrainedModel):
1821
+
1822
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1823
+
1824
+ def __init__(self, config):
1825
+ super().__init__(config)
1826
+ self.num_labels = config.num_labels
1827
+
1828
+ self.bert = BertModel(config, add_pooling_layer=False)
1829
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1830
+
1831
+ # Initialize weights and apply final processing
1832
+ self.post_init()
1833
+
1834
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1835
+ @add_code_sample_docstrings(
1836
+ processor_class=_TOKENIZER_FOR_DOC,
1837
+ checkpoint=_CHECKPOINT_FOR_QA,
1838
+ output_type=QuestionAnsweringModelOutput,
1839
+ config_class=_CONFIG_FOR_DOC,
1840
+ qa_target_start_index=_QA_TARGET_START_INDEX,
1841
+ qa_target_end_index=_QA_TARGET_END_INDEX,
1842
+ expected_output=_QA_EXPECTED_OUTPUT,
1843
+ expected_loss=_QA_EXPECTED_LOSS,
1844
+ )
1845
+ def forward(
1846
+ self,
1847
+ input_ids: Optional[torch.Tensor] = None,
1848
+ attention_mask: Optional[torch.Tensor] = None,
1849
+ token_type_ids: Optional[torch.Tensor] = None,
1850
+ position_ids: Optional[torch.Tensor] = None,
1851
+ head_mask: Optional[torch.Tensor] = None,
1852
+ inputs_embeds: Optional[torch.Tensor] = None,
1853
+ start_positions: Optional[torch.Tensor] = None,
1854
+ end_positions: Optional[torch.Tensor] = None,
1855
+ output_attentions: Optional[bool] = None,
1856
+ output_hidden_states: Optional[bool] = None,
1857
+ return_dict: Optional[bool] = None,
1858
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1859
+ r"""
1860
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1861
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1862
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1863
+ are not taken into account for computing the loss.
1864
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1865
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1866
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1867
+ are not taken into account for computing the loss.
1868
+ """
1869
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1870
+
1871
+ outputs = self.bert(
1872
+ input_ids,
1873
+ attention_mask=attention_mask,
1874
+ token_type_ids=token_type_ids,
1875
+ position_ids=position_ids,
1876
+ head_mask=head_mask,
1877
+ inputs_embeds=inputs_embeds,
1878
+ output_attentions=output_attentions,
1879
+ output_hidden_states=output_hidden_states,
1880
+ return_dict=return_dict,
1881
+ )
1882
+
1883
+ sequence_output = outputs[0]
1884
+
1885
+ logits = self.qa_outputs(sequence_output)
1886
+ start_logits, end_logits = logits.split(1, dim=-1)
1887
+ start_logits = start_logits.squeeze(-1).contiguous()
1888
+ end_logits = end_logits.squeeze(-1).contiguous()
1889
+
1890
+ total_loss = None
1891
+ if start_positions is not None and end_positions is not None:
1892
+ # If we are on multi-GPU, split add a dimension
1893
+ if len(start_positions.size()) > 1:
1894
+ start_positions = start_positions.squeeze(-1)
1895
+ if len(end_positions.size()) > 1:
1896
+ end_positions = end_positions.squeeze(-1)
1897
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1898
+ ignored_index = start_logits.size(1)
1899
+ start_positions = start_positions.clamp(0, ignored_index)
1900
+ end_positions = end_positions.clamp(0, ignored_index)
1901
+
1902
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1903
+ start_loss = loss_fct(start_logits, start_positions)
1904
+ end_loss = loss_fct(end_logits, end_positions)
1905
+ total_loss = (start_loss + end_loss) / 2
1906
+
1907
+ if not return_dict:
1908
+ output = (start_logits, end_logits) + outputs[2:]
1909
+ return ((total_loss,) + output) if total_loss is not None else output
1910
+
1911
+ return QuestionAnsweringModelOutput(
1912
+ loss=total_loss,
1913
+ start_logits=start_logits,
1914
+ end_logits=end_logits,
1915
+ hidden_states=outputs.hidden_states,
1916
+ attentions=outputs.attentions,
1917
+ )
modeling_gplm.py ADDED
@@ -0,0 +1,1225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ import math
5
+ from typing import Dict, Optional, Sequence, Tuple, List, Union
6
+ import uuid
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import Tensor, nn
10
+ from torch.nn import Parameter
11
+
12
+
13
+ def gelu(x):
14
+ """Implementation of the gelu activation function.
15
+ OpenAI GPT's gelu: 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
16
+ """
17
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
18
+
19
+
20
+ def symmetrize(x):
21
+ "Make layer symmetric in final two dimensions, used for contact prediction."
22
+ return x + x.transpose(-1, -2)
23
+
24
+
25
+ def apc(x):
26
+ "Perform average product correct, used for contact prediction."
27
+ a1 = x.sum(-1, keepdims=True)
28
+ a2 = x.sum(-2, keepdims=True)
29
+ a12 = x.sum((-1, -2), keepdims=True)
30
+
31
+ avg = a1 * a2
32
+ avg.div_(a12) # in-place to reduce memory
33
+ normalized = x - avg
34
+ return normalized
35
+
36
+
37
+ class LucaGPLM1LayerNorm(nn.Module):
38
+ def __init__(self, hidden_size, eps=1e-12, affine=True):
39
+ """Construct a layernorm layer in the TF style (eps inside the sqrt)."""
40
+ super().__init__()
41
+ self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size)
42
+ self.eps = eps
43
+ self.affine = bool(affine)
44
+ if self.affine:
45
+ self.weight = nn.Parameter(torch.ones(hidden_size))
46
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
47
+ else:
48
+ self.weight, self.bias = None, None
49
+
50
+ def forward(self, x):
51
+ dims = tuple(-(i + 1) for i in range(len(self.hidden_size)))
52
+ means = x.mean(dims, keepdim=True)
53
+ x_zeromean = x - means
54
+ variances = x_zeromean.pow(2).mean(dims, keepdim=True)
55
+ x = x_zeromean / torch.sqrt(variances + self.eps)
56
+ if self.affine:
57
+ x = (self.weight * x) + self.bias
58
+ return x
59
+
60
+
61
+ try:
62
+ # Optimized LayerNorm
63
+ from apex.normalization import FusedLayerNorm as _FusedLayerNorm
64
+ class LucaGPLM1bLayerNorm(_FusedLayerNorm):
65
+ @torch.jit.unused
66
+ def forward(self, x):
67
+ if not x.is_cuda:
68
+ return super().forward(x)
69
+ else:
70
+ with torch.cuda.device(x.device):
71
+ return super().forward(x)
72
+
73
+ except ImportError as e:
74
+ print("import apex err:", e)
75
+ from torch.nn import LayerNorm as LucaGPLM1bLayerNorm
76
+
77
+
78
+ class LucaGPLMTransformerLayer(nn.Module):
79
+ """LucaGPLM Transformer layer block."""
80
+
81
+ def __init__(
82
+ self,
83
+ embed_dim,
84
+ ffn_embed_dim,
85
+ attention_heads,
86
+ add_bias_kv=True,
87
+ use_lucagplm1b_layer_norm=False,
88
+ use_rotary_embeddings: bool = False,
89
+ ):
90
+ '''
91
+ Tramsformer-Encoder 层
92
+ :param embed_dim: token embedding dim
93
+ :param ffn_embed_dim: fully connected layer dim
94
+ :param attention_heads: heads num
95
+ :param add_bias_kv: key-value layer add bias
96
+ :param use_lucagplm1b_layer_norm: whether to use lucagplm 1b layer norm
97
+ :param use_rotary_embeddings: whether to use rotary embedding
98
+ '''
99
+ super().__init__()
100
+ self.embed_dim = embed_dim
101
+ self.ffn_embed_dim = ffn_embed_dim
102
+ self.attention_heads = attention_heads
103
+ self.use_rotary_embeddings = use_rotary_embeddings
104
+ self._init_submodules(add_bias_kv, use_lucagplm1b_layer_norm)
105
+
106
+ def _init_submodules(self, add_bias_kv, use_lucagplm1b_layer_norm):
107
+ LucaGPLMLayerNorm = LucaGPLM1bLayerNorm if use_lucagplm1b_layer_norm else LucaGPLM1LayerNorm
108
+
109
+ # pre layer norm
110
+ self.pre_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
111
+
112
+ self.self_attn = LucaGPLMMultiheadAttention(
113
+ self.embed_dim,
114
+ self.attention_heads,
115
+ add_bias_kv=add_bias_kv,
116
+ add_zero_attn=False,
117
+ use_rotary_embeddings=self.use_rotary_embeddings,
118
+ )
119
+
120
+ # post layer norm
121
+ self.post_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
122
+
123
+ # dimension increase by the fully connected layer
124
+ self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
125
+
126
+ # dimension reduction by the fully connected layer
127
+ self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
128
+
129
+ def forward(
130
+ self,
131
+ x,
132
+ self_attn_mask=None,
133
+ self_attn_padding_mask=None,
134
+ need_head_weights=False
135
+ ):
136
+ residual = x
137
+ x = self.pre_layer_norm(x)
138
+ x, attn = self.self_attn(
139
+ query=x,
140
+ key=x,
141
+ value=x,
142
+ key_padding_mask=self_attn_padding_mask,
143
+ need_weights=True,
144
+ need_head_weights=need_head_weights,
145
+ attn_mask=self_attn_mask,
146
+ )
147
+ x = residual + x
148
+
149
+ residual = x
150
+ x = self.post_layer_norm(x)
151
+ x = gelu(self.fc1(x))
152
+ x = self.fc2(x)
153
+ x = residual + x
154
+
155
+ return x, attn
156
+
157
+
158
+ class AxialTransformerLayer(nn.Module):
159
+ """Implements an Axial MSA Transformer block."""
160
+ def __init__(
161
+ self,
162
+ embedding_dim: int = 768,
163
+ ffn_embedding_dim: int = 3072,
164
+ num_attention_heads: int = 8,
165
+ dropout: float = 0.1,
166
+ attention_dropout: float = 0.1,
167
+ activation_dropout: float = 0.1,
168
+ max_tokens_per_msa: int = 2**14,
169
+ ) -> None:
170
+ super().__init__()
171
+
172
+ # Initialize parameters
173
+ self.embedding_dim = embedding_dim
174
+ self.dropout_prob = dropout
175
+
176
+ row_self_attention = RowSelfAttention(
177
+ embedding_dim,
178
+ num_attention_heads,
179
+ dropout=dropout,
180
+ max_tokens_per_msa=max_tokens_per_msa,
181
+ )
182
+
183
+ column_self_attention = ColumnSelfAttention(
184
+ embedding_dim,
185
+ num_attention_heads,
186
+ dropout=dropout,
187
+ max_tokens_per_msa=max_tokens_per_msa,
188
+ )
189
+
190
+ feed_forward_layer = FeedForwardNetwork(
191
+ embedding_dim,
192
+ ffn_embedding_dim,
193
+ activation_dropout=activation_dropout,
194
+ max_tokens_per_msa=max_tokens_per_msa,
195
+ )
196
+
197
+ self.row_self_attention = self.build_residual(row_self_attention)
198
+ self.column_self_attention = self.build_residual(column_self_attention)
199
+ self.feed_forward_layer = self.build_residual(feed_forward_layer)
200
+
201
+ def build_residual(self, layer: nn.Module):
202
+ return NormalizedResidualBlock(
203
+ layer,
204
+ self.embedding_dim,
205
+ self.dropout_prob,
206
+ )
207
+
208
+ def forward(
209
+ self,
210
+ x: torch.Tensor,
211
+ self_attn_mask: Optional[torch.Tensor] = None,
212
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
213
+ need_head_weights: bool = False,
214
+ ):
215
+ """
216
+ LayerNorm is applied either before or after the self-attention/ffn
217
+ modules similar to the original Transformer implementation.
218
+ """
219
+ x, row_attn = self.row_self_attention(
220
+ x,
221
+ self_attn_mask=self_attn_mask,
222
+ self_attn_padding_mask=self_attn_padding_mask,
223
+ )
224
+ x, column_attn = self.column_self_attention(
225
+ x,
226
+ self_attn_mask=self_attn_mask,
227
+ self_attn_padding_mask=self_attn_padding_mask,
228
+ )
229
+ x = self.feed_forward_layer(x)
230
+ if need_head_weights:
231
+ return x, column_attn, row_attn
232
+ else:
233
+ return x
234
+
235
+
236
+ class LearnedPositionalEmbedding(nn.Embedding):
237
+ """
238
+ This module learns positional embeddings up to a fixed maximum size.
239
+ Padding ids are ignored by either offsetting based on padding_idx
240
+ or by setting padding_idx to None and ensuring that the appropriate
241
+ position ids are passed to the forward function.
242
+ """
243
+
244
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
245
+ if padding_idx is not None:
246
+ num_embeddings_ = num_embeddings + padding_idx + 1
247
+ else:
248
+ num_embeddings_ = num_embeddings
249
+ super().__init__(num_embeddings_, embedding_dim, padding_idx)
250
+ self.max_positions = num_embeddings
251
+
252
+ def forward(self, input: torch.Tensor):
253
+ """Input is expected to be of size [bsz x seqlen]."""
254
+ if input.size(1) > self.max_positions:
255
+ raise ValueError(
256
+ f"Sequence length {input.size(1)} above maximum "
257
+ f" sequence length of {self.max_positions}"
258
+ )
259
+ mask = input.ne(self.padding_idx).int()
260
+ positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
261
+ return F.embedding(
262
+ positions,
263
+ self.weight,
264
+ self.padding_idx,
265
+ self.max_norm,
266
+ self.norm_type,
267
+ self.scale_grad_by_freq,
268
+ self.sparse,
269
+ )
270
+
271
+
272
+ class SinusoidalPositionalEmbedding(nn.Module):
273
+ def __init__(self, embed_dim, padding_idx, learned=False):
274
+ super().__init__()
275
+ self.embed_dim = embed_dim
276
+ self.padding_idx = padding_idx
277
+ self.register_buffer("_float_tensor", torch.FloatTensor(1))
278
+ self.weights = None
279
+
280
+ def forward(self, x):
281
+ bsz, seq_len = x.shape
282
+ max_pos = self.padding_idx + 1 + seq_len
283
+ if self.weights is None or max_pos > self.weights.size(0):
284
+ self.weights = self.get_embedding(max_pos)
285
+ self.weights = self.weights.type_as(self._float_tensor)
286
+
287
+ positions = self.make_positions(x)
288
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
289
+
290
+ def make_positions(self, x):
291
+ mask = x.ne(self.padding_idx)
292
+ range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1
293
+ positions = range_buf.expand_as(x)
294
+ return positions * mask.long() + self.padding_idx * (1 - mask.long())
295
+
296
+ def get_embedding(self, num_embeddings):
297
+ half_dim = self.embed_dim // 2
298
+ emb = math.log(10000) / (half_dim - 1)
299
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
300
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
301
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
302
+ if self.embed_dim % 2 == 1:
303
+ # zero pad
304
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
305
+ if self.padding_idx is not None:
306
+ emb[self.padding_idx, :] = 0
307
+ return emb
308
+
309
+
310
+ class RobertaLMHead(nn.Module):
311
+ """Head for masked language modeling."""
312
+
313
+ def __init__(self, embed_dim, output_dim, weight):
314
+ super().__init__()
315
+ self.dense = nn.Linear(embed_dim, embed_dim)
316
+ self.layer_norm = LucaGPLM1bLayerNorm(embed_dim)
317
+ self.weight = weight
318
+ self.bias = nn.Parameter(torch.zeros(output_dim))
319
+
320
+ def forward(self, features):
321
+ x = self.dense(features)
322
+ x = gelu(x)
323
+ x = self.layer_norm(x)
324
+ # project back to size of vocabulary with bias
325
+ x = F.linear(x, self.weight) + self.bias
326
+ return x
327
+
328
+
329
+ class ContactPredictionHead(nn.Module):
330
+ """Performs symmetrization, apc, and computes a logistic regression on the output features"""
331
+
332
+ def __init__(
333
+ self,
334
+ in_features: int,
335
+ prepend_bos: bool,
336
+ append_eos: bool,
337
+ bias=True,
338
+ eos_idx: Optional[int] = None,
339
+ ):
340
+ super().__init__()
341
+ self.in_features = in_features
342
+ self.prepend_bos = prepend_bos
343
+ self.append_eos = append_eos
344
+ if append_eos and eos_idx is None:
345
+ raise ValueError("Using an alphabet with eos token, but no eos token was passed in.")
346
+ self.eos_idx = eos_idx
347
+ self.regression = nn.Linear(in_features, 1, bias)
348
+ self.activation = nn.Sigmoid()
349
+
350
+ def forward(self, tokens, attentions):
351
+ # remove eos token attentions
352
+ if self.append_eos:
353
+ eos_mask = tokens.ne(self.eos_idx).to(attentions)
354
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
355
+ attentions = attentions * eos_mask[:, None, None, :, :]
356
+ attentions = attentions[..., :-1, :-1]
357
+ # remove cls token attentions
358
+ if self.prepend_bos:
359
+ attentions = attentions[..., 1:, 1:]
360
+ batch_size, layers, heads, seqlen, _ = attentions.size()
361
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
362
+
363
+ # features: B x C x T x T
364
+ attentions = attentions.to(
365
+ self.regression.weight.device
366
+ ) # attentions always float32, may need to convert to float16
367
+ attentions = apc(symmetrize(attentions))
368
+ attentions = attentions.permute(0, 2, 3, 1)
369
+ return self.activation(self.regression(attentions).squeeze(3))
370
+
371
+
372
+ class NormalizedResidualBlock(nn.Module):
373
+ def __init__(
374
+ self,
375
+ layer: nn.Module,
376
+ embedding_dim: int,
377
+ dropout: float = 0.1,
378
+ ):
379
+ super().__init__()
380
+ self.embedding_dim = embedding_dim
381
+
382
+ self.layer = layer
383
+ self.dropout_module = nn.Dropout(
384
+ dropout,
385
+ )
386
+ self.layer_norm = LucaGPLM1bLayerNorm(self.embedding_dim)
387
+
388
+ def forward(self, x, *args, **kwargs):
389
+ residual = x
390
+ x = self.layer_norm(x)
391
+ outputs = self.layer(x, *args, **kwargs)
392
+ if isinstance(outputs, tuple):
393
+ x, *out = outputs
394
+ else:
395
+ x = outputs
396
+ out = None
397
+
398
+ x = self.dropout_module(x)
399
+ x = residual + x
400
+
401
+ if out is not None:
402
+ return (x,) + tuple(out)
403
+ else:
404
+ return x
405
+
406
+
407
+ class FeedForwardNetwork(nn.Module):
408
+ def __init__(
409
+ self,
410
+ embedding_dim: int,
411
+ ffn_embedding_dim: int,
412
+ activation_dropout: float = 0.1,
413
+ max_tokens_per_msa: int = 2**14,
414
+ ):
415
+ super().__init__()
416
+ self.embedding_dim = embedding_dim
417
+ self.ffn_embedding_dim = ffn_embedding_dim
418
+ self.max_tokens_per_msa = max_tokens_per_msa
419
+ self.activation_fn = nn.GELU()
420
+ self.activation_dropout_module = nn.Dropout(
421
+ activation_dropout,
422
+ )
423
+ self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
424
+ self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)
425
+
426
+ def forward(self, x):
427
+ x = self.activation_fn(self.fc1(x))
428
+ x = self.activation_dropout_module(x)
429
+ x = self.fc2(x)
430
+ return x
431
+
432
+
433
+ class RowSelfAttention(nn.Module):
434
+ """Compute self-attention over rows of a 2D input."""
435
+
436
+ def __init__(
437
+ self,
438
+ embed_dim,
439
+ num_heads,
440
+ dropout=0.0,
441
+ max_tokens_per_msa: int = 2 ** 16,
442
+ ):
443
+ super().__init__()
444
+ self.num_heads = num_heads
445
+ self.dropout = dropout
446
+ self.head_dim = embed_dim // num_heads
447
+ self.scaling = self.head_dim ** -0.5
448
+ self.max_tokens_per_msa = max_tokens_per_msa
449
+ self.attn_shape = "hnij"
450
+
451
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
452
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
453
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
454
+
455
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
456
+ self.dropout_module = nn.Dropout(dropout)
457
+
458
+ def align_scaling(self, q):
459
+ num_rows = q.size(0)
460
+ return self.scaling / math.sqrt(num_rows)
461
+
462
+ def _batched_forward(
463
+ self,
464
+ x,
465
+ self_attn_mask=None,
466
+ self_attn_padding_mask=None,
467
+ ):
468
+ num_rows, num_cols, batch_size, embed_dim = x.size()
469
+ max_rows = max(1, self.max_tokens_per_msa // num_cols)
470
+ attns = 0
471
+ scaling = self.align_scaling(x)
472
+ for start in range(0, num_rows, max_rows):
473
+ attn_weights = self.compute_attention_weights(
474
+ x[start : start + max_rows],
475
+ scaling,
476
+ self_attn_mask=self_attn_mask,
477
+ self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows]
478
+ if self_attn_padding_mask is not None
479
+ else None,
480
+ )
481
+ attns += attn_weights
482
+ attn_probs = attns.softmax(-1)
483
+ attn_probs = self.dropout_module(attn_probs)
484
+
485
+ outputs = []
486
+ for start in range(0, num_rows, max_rows):
487
+ output = self.compute_attention_update(x[start : start + max_rows], attn_probs)
488
+ outputs.append(output)
489
+
490
+ output = torch.cat(outputs, 0)
491
+ return output, attn_probs
492
+
493
+ def compute_attention_weights(
494
+ self,
495
+ x,
496
+ scaling: float,
497
+ self_attn_mask=None,
498
+ self_attn_padding_mask=None,
499
+ ):
500
+ num_rows, num_cols, batch_size, embed_dim = x.size()
501
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
502
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
503
+ q *= scaling
504
+ if self_attn_padding_mask is not None:
505
+ # Zero out any padded aligned positions - this is important since
506
+ # we take a sum across the alignment axis.
507
+ q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q)
508
+
509
+ attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)
510
+
511
+ if self_attn_mask is not None:
512
+ raise NotImplementedError
513
+ # Mask Size: [B x R x C], Weights Size: [H x B x C x C]
514
+
515
+ if self_attn_padding_mask is not None:
516
+ attn_weights = attn_weights.masked_fill(
517
+ self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2),
518
+ -10000,
519
+ )
520
+
521
+ return attn_weights
522
+
523
+ def compute_attention_update(
524
+ self,
525
+ x,
526
+ attn_probs,
527
+ ):
528
+ num_rows, num_cols, batch_size, embed_dim = x.size()
529
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
530
+ context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
531
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
532
+ output = self.out_proj(context)
533
+ return output
534
+
535
+ def forward(
536
+ self,
537
+ x,
538
+ self_attn_mask=None,
539
+ self_attn_padding_mask=None,
540
+ ):
541
+ num_rows, num_cols, batch_size, embed_dim = x.size()
542
+ if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled():
543
+ return self._batched_forward(x, self_attn_mask, self_attn_padding_mask)
544
+ else:
545
+ scaling = self.align_scaling(x)
546
+ attn_weights = self.compute_attention_weights(
547
+ x, scaling, self_attn_mask, self_attn_padding_mask
548
+ )
549
+ attn_probs = attn_weights.softmax(-1)
550
+ attn_probs = self.dropout_module(attn_probs)
551
+ output = self.compute_attention_update(x, attn_probs)
552
+ return output, attn_probs
553
+
554
+
555
+ class ColumnSelfAttention(nn.Module):
556
+ """Compute self-attention over columns of a 2D input."""
557
+
558
+ def __init__(
559
+ self,
560
+ embed_dim,
561
+ num_heads,
562
+ dropout=0.0,
563
+ max_tokens_per_msa: int = 2 ** 16,
564
+ ):
565
+ super().__init__()
566
+
567
+ self.num_heads = num_heads
568
+ self.dropout = dropout
569
+ self.head_dim = embed_dim // num_heads
570
+ self.scaling = self.head_dim ** -0.5
571
+ self.max_tokens_per_msa = max_tokens_per_msa
572
+
573
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
574
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
575
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
576
+
577
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
578
+ self.dropout_module = nn.Dropout(dropout)
579
+
580
+ def _batched_forward(
581
+ self,
582
+ x,
583
+ self_attn_mask=None,
584
+ self_attn_padding_mask=None,
585
+ ):
586
+ num_rows, num_cols, batch_size, embed_dim = x.size()
587
+ max_cols = max(1, self.max_tokens_per_msa // num_rows)
588
+ outputs = []
589
+ attns = []
590
+ for start in range(0, num_cols, max_cols):
591
+ output, attn = self(
592
+ x[:, start : start + max_cols],
593
+ self_attn_mask=self_attn_mask,
594
+ self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols]
595
+ if self_attn_padding_mask is not None
596
+ else None,
597
+ )
598
+ outputs.append(output)
599
+ attns.append(attn)
600
+ output = torch.cat(outputs, 1)
601
+ attns = torch.cat(attns, 1)
602
+ return output, attns
603
+
604
+ def compute_attention_update(
605
+ self,
606
+ x,
607
+ self_attn_mask=None,
608
+ self_attn_padding_mask=None,
609
+ ):
610
+ num_rows, num_cols, batch_size, embed_dim = x.size()
611
+ if num_rows == 1:
612
+ # if there is only 1 position, this is equivalent and doesn't break with padding
613
+ attn_probs = torch.ones(
614
+ self.num_heads,
615
+ num_cols,
616
+ batch_size,
617
+ num_rows,
618
+ num_rows,
619
+ device=x.device,
620
+ dtype=x.dtype,
621
+ )
622
+ output = self.out_proj(self.v_proj(x))
623
+ else:
624
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
625
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
626
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
627
+ q *= self.scaling
628
+
629
+ attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)
630
+
631
+ if self_attn_mask is not None:
632
+ raise NotImplementedError
633
+ if self_attn_padding_mask is not None:
634
+ attn_weights = attn_weights.masked_fill(
635
+ self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3),
636
+ -10000,
637
+ )
638
+
639
+ attn_probs = attn_weights.softmax(-1)
640
+ attn_probs = self.dropout_module(attn_probs)
641
+ context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
642
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
643
+ output = self.out_proj(context)
644
+ return output, attn_probs
645
+
646
+ def forward(
647
+ self,
648
+ x,
649
+ self_attn_mask=None,
650
+ self_attn_padding_mask=None,
651
+ ):
652
+ num_rows, num_cols, batch_size, embed_dim = x.size()
653
+ # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled():
654
+ if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled():
655
+ return self._batched_forward(
656
+ x,
657
+ self_attn_mask,
658
+ self_attn_padding_mask,
659
+ )
660
+ else:
661
+ return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask)
662
+
663
+
664
+ def utils_softmax(x, dim: int, onnx_trace: bool = False):
665
+ if onnx_trace:
666
+ return F.softmax(x.float(), dim=dim)
667
+ else:
668
+ return F.softmax(x, dim=dim, dtype=torch.float32)
669
+
670
+
671
+ class FairseqIncrementalState(object):
672
+ def __init__(self, *args, **kwargs):
673
+ super().__init__(*args, **kwargs)
674
+ self.init_incremental_state()
675
+
676
+ def init_incremental_state(self):
677
+ self._incremental_state_id = str(uuid.uuid4())
678
+
679
+ def _get_full_incremental_state_key(self, key: str) -> str:
680
+ return "{}.{}".format(self._incremental_state_id, key)
681
+
682
+ def get_incremental_state(
683
+ self,
684
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
685
+ key: str,
686
+ ) -> Optional[Dict[str, Optional[Tensor]]]:
687
+ """Helper for getting incremental state for an nn.Module."""
688
+ full_key = self._get_full_incremental_state_key(key)
689
+ if incremental_state is None or full_key not in incremental_state:
690
+ return None
691
+ return incremental_state[full_key]
692
+
693
+ def set_incremental_state(
694
+ self,
695
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
696
+ key: str,
697
+ value: Dict[str, Optional[Tensor]],
698
+ ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
699
+ """Helper for setting incremental state for an nn.Module."""
700
+ if incremental_state is not None:
701
+ full_key = self._get_full_incremental_state_key(key)
702
+ incremental_state[full_key] = value
703
+ return incremental_state
704
+
705
+
706
+ def with_incremental_state(cls):
707
+ cls.__bases__ = (FairseqIncrementalState,) + tuple(
708
+ b for b in cls.__bases__ if b != FairseqIncrementalState
709
+ )
710
+ return cls
711
+
712
+
713
+ @with_incremental_state
714
+ class LucaGPLMMultiheadAttention(nn.Module):
715
+ """Multi-headed attention.
716
+
717
+ See "Attention Is All You Need" for more details.
718
+ """
719
+
720
+ def __init__(
721
+ self,
722
+ embed_dim,
723
+ num_heads,
724
+ kdim=None,
725
+ vdim=None,
726
+ dropout=0.0,
727
+ bias=True,
728
+ add_bias_kv: bool = False,
729
+ add_zero_attn: bool = False,
730
+ self_attention: bool = False,
731
+ encoder_decoder_attention: bool = False,
732
+ use_rotary_embeddings: bool = False,
733
+ ):
734
+ super().__init__()
735
+ self.embed_dim = embed_dim
736
+ self.kdim = kdim if kdim is not None else embed_dim
737
+ self.vdim = vdim if vdim is not None else embed_dim
738
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
739
+
740
+ self.num_heads = num_heads
741
+ self.dropout = dropout
742
+ self.head_dim = embed_dim // num_heads
743
+ assert (
744
+ self.head_dim * num_heads == self.embed_dim
745
+ ), "embed_dim must be divisible by num_heads"
746
+ self.scaling = self.head_dim**-0.5
747
+
748
+ self.self_attention = self_attention
749
+ self.encoder_decoder_attention = encoder_decoder_attention
750
+
751
+ assert not self.self_attention or self.qkv_same_dim, (
752
+ "Self-attention requires query, key and " "value to be of the same size"
753
+ )
754
+
755
+ self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
756
+ self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
757
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
758
+
759
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
760
+
761
+ if add_bias_kv:
762
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
763
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
764
+ else:
765
+ self.bias_k = self.bias_v = None
766
+
767
+ self.add_zero_attn = add_zero_attn
768
+
769
+ self.reset_parameters()
770
+
771
+ self.onnx_trace = False
772
+ self.rot_emb = None
773
+ if use_rotary_embeddings:
774
+ self.rot_emb = RotaryEmbedding(dim=self.head_dim)
775
+
776
+ self.enable_torch_version = False
777
+ if hasattr(F, "multi_head_attention_forward"):
778
+ self.enable_torch_version = True
779
+ else:
780
+ self.enable_torch_version = False
781
+
782
+ def prepare_for_onnx_export_(self):
783
+ self.onnx_trace = True
784
+
785
+ def reset_parameters(self):
786
+ '''
787
+ if self.qkv_same_dim:
788
+ # Empirically observed the convergence to be much better with
789
+ # the scaled initialization
790
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
791
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
792
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
793
+ else:
794
+ nn.init.xavier_uniform_(self.k_proj.weight)
795
+ nn.init.xavier_uniform_(self.v_proj.weight)
796
+ nn.init.xavier_uniform_(self.q_proj.weight)
797
+ '''
798
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu"))
799
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu"))
800
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu"))
801
+
802
+ nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu"))
803
+ # nn.init.xavier_uniform_(self.out_proj.weight)
804
+ if self.out_proj.bias is not None:
805
+ nn.init.constant_(self.out_proj.bias, 0.0)
806
+ if self.bias_k is not None:
807
+ nn.init.xavier_normal_(self.bias_k)
808
+ if self.bias_v is not None:
809
+ nn.init.xavier_normal_(self.bias_v)
810
+
811
+ def forward(
812
+ self,
813
+ query,
814
+ key: Optional[Tensor],
815
+ value: Optional[Tensor],
816
+ key_padding_mask: Optional[Tensor] = None,
817
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
818
+ need_weights: bool = True,
819
+ static_kv: bool = False,
820
+ attn_mask: Optional[Tensor] = None,
821
+ before_softmax: bool = False,
822
+ need_head_weights: bool = False,
823
+ ) -> Tuple[Tensor, Optional[Tensor]]:
824
+ """Input shape: Time x Batch x Channel
825
+
826
+ Args:
827
+ key_padding_mask (ByteTensor, optional): mask to exclude
828
+ keys that are pads, of shape `(batch, src_len)`, where
829
+ padding elements are indicated by 1s.
830
+ need_weights (bool, optional): return the attention weights,
831
+ averaged over heads (default: False).
832
+ attn_mask (ByteTensor, optional): typically used to
833
+ implement causal attention, where the mask prevents the
834
+ attention from looking forward in time (default: None).
835
+ before_softmax (bool, optional): return the raw attention
836
+ weights and values before the attention softmax.
837
+ need_head_weights (bool, optional): return the attention
838
+ weights for each head. Implies *need_weights*. Default:
839
+ return the average attention weights over all heads.
840
+ """
841
+ if need_head_weights:
842
+ need_weights = True
843
+
844
+ tgt_len, bsz, embed_dim = query.size()
845
+ assert embed_dim == self.embed_dim
846
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
847
+
848
+ if (
849
+ not self.rot_emb
850
+ and self.enable_torch_version
851
+ and not self.onnx_trace
852
+ and incremental_state is None
853
+ and not static_kv
854
+ # A workaround for quantization to work. Otherwise JIT compilation
855
+ # treats bias in linear module as method.
856
+ and not torch.jit.is_scripting()
857
+ and not need_head_weights
858
+ ):
859
+ assert key is not None and value is not None
860
+ return F.multi_head_attention_forward(
861
+ query,
862
+ key,
863
+ value,
864
+ self.embed_dim,
865
+ self.num_heads,
866
+ torch.empty([0]),
867
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
868
+ self.bias_k,
869
+ self.bias_v,
870
+ self.add_zero_attn,
871
+ self.dropout,
872
+ self.out_proj.weight,
873
+ self.out_proj.bias,
874
+ self.training,
875
+ key_padding_mask,
876
+ need_weights,
877
+ attn_mask,
878
+ use_separate_proj_weight=True,
879
+ q_proj_weight=self.q_proj.weight,
880
+ k_proj_weight=self.k_proj.weight,
881
+ v_proj_weight=self.v_proj.weight,
882
+ )
883
+ if incremental_state is not None:
884
+ saved_state = self._get_input_buffer(incremental_state)
885
+ if saved_state is not None and "prev_key" in saved_state:
886
+ # previous time steps are cached - no need to recompute
887
+ # key and value if they are static
888
+ if static_kv:
889
+ assert self.encoder_decoder_attention and not self.self_attention
890
+ key = value = None
891
+ else:
892
+ saved_state = None
893
+
894
+ if self.self_attention:
895
+ q = self.q_proj(query)
896
+ k = self.k_proj(query)
897
+ v = self.v_proj(query)
898
+ elif self.encoder_decoder_attention:
899
+ # encoder-decoder attention
900
+ q = self.q_proj(query)
901
+ if key is None:
902
+ assert value is None
903
+ k = v = None
904
+ else:
905
+ k = self.k_proj(key)
906
+ v = self.v_proj(key)
907
+
908
+ else:
909
+ assert key is not None and value is not None
910
+ q = self.q_proj(query)
911
+ k = self.k_proj(key)
912
+ v = self.v_proj(value)
913
+ q *= self.scaling
914
+
915
+ if self.bias_k is not None:
916
+ assert self.bias_v is not None
917
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
918
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
919
+ if attn_mask is not None:
920
+ attn_mask = torch.cat(
921
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
922
+ )
923
+ if key_padding_mask is not None:
924
+ key_padding_mask = torch.cat(
925
+ [
926
+ key_padding_mask,
927
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
928
+ ],
929
+ dim=1,
930
+ )
931
+
932
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
933
+ if k is not None:
934
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
935
+ if v is not None:
936
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
937
+
938
+ if saved_state is not None:
939
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
940
+ if "prev_key" in saved_state:
941
+ _prev_key = saved_state["prev_key"]
942
+ assert _prev_key is not None
943
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
944
+ if static_kv:
945
+ k = prev_key
946
+ else:
947
+ assert k is not None
948
+ k = torch.cat([prev_key, k], dim=1)
949
+ if "prev_value" in saved_state:
950
+ _prev_value = saved_state["prev_value"]
951
+ assert _prev_value is not None
952
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
953
+ if static_kv:
954
+ v = prev_value
955
+ else:
956
+ assert v is not None
957
+ v = torch.cat([prev_value, v], dim=1)
958
+ prev_key_padding_mask: Optional[Tensor] = None
959
+ if "prev_key_padding_mask" in saved_state:
960
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
961
+ assert k is not None and v is not None
962
+ key_padding_mask = LucaGPLMMultiheadAttention._append_prev_key_padding_mask(
963
+ key_padding_mask=key_padding_mask,
964
+ prev_key_padding_mask=prev_key_padding_mask,
965
+ batch_size=bsz,
966
+ src_len=k.size(1),
967
+ static_kv=static_kv,
968
+ )
969
+
970
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
971
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
972
+ saved_state["prev_key_padding_mask"] = key_padding_mask
973
+ # In this branch incremental_state is never None
974
+ assert incremental_state is not None
975
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
976
+ assert k is not None
977
+ src_len = k.size(1)
978
+
979
+ # This is part of a workaround to get around fork/join parallelism
980
+ # not supporting Optional types.
981
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
982
+ key_padding_mask = None
983
+
984
+ if key_padding_mask is not None:
985
+ assert key_padding_mask.size(0) == bsz
986
+ assert key_padding_mask.size(1) == src_len
987
+
988
+ if self.add_zero_attn:
989
+ assert v is not None
990
+ src_len += 1
991
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
992
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
993
+ if attn_mask is not None:
994
+ attn_mask = torch.cat(
995
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
996
+ )
997
+ if key_padding_mask is not None:
998
+ key_padding_mask = torch.cat(
999
+ [
1000
+ key_padding_mask,
1001
+ torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
1002
+ ],
1003
+ dim=1,
1004
+ )
1005
+
1006
+ if self.rot_emb:
1007
+ q, k = self.rot_emb(q, k)
1008
+
1009
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
1010
+ attn_weights = LucaGPLMMultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
1011
+
1012
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
1013
+
1014
+ if attn_mask is not None:
1015
+ attn_mask = attn_mask.unsqueeze(0)
1016
+ if self.onnx_trace:
1017
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
1018
+ attn_weights += attn_mask
1019
+
1020
+ if key_padding_mask is not None:
1021
+ # don't attend to padding symbols
1022
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
1023
+ attn_weights = attn_weights.masked_fill(
1024
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
1025
+ )
1026
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
1027
+
1028
+ if before_softmax:
1029
+ return attn_weights, v
1030
+
1031
+ attn_weights_float = utils_softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
1032
+ attn_weights = attn_weights_float.type_as(attn_weights)
1033
+ attn_probs = F.dropout(
1034
+ attn_weights_float.type_as(attn_weights),
1035
+ p=self.dropout,
1036
+ training=self.training,
1037
+ )
1038
+ assert v is not None
1039
+ attn = torch.bmm(attn_probs, v)
1040
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
1041
+ if self.onnx_trace and attn.size(1) == 1:
1042
+ # when ONNX tracing a single decoder step (sequence length == 1)
1043
+ # the transpose is a no-op copy before view, thus unnecessary
1044
+ attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
1045
+ else:
1046
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
1047
+ attn = self.out_proj(attn)
1048
+ attn_weights: Optional[Tensor] = None
1049
+ if need_weights:
1050
+ attn_weights = attn_weights_float.view(
1051
+ bsz, self.num_heads, tgt_len, src_len
1052
+ ).type_as(attn).transpose(1, 0)
1053
+ if not need_head_weights:
1054
+ # average attention weights over heads
1055
+ attn_weights = attn_weights.mean(dim=0)
1056
+
1057
+ return attn, attn_weights
1058
+
1059
+ @staticmethod
1060
+ def _append_prev_key_padding_mask(
1061
+ key_padding_mask: Optional[Tensor],
1062
+ prev_key_padding_mask: Optional[Tensor],
1063
+ batch_size: int,
1064
+ src_len: int,
1065
+ static_kv: bool,
1066
+ ) -> Optional[Tensor]:
1067
+ # saved key padding masks have shape (bsz, seq_len)
1068
+ if prev_key_padding_mask is not None and static_kv:
1069
+ new_key_padding_mask = prev_key_padding_mask
1070
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
1071
+ new_key_padding_mask = torch.cat(
1072
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
1073
+ )
1074
+ # During incremental decoding, as the padding token enters and
1075
+ # leaves the frame, there will be a time when prev or current
1076
+ # is None
1077
+ elif prev_key_padding_mask is not None:
1078
+ filler = torch.zeros(
1079
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
1080
+ device=prev_key_padding_mask.device,
1081
+ )
1082
+ new_key_padding_mask = torch.cat(
1083
+ [prev_key_padding_mask.float(), filler.float()], dim=1
1084
+ )
1085
+ elif key_padding_mask is not None:
1086
+ filler = torch.zeros(
1087
+ (batch_size, src_len - key_padding_mask.size(1)),
1088
+ device=key_padding_mask.device,
1089
+ )
1090
+ new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
1091
+ else:
1092
+ new_key_padding_mask = prev_key_padding_mask
1093
+ return new_key_padding_mask
1094
+
1095
+ @torch.jit.export
1096
+ def reorder_incremental_state(
1097
+ self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
1098
+ ):
1099
+ """Reorder buffered internal state (for incremental generation)."""
1100
+ input_buffer = self._get_input_buffer(incremental_state)
1101
+ if input_buffer is not None:
1102
+ for k in input_buffer.keys():
1103
+ input_buffer_k = input_buffer[k]
1104
+ if input_buffer_k is not None:
1105
+ if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(
1106
+ 0
1107
+ ):
1108
+ break
1109
+ input_buffer[k] = input_buffer_k.index_select(0, new_order)
1110
+ incremental_state = self._set_input_buffer(incremental_state, input_buffer)
1111
+ return incremental_state
1112
+
1113
+ def _get_input_buffer(
1114
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
1115
+ ) -> Dict[str, Optional[Tensor]]:
1116
+ result = self.get_incremental_state(incremental_state, "attn_state")
1117
+ if result is not None:
1118
+ return result
1119
+ else:
1120
+ empty_result: Dict[str, Optional[Tensor]] = {}
1121
+ return empty_result
1122
+
1123
+ def _set_input_buffer(
1124
+ self,
1125
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
1126
+ buffer: Dict[str, Optional[Tensor]],
1127
+ ):
1128
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
1129
+
1130
+ def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
1131
+ return attn_weights
1132
+
1133
+ def upgrade_state_dict_named(self, state_dict, name):
1134
+ prefix = name + "." if name != "" else ""
1135
+ items_to_add = {}
1136
+ keys_to_remove = []
1137
+ for k in state_dict.keys():
1138
+ if k.endswith(prefix + "in_proj_weight"):
1139
+ # in_proj_weight used to be q + k + v with same dimensions
1140
+ dim = int(state_dict[k].shape[0] / 3)
1141
+ items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
1142
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
1143
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
1144
+
1145
+ keys_to_remove.append(k)
1146
+
1147
+ k_bias = prefix + "in_proj_bias"
1148
+ if k_bias in state_dict.keys():
1149
+ dim = int(state_dict[k].shape[0] / 3)
1150
+ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
1151
+ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
1152
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
1153
+
1154
+ keys_to_remove.append(prefix + "in_proj_bias")
1155
+
1156
+ for k in keys_to_remove:
1157
+ del state_dict[k]
1158
+
1159
+ for key, value in items_to_add.items():
1160
+ state_dict[key] = value
1161
+
1162
+
1163
+ def rotate_half(x):
1164
+ x1, x2 = x.chunk(2, dim=-1)
1165
+ return torch.cat((-x2, x1), dim=-1)
1166
+
1167
+
1168
+ def apply_rotary_pos_emb(x, cos, sin):
1169
+ cos = cos[:, : x.shape[-2], :]
1170
+ sin = sin[:, : x.shape[-2], :]
1171
+
1172
+ return (x * cos) + (rotate_half(x) * sin)
1173
+
1174
+
1175
+ class RotaryEmbedding(torch.nn.Module):
1176
+ """
1177
+ The rotary position embeddings from RoFormer_ (Su et. al).
1178
+ A crucial insight from the method is that the query and keys are
1179
+ transformed by rotation matrices which depend on the relative positions.
1180
+ Other implementations are available in the Rotary Transformer repo_ and in
1181
+ GPT-NeoX_, GPT-NeoX was an inspiration
1182
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
1183
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
1184
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
1185
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
1186
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
1187
+ """
1188
+
1189
+ def __init__(self, dim: int, *_, **__):
1190
+ super().__init__()
1191
+ # Generate and save the inverse frequency buffer (non trainable)
1192
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
1193
+ self.register_buffer("inv_freq", inv_freq)
1194
+
1195
+ self._seq_len_cached = None
1196
+ self._cos_cached = None
1197
+ self._sin_cached = None
1198
+
1199
+ def _update_cos_sin_tables(self, x, seq_dimension=1):
1200
+ seq_len = x.shape[seq_dimension]
1201
+
1202
+ # Reset the tables if the sequence length has changed,
1203
+ # or if we're on a new device (possibly due to tracing for instance)
1204
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
1205
+ self._seq_len_cached = seq_len
1206
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
1207
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
1208
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
1209
+
1210
+ self._cos_cached = emb.cos()[None, :, :]
1211
+ self._sin_cached = emb.sin()[None, :, :]
1212
+
1213
+ return self._cos_cached, self._sin_cached
1214
+
1215
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1216
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
1217
+
1218
+ return (
1219
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
1220
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
1221
+ )
1222
+
1223
+
1224
+
1225
+
multi_label_metrics.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: sanyuan.**@**.com
7
+ @tel: 137****6540
8
+ @datetime: 2022/11/26 21:05
9
+ @project: LucaOne
10
+ @file: multi_label_metrics.py
11
+ @desc: metrics for multi-label classification
12
+ '''
13
+ import csv
14
+ import numpy as np
15
+ import torch
16
+ from sklearn.metrics import roc_auc_score, average_precision_score
17
+
18
+
19
+ def multi_label_acc(targets, probs, threshold=0.5):
20
+ targets_relevant = relevant_indexes(targets)
21
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
22
+ acc_list = []
23
+ for idx in range(targets.shape[0]):
24
+ target_relevant = targets_relevant[idx]
25
+ pred_relevant = preds_relevant[idx]
26
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
27
+ intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
28
+ if union_len == 0:
29
+ acc_list.append(1.0)
30
+ else:
31
+ # acc
32
+ acc = 1.0 - (union_len - intersection_len) / targets.shape[1]
33
+ acc_list.append(acc)
34
+ return round(sum(acc_list)/len(acc_list), 6) if len(acc_list) > 0 else 0
35
+
36
+
37
+ def multi_label_precision(targets, probs, threshold=0.5):
38
+ targets_relevant = relevant_indexes(targets)
39
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
40
+ prec_list = []
41
+
42
+ for idx in range(targets.shape[0]):
43
+ target_relevant = targets_relevant[idx]
44
+ pred_relevant = preds_relevant[idx]
45
+ target_len = len(target_relevant)
46
+ predict_len = len(pred_relevant)
47
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
48
+ intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
49
+ if union_len == 0:
50
+ prec_list.append(1.0)
51
+ else:
52
+ # precision
53
+ prec = 0.0
54
+ if predict_len > 0:
55
+ prec = intersection_len / predict_len
56
+ prec_list.append(prec)
57
+
58
+ round(sum(prec_list)/len(prec_list), 6) if len(prec_list) > 0 else 0
59
+
60
+
61
+ def multi_label_recall(targets, probs, threshold=0.5):
62
+ targets_relevant = relevant_indexes(targets)
63
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
64
+ recall_list = []
65
+ for idx in range(targets.shape[0]):
66
+ target_relevant = targets_relevant[idx]
67
+ pred_relevant = preds_relevant[idx]
68
+ target_len = len(target_relevant)
69
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
70
+ intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
71
+ if union_len == 0:
72
+ recall_list.append(1.0)
73
+ else:
74
+ # recall
75
+ if target_len > 0:
76
+ recall = intersection_len / target_len
77
+ else:
78
+ recall = 1.0
79
+ recall_list.append(recall)
80
+ return round(sum(recall_list)/len(recall_list), 6) if len(recall_list) > 0 else 0
81
+
82
+
83
+ def multi_label_jaccard(targets, probs, threshold=0.5):
84
+ targets_relevant = relevant_indexes(targets)
85
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
86
+ jaccard_list = []
87
+ for idx in range(targets.shape[0]):
88
+ target_relevant = targets_relevant[idx]
89
+ pred_relevant = preds_relevant[idx]
90
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
91
+ intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
92
+ if union_len == 0:
93
+ jaccard_list.append(1.0)
94
+ else:
95
+ # jaccard sim
96
+ jac = intersection_len / union_len
97
+ jaccard_list.append(jac)
98
+ return round(sum(jaccard_list)/len(jaccard_list), 6) if len(jaccard_list) > 0 else 0
99
+
100
+
101
+ def multi_label_f1(targets, probs, threshold=0.5):
102
+ targets_relevant = relevant_indexes(targets)
103
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
104
+ f1_list = []
105
+ for idx in range(targets.shape[0]):
106
+ target_relevant = targets_relevant[idx]
107
+ pred_relevant = preds_relevant[idx]
108
+ target_len = len(target_relevant)
109
+ predict_len = len(pred_relevant)
110
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
111
+ intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
112
+ if union_len == 0:
113
+ f1_list.append(1.0)
114
+ else:
115
+ # precision
116
+ prec = 0.0
117
+
118
+ # recall
119
+ if target_len > 0:
120
+ recall = intersection_len / target_len
121
+ else:
122
+ recall = 1.0
123
+ # f1
124
+ if prec + recall == 0:
125
+ f1 = 0.0
126
+ else:
127
+ f1 = 2.0 * prec * recall / (prec + recall)
128
+ f1_list.append(f1)
129
+ return round(sum(f1_list)/len(f1_list), 6) if len(f1_list) > 0 else 0
130
+
131
+
132
+ def multi_label_roc_auc(targets, probs, threshold=0.5):
133
+ targets_relevant = relevant_indexes(targets)
134
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
135
+ roc_auc_list = []
136
+ for idx in range(targets.shape[0]):
137
+ target_relevant = targets_relevant[idx]
138
+ pred_relevant = preds_relevant[idx]
139
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
140
+ if union_len == 0:
141
+ roc_auc_list.append(1.0)
142
+ else:
143
+ # roc_auc
144
+ if len(np.unique(targets[idx, :])) > 1:
145
+ roc_auc = roc_auc_macro(targets[idx, :], probs[idx, :])
146
+ roc_auc_list.append(roc_auc)
147
+ return round(sum(roc_auc_list)/len(roc_auc_list), 6) if len(roc_auc_list) > 0 else 0
148
+
149
+
150
+ def multi_label_pr_auc(targets, probs, threshold=0.5):
151
+ targets_relevant = relevant_indexes(targets)
152
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
153
+ pr_auc_list = []
154
+ for idx in range(targets.shape[0]):
155
+ target_relevant = targets_relevant[idx]
156
+ pred_relevant = preds_relevant[idx]
157
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
158
+ if union_len == 0:
159
+ pr_auc_list.append(1.0)
160
+ else:
161
+ # roc_auc
162
+ if len(np.unique(targets[idx, :])) > 1:
163
+
164
+ pr_auc = pr_auc_macro(targets[idx, :], probs[idx, :])
165
+ pr_auc_list.append(pr_auc)
166
+
167
+ return round(sum(pr_auc_list)/len(pr_auc_list), 6) if len(pr_auc_list) > 0 else 0
168
+
169
+
170
+ def metrics_multi_label(targets, probs, threshold=0.5):
171
+ '''
172
+ metrics of multi-label classification
173
+ cal metrics for true matrix to predict probability matrix
174
+ :param targets: true 0-1 indicator matrix (n_samples, n_labels)
175
+ :param probs: probs 0~1 probability matrix (n_samples, n_labels)
176
+ :param threshold: negative-positive threshold
177
+ :return: some metrics
178
+ '''
179
+ targets_relevant = relevant_indexes(targets)
180
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
181
+ acc_list = []
182
+ prec_list = []
183
+ recall_list = []
184
+ jaccard_list = []
185
+ f1_list = []
186
+ roc_auc_list = []
187
+ pr_auc_list = []
188
+ for idx in range(targets.shape[0]):
189
+ target_relevant = targets_relevant[idx]
190
+ pred_relevant = preds_relevant[idx]
191
+ target_len = len(target_relevant)
192
+ predict_len = len(pred_relevant)
193
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
194
+ intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
195
+ if union_len == 0:
196
+ acc_list.append(1.0)
197
+ prec_list.append(1.0)
198
+ recall_list.append(1.0)
199
+ roc_auc_list.append(1.0)
200
+ jaccard_list.append(1.0)
201
+ f1_list.append(1.0)
202
+ pr_auc_list.append(1.0)
203
+ else:
204
+ # acc
205
+ acc = 1.0 - (union_len - intersection_len) / targets.shape[1]
206
+ acc_list.append(acc)
207
+
208
+ # precision
209
+ prec = 0.0
210
+ if predict_len > 0:
211
+ prec = intersection_len / predict_len
212
+ prec_list.append(prec)
213
+
214
+ # recall
215
+ if target_len > 0:
216
+ recall = intersection_len / target_len
217
+ else:
218
+ recall = 1.0
219
+ recall_list.append(recall)
220
+
221
+ # jaccard sim
222
+ jac = intersection_len / union_len
223
+ jaccard_list.append(jac)
224
+
225
+ # f1
226
+ if prec + recall == 0:
227
+ f1 = 0.0
228
+ else:
229
+ f1 = 2.0 * prec * recall / (prec + recall)
230
+ f1_list.append(f1)
231
+
232
+ # roc_auc
233
+ if len(np.unique(targets[idx, :])) > 1:
234
+ roc_auc = roc_auc_macro(targets[idx, :], probs[idx, :])
235
+ roc_auc_list.append(roc_auc)
236
+ pr_auc = pr_auc_macro(targets[idx, :], probs[idx, :])
237
+ pr_auc_list.append(pr_auc)
238
+
239
+ f_max_value, p_max_value, r_max_value, t_max_value, preds_max_value = f_max(targets, probs)
240
+ return {
241
+ "acc": round(float(sum(acc_list)/len(acc_list)), 6) if len(acc_list) > 0 else 0,
242
+ "jaccard": round(float(sum(jaccard_list)/len(jaccard_list)), 6) if len(jaccard_list) > 0 else 0,
243
+ "prec": round(float(sum(prec_list)/len(prec_list)), 6) if len(prec_list) > 0 else 0,
244
+ "recall": round(float(sum(recall_list)/len(recall_list)), 6) if len(recall_list) > 0 else 0,
245
+ "f1": round(float(sum(f1_list)/len(f1_list)), 6) if len(f1_list) > 0 else 0,
246
+ "pr_auc": round(float(sum(pr_auc_list)/len(pr_auc_list)), 6) if len(pr_auc_list) > 0 else 0,
247
+ "roc_auc": round(float(sum(roc_auc_list)/len(roc_auc_list)), 6) if len(roc_auc_list) > 0 else 0,
248
+ "fmax": round(float(f_max_value), 6),
249
+ "pmax": round(float(p_max_value), 6) ,
250
+ "rmax": round(float(r_max_value), 6),
251
+ "tmax": round(float(t_max_value), 6)
252
+ }
253
+
254
+
255
+ def f_max(targets, probs, gos=None):
256
+ '''
257
+ f-max for multi-label classification
258
+ :param targets: true 0-1 indicator matrix (n_samples, n_labels)
259
+ :param probs: probs 0~1 probability matrix (n_samples, n_labels)
260
+ :param gos:
261
+ :return: fmax, p_max(precision max), r_max(recall max), t_max(classificaton threshold), preds_max(0-1 indicator matrix)
262
+ '''
263
+ preds_max = None
264
+ f_max = 0
265
+ p_max = 0
266
+ r_max = 0
267
+ t_max = 0
268
+ # from 0.01 to 1 (100 thresholds)
269
+ for tt in range(1, 101):
270
+ threshold = tt / 100.0
271
+ preds = (probs > threshold).astype(np.int32)
272
+ p = 0.0
273
+ r = 0.0
274
+ total = 0
275
+ p_total = 0
276
+ for i in range(preds.shape[0]):
277
+ tp = np.sum(preds[i, :] * targets[i, :])
278
+ fp = np.sum(preds[i, :]) - tp
279
+ fn = np.sum(targets[i, :]) - tp
280
+ if gos:
281
+ fn += gos[i]
282
+
283
+ if tp == 0 and fp == 0 and fn == 0:
284
+ continue
285
+ total += 1
286
+ if tp != 0:
287
+ p_total += 1
288
+ precision = tp / (1.0 * (tp + fp))
289
+ recall = tp / (1.0 * (tp + fn))
290
+ p += precision
291
+ r += recall
292
+
293
+ if total > 0 and p_total > 0:
294
+ r /= total
295
+ p /= p_total
296
+ if p + r > 0:
297
+ f = 2 * p * r / (p + r)
298
+ if f_max < f:
299
+ f_max = f
300
+ p_max = p
301
+ r_max = r
302
+ t_max = threshold
303
+ preds_max = preds
304
+
305
+ return f_max, p_max, r_max, t_max, preds_max
306
+
307
+
308
+ def metrics_multi_label_for_pred(targets, preds, savepath=None):
309
+ '''
310
+ metrics for multi-label classification
311
+ cal metrics for true matrix to predict
312
+ :param targets: true 0-1 indicator matrix (n_samples, n_labels)
313
+ :param preds: preds 0~1 indicator matrix (n_samples, n_labels)
314
+ :return: some metrics
315
+ '''
316
+ targets_relevant = relevant_indexes(targets)
317
+ preds_relevant = relevant_indexes(preds)
318
+ acc_list = []
319
+ prec_list = []
320
+ recall_list = []
321
+ jaccard_list = []
322
+ f1_list = []
323
+ for idx in range(targets.shape[0]):
324
+ target_relevant = targets_relevant[idx]
325
+ pred_relevant = preds_relevant[idx]
326
+
327
+ target_len = len(target_relevant)
328
+ predict_len = len(pred_relevant)
329
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
330
+ intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
331
+ acc = 1.0 - (union_len - intersection_len) / targets.shape[1]
332
+ prec = 0.0
333
+ if predict_len > 0:
334
+ prec = intersection_len / predict_len
335
+ recall = 0
336
+ if target_len > 0:
337
+ recall = intersection_len / target_len
338
+ else:
339
+ print(targets[idx])
340
+ jac = intersection_len / union_len
341
+ if prec + recall == 0:
342
+ f1 = 0.0
343
+ else:
344
+ f1 = 2.0 * prec * recall / (prec + recall)
345
+
346
+ acc_list.append(acc)
347
+ prec_list.append(prec)
348
+ recall_list.append(recall)
349
+ jaccard_list.append(jac)
350
+ f1_list.append(f1)
351
+
352
+ return {
353
+ "acc": round(sum(acc_list)/targets.shape[0], 6),
354
+ "jaccard": round(sum(jaccard_list)/targets.shape[0], 6),
355
+ "prec": round(sum(prec_list)/targets.shape[0], 6),
356
+ "recall": round(sum(recall_list)/targets.shape[0], 6),
357
+ "f1": round(sum(f1_list)/targets.shape[0], 6)
358
+ }
359
+
360
+
361
+ def label_id_2_array(label_ids, label_size):
362
+ '''
363
+ building 0-1 indicator array for multi-label classification
364
+ :param label_ids:
365
+ :param label_size:
366
+ :return:
367
+ '''
368
+ arr = np.zeros(label_size)
369
+ arr[label_ids] = 1
370
+ return arr
371
+
372
+
373
+ def relevant_indexes(matrix):
374
+ '''
375
+ Which positions in the multi-label are labeled as 1
376
+ :param matrix:
377
+ :return:
378
+ '''
379
+ if torch.is_tensor(matrix):
380
+ matrix = matrix.detach().cpu().numpy()
381
+ relevants = []
382
+ shape = matrix.shape
383
+ if matrix.ndim == 3:
384
+
385
+ for x in range(shape[0]):
386
+ relevant_x = []
387
+ for y in range(shape[1]):
388
+ relevant_y = []
389
+ for z in range(shape[2]):
390
+ if matrix[x, y, z] == 1:
391
+ relevant_y.append(int(z))
392
+ relevant_x.append(relevant_y)
393
+ relevants.append(relevant_x)
394
+ elif matrix.ndim == 2:
395
+ for row in range(shape[0]):
396
+ relevant = []
397
+ for col in range(shape[1]):
398
+ if matrix[row, col] == 1:
399
+ relevant.append(int(col))
400
+ relevants.append(relevant)
401
+ else:
402
+ for idx in range(matrix.shape[0]):
403
+ if matrix[idx] == 1:
404
+ relevants.append(int(idx))
405
+ return relevants
406
+
407
+
408
+ def irrelevant_indexes(matrix):
409
+ '''
410
+ Which positions in the multi-label label are 0
411
+ :param matrix:
412
+ :return:
413
+ '''
414
+ if torch.is_tensor(matrix):
415
+ matrix = matrix.detach().cpu().numpy()
416
+
417
+ irrelevants = []
418
+ if matrix.ndim == 3:
419
+ for x in range(matrix.shape[0]):
420
+ irrelevant_x = []
421
+ for y in range(matrix.shape[1]):
422
+ irrelevant_y = []
423
+ for z in range(matrix.shape[2]):
424
+ if matrix[x, y, z] == 0:
425
+ irrelevant_y.append(int(z))
426
+ irrelevant_x.append(irrelevant_y)
427
+ irrelevants.append(irrelevant_x)
428
+ elif matrix.ndim == 2:
429
+ for row in range(matrix.shape[0]):
430
+ irrelevant = []
431
+ for col in range(matrix.shape[1]):
432
+ if matrix[row, col] == 1:
433
+ irrelevant.append(int(col))
434
+ irrelevants.append(irrelevant)
435
+ else:
436
+ for idx in range(matrix.shape[0]):
437
+ if matrix[idx] == 1:
438
+ irrelevants.append(int(idx))
439
+
440
+ return irrelevants
441
+
442
+
443
+ def prob_2_pred(prob, threshold):
444
+ '''
445
+ Probabilities converted to 0-1 predicted labels
446
+ :param prob:
447
+ :param threshold:
448
+ :return:
449
+ '''
450
+ if torch.is_tensor(prob):
451
+ prob = prob.detach().cpu().numpy()
452
+
453
+ if isinstance(prob, (np.ndarray, np.generic)):
454
+ return (prob >= threshold).astype(int)
455
+
456
+
457
+ def roc_auc_macro(target, prob):
458
+ '''
459
+ macro roc auc
460
+ :param target:
461
+ :param prob:
462
+ :return:
463
+ '''
464
+ return roc_auc_score(target, prob, average="macro")
465
+
466
+
467
+ def pr_auc_macro(target, prob):
468
+ '''
469
+ macro pr-auc
470
+ :param target:
471
+ :param prob:
472
+ :return:
473
+ '''
474
+ return average_precision_score(target, prob, average="macro")
475
+
476
+
477
+ def write_error_samples_multi_label(filepath, samples, input_indexs, input_id_2_names, output_id_2_name, targets,
478
+ probs, threshold=0.5,
479
+ use_other_diags=False, use_other_operas=False, use_checkin_department=False):
480
+ '''
481
+ writer bad cases for multi-label classification
482
+ :param filepath:
483
+ :param samples:
484
+ :param input_indexs:
485
+ :param input_id_2_names:
486
+ :param output_id_2_name:
487
+ :param targets:
488
+ :param probs:
489
+ :param threshold:
490
+ :param use_other_diags:
491
+ :param use_other_operas:
492
+ :param use_checkin_department:
493
+ :return:
494
+ '''
495
+ preds = prob_2_pred(probs, threshold=threshold)
496
+ targets_relevant = relevant_indexes(targets)
497
+ preds_relevant = relevant_indexes(preds)
498
+ with open(filepath, "w") as fp:
499
+ writer = csv.writer(fp)
500
+ writer.writerow(["score", "y_true", "y_pred", "inputs"])
501
+ for i in range(len(targets_relevant)):
502
+ target = set(targets_relevant[i])
503
+ pred = set(preds_relevant[i])
504
+ jacc = len(target.intersection(pred))/(len(target.union(pred)))
505
+ if output_id_2_name:
506
+ target_labels = [output_id_2_name[v] for v in target]
507
+ pred_labels = [output_id_2_name[v] for v in pred]
508
+ else:
509
+ target_labels = target
510
+ pred_labels = pred
511
+ sample = samples[i]
512
+ if input_id_2_names:
513
+ new_sample = []
514
+ for idx, input_index in enumerate(input_indexs):
515
+ if input_index == 3 and not use_checkin_department:
516
+ input_index = 12
517
+ new_sample.append([input_id_2_names[idx][v] for v in sample[input_index]])
518
+ if input_index == 6 and use_other_diags or input_index == 8 and use_other_operas or input_index == 10 and use_other_diags:
519
+ new_sample.append([input_id_2_names[idx][v] for v in sample[input_index + 1]])
520
+ else:
521
+ new_sample = sample
522
+ row = [jacc, target_labels, pred_labels, new_sample]
523
+ writer.writerow(row)
524
+
525
+
526
+ if __name__ == "__main__":
527
+ '''multi_label'''
528
+ probs = np.array([[0.6, 0.1, 0.1], [0.8, 0.3, 0.8], [0.8, 0.1, 0.1], [0.8, 0.1, 0.1]])
529
+ targets = np.array([[1, 1, 0], [1, 0, 1], [1, 0, 0], [0, 0, 1]])
530
+ print(metrics_multi_label(targets, probs))
531
+ t = np.array([[0, 0, 0], [1, 1, 1]])
532
+ print(t[0, :])
533
+ print(np.unique(t[0, :]))
534
+
535
+
536
+
pooling.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .modeling_bert import BertEncoder, BertPooler
8
+
9
+ class GlobalMaskMaxPooling1D(nn.Module):
10
+ def __init__(self, ):
11
+ super(GlobalMaskMaxPooling1D, self).__init__()
12
+
13
+ def forward(self, x, mask=None):
14
+ if mask is not None:
15
+ # (B, Seq_len) -> (B, Seq_len, 1)
16
+ mask = 1.0 - mask
17
+ mask = mask * (-2**10 + 1)
18
+ mask = torch.unsqueeze(mask, dim=-1)
19
+ x += mask
20
+ return torch.max(x, dim=1)[0]
21
+
22
+
23
+ class GlobalMaskMinPooling1D(nn.Module):
24
+ def __init__(self, ):
25
+ super(GlobalMaskMinPooling1D, self).__init__()
26
+
27
+ def forward(self, x, mask=None):
28
+ if mask is not None:
29
+ # (B, Seq_len) -> (B, Seq_len, 1)
30
+ mask = 1.0 - mask
31
+ mask = mask * (2**10+1)
32
+ mask = torch.unsqueeze(mask, dim=-1)
33
+ x += mask
34
+ return torch.min(x, dim=1)[0]
35
+
36
+
37
+ class GlobalMaskAvgPooling1D(nn.Module):
38
+ def __init__(self):
39
+ super(GlobalMaskAvgPooling1D, self).__init__()
40
+
41
+ def forward(self, x, mask=None):
42
+ if mask is not None:
43
+ # (B, Seq_len) -> (B, Seq_len, 1)
44
+ mask = torch.unsqueeze(mask, dim=-1)
45
+ x *= mask
46
+ return torch.sum(x, dim=1)/torch.sum(mask, dim=1)
47
+ else:
48
+ return torch.mean(x, dim=1)
49
+
50
+
51
+ class GlobalMaskSumPooling1D(nn.Module):
52
+ def __init__(self, axis):
53
+ '''
54
+ sum pooling
55
+ :param axis: axis=0, add all the rows of the matrix,axis=1, add all the cols of the matrix
56
+ '''
57
+ super(GlobalMaskSumPooling1D, self).__init__()
58
+ self.axis = axis
59
+
60
+ def forward(self, x, mask=None):
61
+ if mask is not None:
62
+ # (B, Seq_len) -> (B, Seq_len, 1)
63
+ mask = torch.unsqueeze(mask, dim=-1)
64
+ x *= mask
65
+ return torch.sum(x, dim=self.axis)
66
+
67
+
68
+ class GlobalMaskWeightedAttentionPooling1D(nn.Module):
69
+ def __init__(self, embed_size, use_bias=False):
70
+ super(GlobalMaskWeightedAttentionPooling1D, self).__init__()
71
+ self.embed_size = embed_size
72
+ self.use_bias = use_bias
73
+
74
+ self.W = nn.Parameter(torch.Tensor(self.embed_size))
75
+ nn.init.trunc_normal_(self.W, std=0.01)
76
+ if self.use_bias:
77
+ self.b = nn.Parameter(torch.Tensor(1))
78
+ nn.init.trunc_normal_(self.b, std=0.01)
79
+
80
+ def forward(self, x, mask=None):
81
+ # (B, Len, Embed) x (Embed,) = (B, Len)
82
+ logits = torch.matmul(x, self.W)
83
+ if self.use_bias:
84
+ logits += self.b
85
+
86
+ if mask is not None:
87
+ attention_probs = nn.Softmax(dim=-1)(logits + (1.0 - mask) * -10000)
88
+ else:
89
+ attention_probs = nn.Softmax(dim=-1)(logits)
90
+ x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
91
+ return x
92
+
93
+
94
+ class GlobalMaskContextAttentionPooling1D(nn.Module):
95
+ def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False):
96
+ super(GlobalMaskContextAttentionPooling1D, self).__init__()
97
+ self.embed_size = embed_size
98
+ self.use_additive_bias = use_additive_bias
99
+ self.use_attention_bias = use_attention_bias
100
+ self.units = units if units else embed_size
101
+
102
+ self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units))
103
+ self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units))
104
+ if self.use_additive_bias:
105
+ self.b1 = nn.Parameter(torch.Tensor(self.units))
106
+ nn.init.trunc_normal_(self.b1, std=0.01)
107
+ if self.use_attention_bias:
108
+ self.b2 = nn.Parameter(torch.Tensor(1))
109
+ nn.init.trunc_normal_(self.b2, std=0.01)
110
+
111
+ self.c = nn.Parameter(torch.Tensor(self.units))
112
+
113
+ nn.init.trunc_normal_(self.U, std=0.01)
114
+ nn.init.trunc_normal_(self.V, std=0.01)
115
+ nn.init.trunc_normal_(self.c, std=0.01)
116
+
117
+ def forward(self, x, mask=None):
118
+ # (B, Len, Embed) x (Embed, Units) = (B, Len, Units)
119
+ q = torch.matmul(x, self.U)
120
+ k = torch.matmul(x, self.V)
121
+ if self.use_additive_bias:
122
+ h = torch.tanh(q + k + self.b1)
123
+ else:
124
+ h = torch.tanh(q + k)
125
+
126
+ if self.use_attention_bias:
127
+ e = torch.matmul(h, self.c) + self.b2
128
+ else:
129
+ e = torch.matmul(h, self.c)
130
+ if mask is not None:
131
+ attention_probs = nn.Softmax(dim=-1)(e + (1.0 - mask) * -10000)
132
+ else:
133
+ attention_probs = nn.Softmax(dim=-1)(e)
134
+ x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
135
+ return x
136
+
137
+
138
+ class GlobalMaskValueAttentionPooling1D(nn.Module):
139
+ def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False):
140
+ super(GlobalMaskValueAttentionPooling1D, self).__init__()
141
+ self.embed_size = embed_size
142
+ self.use_additive_bias = use_additive_bias
143
+ self.use_attention_bias = use_attention_bias
144
+ self.units = units if units else embed_size
145
+
146
+ self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units))
147
+ self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units))
148
+ if self.use_additive_bias:
149
+ self.b1 = nn.Parameter(torch.Tensor(self.units))
150
+ nn.init.trunc_normal_(self.b1, std=0.01)
151
+ if self.use_attention_bias:
152
+ self.b2 = nn.Parameter(torch.Tensor(self.embed_size))
153
+ nn.init.trunc_normal_(self.b2, std=0.01)
154
+
155
+ self.W = nn.Parameter(torch.Tensor(self.units, self.embed_size))
156
+
157
+ nn.init.trunc_normal_(self.U, std=0.01)
158
+ nn.init.trunc_normal_(self.V, std=0.01)
159
+ nn.init.trunc_normal_(self.W, std=0.01)
160
+
161
+ def forward(self, x, mask=None):
162
+ # (B, Len, Embed) x (Embed, Units) = (B, Len, Units)
163
+ q = torch.matmul(x, self.U)
164
+ k = torch.matmul(x, self.V)
165
+ if self.use_additive_bias:
166
+ h = torch.tanh(q + k + self.b1)
167
+ else:
168
+ h = torch.tanh(q + k)
169
+
170
+ # (B, Len, Units) x (Units, Embed) = (B, Len, Embed)
171
+ if self.use_attention_bias:
172
+ e = torch.matmul(h, self.W) + self.b2
173
+ else:
174
+ e = torch.matmul(h, self.W)
175
+ if mask is not None:
176
+ attention_probs = nn.Softmax(dim=1)(e + torch.unsqueeze((1.0 - mask) * -10000, dim=-1))
177
+ else:
178
+ attention_probs = nn.Softmax(dim=1)(e)
179
+ x = torch.sum(attention_probs * x, dim=1)
180
+ return x
181
+
182
+ def __repr__(self):
183
+ return self.__class__.__name__ + ' (' + str(self.embed_size) + ' -> ' + str(self.embed_size) + ')'
184
+
185
+
186
+ class GlobalMaskTransformerPooling1D(nn.Module):
187
+ def __init__(self, config):
188
+ super(GlobalMaskTransformerPooling1D, self).__init__()
189
+ self.embeddings = nn.Parameter(torch.Tensor(1, 1, config.hidden_size))
190
+ nn.init.trunc_normal_(self.embeddings, std=0.02)
191
+ config.num_hidden_layers = 2
192
+ self.encoder = BertEncoder(config)
193
+ self.pooler = BertPooler(config)
194
+
195
+ def forward(self, x, mask=None):
196
+ B, Seq_len, Enbed = x.size()
197
+ cls_emb_batch = self.embeddings.expand(B, 1, Enbed)
198
+ merged_output = torch.cat((cls_emb_batch, x), dim=1) # [B, Seq_len + 1, Enbed]
199
+ if mask is not None:
200
+ device = x.device
201
+ cls_mask = torch.ones(B, 1).to(device)
202
+ mask = torch.cat([cls_mask, mask], dim=1)
203
+ mask = mask[:, None, None, :]
204
+
205
+ sequence_output = self.encoder(merged_output,
206
+ attention_mask=mask,
207
+ head_mask=None,
208
+ encoder_hidden_states=None,
209
+ encoder_attention_mask=None,
210
+ output_attentions=False,
211
+ output_hidden_states=False,
212
+ return_dict=False)[0]
213
+ pooled_output = self.pooler(sequence_output)
214
+ return pooled_output
215
+
216
+
217
+ class GlobalMaxPool1d(nn.Module):
218
+ def __init__(self):
219
+ super(GlobalMaxPool1d,self).__init__()
220
+ self.fc = nn.AdaptiveMaxPool1d(1)
221
+
222
+ def forward(self, x):
223
+ x = x.permute(0, 2, 1)
224
+ x = self.fc(x)
225
+ x = torch.squeeze(x, dim=-1)
226
+ return x
227
+
228
+
229
+ class GlobalAvgPool1d(nn.Module):
230
+ def __init__(self, ):
231
+ super(GlobalAvgPool1d, self).__init__()
232
+ self.fc = nn.AdaptiveAvgPool1d(1)
233
+
234
+ def forward(self, x):
235
+ x = x.permute(0, 2, 1)
236
+ x = self.fc(x)
237
+ x = torch.squeeze(x, dim=-1)
238
+ return x
239
+
240
+
241
+ class AttentionPool1d(nn.Module):
242
+ def __init__(self, embed_size, device="cuda"):
243
+ super(AttentionPool1d, self).__init__()
244
+ self.embed_size = embed_size
245
+ self.W = nn.Parameter(torch.Tensor(self.embed_size, self.embed_size))
246
+ self.b = nn.Parameter(torch.Tensor(self.embed_size))
247
+ self.c = nn.Parameter(torch.Tensor(self.embed_size))
248
+ nn.init.trunc_normal_(self.W, std=0.02)
249
+ nn.init.trunc_normal_(self.b, std=0.02)
250
+ nn.init.trunc_normal_(self.c, std=0.02)
251
+
252
+ def forward(self, x):
253
+ '''
254
+ # x:(B, Seq_len, Enbed)
255
+ # mul: (B, Seq_len)
256
+ mul = torch.matmul(x, self.w)
257
+ # B, Seq_len
258
+ attention_probs = nn.Softmax(dim=-1)(mul)
259
+ # B, Seq_len
260
+ x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
261
+ '''
262
+ mul = torch.tanh(torch.matmul(x, self.W) + self.b)
263
+ mul = torch.matmul(mul, self.c)
264
+ attention_probs = nn.Softmax(dim=-1)(mul)
265
+ x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
266
+ return x
267
+
268
+
269
+ class TransformerPool1d(nn.Module):
270
+ def __init__(self, config, embeddings, embed_size, num_transformer_layers=2, CLS_ID=102, device="cuda"):
271
+ super(TransformerPool1d, self).__init__()
272
+ if embeddings:
273
+ self.embeddings = embeddings
274
+ else:
275
+ self.embeddings = nn.Parameter(torch.Tensor(1, 1, embed_size))
276
+ nn.init.trunc_normal_(self.embeddings, std=0.02)
277
+ # self.embeddings = BertEmbeddings(config)
278
+ self.CLS_ID = CLS_ID
279
+ self.device = device
280
+ config.num_hidden_layers = num_transformer_layers
281
+ self.encoder = BertEncoder(config)
282
+ self.pooler = BertPooler(config)
283
+
284
+ def forward(self, x):
285
+ # x:(B, Seq_len, Enbed)
286
+ B, Seq_len, Enbed = x.size()
287
+ #cls_emb_batch = self.embeddings(torch.tensor([[self.CLS_ID]] * x.size()[0], dtype=torch.long).to(self.device)) # B, 1
288
+ cls_emb_batch = self.embeddings.expand(B, 1, Enbed)
289
+ merged_output = torch.cat((cls_emb_batch, x), dim=1) # [B, Seq_len + 1, Enbed]
290
+ sequence_output = self.encoder(merged_output,
291
+ attention_mask=None,
292
+ head_mask=None,
293
+ encoder_hidden_states=None,
294
+ encoder_attention_mask=None,
295
+ output_attentions=False,
296
+ output_hidden_states=False,
297
+ return_dict=False)[0]
298
+ pooled_output = self.pooler(sequence_output)
299
+ return pooled_output
300
+
301
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:234ed601e664ca2e736f2427dfb8544b47370f641bbd82612297efca3943892a
3
+ size 6320919985
regression_loss.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: [email protected]
7
+ @tel: 137****6540
8
+ @datetime: 2023/6/15 22:53
9
+ @project: LucaOne
10
+ @file: regression_loss.py
11
+ @desc: regression loss
12
+ '''
13
+ import warnings
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ from statsmodels.stats.stattools import durbin_watson
18
+
19
+ from .masked_loss import _MaskedLoss
20
+
21
+
22
+ def nanstd(input, dim=None, keepdim=False):
23
+ mu = torch.nanmean(input, dim=dim, keepdim=True)
24
+ return torch.sqrt(torch.nanmean((input - mu)**2, dim=dim, keepdim=keepdim))
25
+
26
+
27
+ def iqr(batch, dim=None, reduction='mean'):
28
+ if dim is None:
29
+ if len(batch.shape) == 1:
30
+ dim = 0
31
+ else:
32
+ dim = 1
33
+ if isinstance(batch, np.ndarray):
34
+ out = np.quantile(batch, 0.75, axis=dim) - \
35
+ np.quantile(batch, 0.25, axis=dim)
36
+ elif isinstance(batch, torch.Tensor):
37
+ out = torch.quantile(batch, 0.75, dim=dim) - \
38
+ torch.quantile(batch, 0.25, dim=dim)
39
+ if reduction == 'none':
40
+ return out
41
+ elif reduction == 'mean':
42
+ return out.mean()
43
+ else:
44
+ raise NotImplementedError
45
+
46
+
47
+ def naniqr(batch, dim=None, reduction='none'):
48
+ if dim is None:
49
+ if len(batch.shape) == 1:
50
+ dim = 0
51
+ else:
52
+ dim = 1
53
+ if isinstance(batch, np.ndarray):
54
+ out = np.nanquantile(batch, 0.75, axis=dim) - \
55
+ np.nanquantile(batch, 0.25, axis=dim)
56
+ elif isinstance(batch, torch.Tensor):
57
+ out = torch.nanquantile(batch, 0.75, dim=dim) - \
58
+ torch.nanquantile(batch, 0.25, dim=dim)
59
+ if reduction == 'none':
60
+ return out
61
+ elif reduction == 'mean':
62
+ return out.mean()
63
+ elif reduction == 'nanmean':
64
+ return torch.nanmean(out)
65
+ else:
66
+ raise NotImplementedError
67
+
68
+
69
+ def compute_dw(res, dim=1, replace_missing=0., reduction='none'):
70
+ """Durbin-Watson statistics
71
+ https://www.statsmodels.org/devel/generated/statsmodels.stats.stattools.durbin_watson.html
72
+ """
73
+ if isinstance(res, torch.Tensor):
74
+ res = res.detach().cpu().numpy()
75
+ if replace_missing is not None:
76
+ res = res.copy()
77
+ res[np.isnan(res)] = replace_missing
78
+ out = durbin_watson(res, axis=dim)
79
+ if reduction == 'mean':
80
+ return out.mean()
81
+ elif reduction == 'none':
82
+ return out
83
+ elif reduction == 'median':
84
+ return np.median(out)
85
+
86
+
87
+ def estimate_noise(x, dim=1, window_size=10, step=5, reduce='nanmean', keepdim=True):
88
+ noises = nanstd(x.unfold(dim, window_size, step), -1, keepdim=False)
89
+ if reduce == 'nanmedian':
90
+ return noises.nanmedian(dim, keepdim=keepdim).values
91
+ if reduce == 'nanmean':
92
+ return noises.nanmean(dim, keepdim=keepdim)
93
+ if reduce == 'median':
94
+ return noises.median(dim, keepdim=keepdim).values
95
+ if reduce == 'mean':
96
+ return noises.mean(dim, keepdim=keepdim)
97
+ if reduce == 'none':
98
+ return noises
99
+ raise ValueError
100
+
101
+
102
+ class MaskedMSELoss(_MaskedLoss):
103
+ """Masked MSE loss"""
104
+ def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
105
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
106
+ self.criterion = nn.MSELoss(reduction='none')
107
+
108
+
109
+ class MaskedL1Loss(_MaskedLoss):
110
+ """Masked L1 loss."""
111
+
112
+ def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
113
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
114
+ self.criterion = nn.L1Loss(reduction='none')
115
+
116
+
117
+ class MaskedHuberLoss(_MaskedLoss):
118
+ """Masked L1 loss."""
119
+
120
+ def __init__(self, reduction='mean', ignore_nans=True, delta=1, ignore_value=-100.0):
121
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
122
+ self.criterion = nn.HuberLoss(reduction='none', delta=delta)
123
+
124
+
125
+ class IQRLoss(nn.Module):
126
+ "IQR of the residuals"
127
+ def __init__(self, reduction='nanmean', ignore_nans=True, ignore_value=-100.0):
128
+ super().__init__()
129
+ self.reduction = reduction
130
+ self.ignore_nans = ignore_nans
131
+ self.ignore_value = ignore_value
132
+
133
+ def forward(self, input, target=0.):
134
+ if isinstance(target, torch.Tensor) and not (target.size() == input.size()):
135
+ warnings.warn(
136
+ "Using a target size ({}) that is different to the input size ({}). "
137
+ "This will likely lead to incorrect results due to broadcasting. "
138
+ "Please ensure they have the same size.".format(
139
+ target.size(), input.size()),
140
+ stacklevel=2,
141
+ )
142
+ if self.ignore_nans:
143
+ return naniqr(target-input, reduction=self.reduction)
144
+ else:
145
+ return iqr(target-input, reduction=self.reduction)
146
+
147
+
148
+ class MaskedLogCoshLoss(_MaskedLoss):
149
+ def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
150
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
151
+ self.criterion = LogCoshLoss(reduction='none')
152
+
153
+
154
+ class MaskedXTanhLoss(_MaskedLoss):
155
+ def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
156
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
157
+ self.criterion = XTanhLoss(reduction='none')
158
+
159
+
160
+ class MaskedXSigmoidLoss(_MaskedLoss):
161
+ def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
162
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
163
+ self.criterion = XSigmoidLoss(reduction='none')
164
+
165
+
166
+ class MaskedAlgebraicLoss(_MaskedLoss):
167
+ def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
168
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
169
+ self.criterion = AlgebraicLoss(reduction='none')
170
+
171
+
172
+ class LogCoshLoss(torch.nn.Module):
173
+ def __init__(self, reduction='none'):
174
+ super().__init__()
175
+ self.reduction = reduction
176
+
177
+ def forward(self, input, target):
178
+ diff = input - target
179
+ if self.reduction == 'mean':
180
+ return torch.mean(torch.log(torch.cosh(diff + 1e-12)))
181
+ elif self.reduction == 'sum':
182
+ return torch.sum(torch.log(torch.cosh(diff + 1e-12)))
183
+ else:
184
+ return torch.log(torch.cosh(diff + 1e-12))
185
+
186
+
187
+ class XTanhLoss(torch.nn.Module):
188
+ def __init__(self, reduction='none'):
189
+ super().__init__()
190
+ self.reduction = reduction
191
+
192
+ def forward(self, input, target):
193
+ diff = input - target
194
+ if self.reduction == 'mean':
195
+ return torch.mean(diff * torch.tanh(diff))
196
+ elif self.reduction == 'sum':
197
+ return torch.sum(diff * torch.tanh(diff))
198
+ else:
199
+ return diff * torch.tanh(diff)
200
+
201
+
202
+ class XSigmoidLoss(torch.nn.Module):
203
+ def __init__(self, reduction='none'):
204
+ super().__init__()
205
+ self.reduction = reduction
206
+
207
+ def forward(self, input, target):
208
+ diff = input - target
209
+ if self.reduction == 'mean':
210
+ return torch.mean(2 * diff * torch.sigmoid(diff) - diff)
211
+ elif self.reduction == 'sum':
212
+ return torch.sum(2 * diff * torch.sigmoid(diff) - diff)
213
+ else:
214
+ return 2 * diff * torch.sigmoid(diff) - diff
215
+
216
+
217
+ class AlgebraicLoss(torch.nn.Module):
218
+ def __init__(self, reduction='none'):
219
+ super().__init__()
220
+ self.reduction = reduction
221
+
222
+ def forward(self, input, target):
223
+ diff = input - target
224
+ if self.reduction == 'mean':
225
+ return torch.mean(diff * diff / torch.sqrt(1 + diff * diff))
226
+ elif self.reduction == 'sum':
227
+ return torch.sum(diff * diff / torch.sqrt(1 + diff * diff))
228
+ else:
229
+ return diff * diff / torch.sqrt(1 + diff * diff)
230
+
231
+
232
+ if __name__ == "__main__":
233
+ import torch
234
+ label = torch.Tensor([[[1], [1], [-100]], [[1], [-100], [0]]])
235
+ pred = torch.Tensor([[[2], [1], [3]], [[2], [1], [3]]])
236
+ loss = MaskedMSELoss(reduction="mean", ignore_nans=True, ignore_value=-100.0)
237
+ print("loss:")
238
+ print(loss(pred, label))
utils.py ADDED
@@ -0,0 +1,979 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ import math
5
+ import os, csv, json
6
+ import io, textwrap, itertools
7
+ import subprocess
8
+ from Bio import SeqIO
9
+ import torch
10
+ import numpy as np
11
+ import sys, random
12
+ from sklearn.metrics import confusion_matrix
13
+ import matplotlib.pyplot as plt
14
+ import pynvml, requests
15
+ from collections import OrderedDict
16
+
17
+ plt.rcParams.update({'font.size': 18})
18
+ plt.rcParams['axes.unicode_minus'] = False
19
+
20
+ from .file_operator import file_reader
21
+ from .multi_label_metrics import prob_2_pred, relevant_indexes, metrics_multi_label
22
+ from .metrics import metrics_multi_class, metrics_binary, metrics_regression
23
+
24
+ common_nucleotide_set = {'A', 'T', 'C', 'G', 'U', 'N'}
25
+
26
+ # not {'O', 'U', 'Z', 'J', 'B'}
27
+ # Common amino acids
28
+ common_amino_acid_set = {'R', 'X', 'S', 'G', 'W', 'I', 'Q', 'A', 'T', 'V', 'K', 'Y', 'C', 'N', 'L', 'F', 'D', 'M', 'P', 'H', 'E'}
29
+
30
+
31
+ def to_device(device, batch):
32
+ '''
33
+ input to device
34
+ :param device:
35
+ :param batch:
36
+ :return:
37
+ '''
38
+ new_batch = {}
39
+ sample_num = 0
40
+ tens = None
41
+ for item1 in batch.items():
42
+ new_batch[item1[0]] = {}
43
+ if isinstance(item1[1], dict):
44
+ for item2 in item1[1].items():
45
+ new_batch[item1[0]][item2[0]] = {}
46
+ if isinstance(item2[1], dict):
47
+ for item3 in item2[1].items():
48
+ if item3[1] is not None and not isinstance(item3[1], int) and not isinstance(item3[1], str) and not isinstance(item3[1], float):
49
+ new_batch[item1[0]][item2[0]][item3[0]] = item3[1].to(device)
50
+ tens = item3[1]
51
+ else:
52
+ new_batch[item1[0]][item2[0]][item3[0]] = item3[1]
53
+ else:
54
+ if item2[1] is not None and not isinstance(item2[1], int) and not isinstance(item2[1], str) and not isinstance(item2[1], float):
55
+ new_batch[item1[0]][item2[0]] = item2[1].to(device)
56
+ tens = item2[1]
57
+ else:
58
+ new_batch[item1[0]][item2[0]] = item2[1]
59
+ else:
60
+ if item1[1] is not None and not isinstance(item1[1], int) and not isinstance(item1[1], str) and not isinstance(item1[1], float):
61
+ new_batch[item1[0]] = item1[1].to(device)
62
+ tens = item1[1]
63
+ else:
64
+ new_batch[item1[0]] = item1[1]
65
+ if tens is not None:
66
+ sample_num = tens.shape[0]
67
+ return new_batch, sample_num
68
+
69
+
70
+ def get_parameter_number(model):
71
+ '''
72
+ colc the parameter number of the model
73
+ :param model:
74
+ :return:
75
+ '''
76
+ param_size = 0
77
+ param_sum = 0
78
+ trainable_size = 0
79
+ trainable_num = 0
80
+ for param in model.parameters():
81
+ cur_size = param.nelement() * param.element_size()
82
+ cur_num = param.nelement()
83
+ param_size += cur_size
84
+ param_sum += cur_num
85
+ if param.requires_grad:
86
+ trainable_size += cur_size
87
+ trainable_num += cur_num
88
+ buffer_size = 0
89
+ buffer_sum = 0
90
+ for buffer in model.buffers():
91
+ buffer_size += buffer.nelement() * buffer.element_size()
92
+ buffer_sum += buffer.nelement()
93
+ '''
94
+ total_num = sum(p.numel() for p in model.parameters())
95
+ total_size = sum(p.numel() * p.element_size() for p in model.parameters())
96
+ total_num += sum(p.numel() for p in model.buffers())
97
+ total_size += sum(p.numel() * p.element_size() for p in model.buffers())
98
+ trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
99
+ trainable_size = sum(p.numel() * p.element_size() for p in model.parameters() if p.requires_grad)
100
+ '''
101
+ return {
102
+ 'total_num': "%fM" % round((buffer_sum + param_sum)/(1024 * 1024), 2),
103
+ 'total_size': "%fMB" % round((buffer_size + param_size)/(1024 * 1024), 2),
104
+ 'param_sum': "%fM" % round(param_sum/(1024 * 1024), 2),
105
+ 'param_size': "%fMB" % round(param_size/(1024 * 1024), 2),
106
+ 'buffer_sum': "%fM" % round(buffer_sum/(1024 * 1024), 2),
107
+ 'buffer_size': "%fMB" % round(buffer_size/(1024 * 1024), 2),
108
+ 'trainable_num': "%fM" % round(trainable_num/(1024 * 1024), 10),
109
+ 'trainable_size': "%fMB" % round(trainable_size/(1024 * 1024), 10)
110
+ }
111
+
112
+
113
+ def set_seed(args):
114
+ random.seed(args.seed)
115
+ np.random.seed(args.seed)
116
+ torch.manual_seed(args.seed)
117
+ if args.n_gpu > 0:
118
+ torch.cuda.manual_seed(args.seed)
119
+ torch.cuda.manual_seed_all(args.seed)
120
+
121
+
122
+ def label_id_2_label_name(output_mode, label_list, prob, threshold=0.5):
123
+ '''
124
+ convect label id to label name
125
+ :param output_mode:
126
+ :param label_list:
127
+ :param prob:
128
+ :param threshold:
129
+ :return:
130
+ '''
131
+ if output_mode in ["multi-label", "multi_label"]:
132
+ res = []
133
+ pred = prob_2_pred(prob, threshold)
134
+ pred_index = relevant_indexes(pred)
135
+ for row in range(prob.shape[0]):
136
+ label_names = [label_list[idx] for idx in pred_index[row]]
137
+ res.append(label_names)
138
+ return res
139
+ elif output_mode in ["multi-class", "multi_class"]:
140
+ pred = np.argmax(prob, axis=1)
141
+ label_names = [label_list[idx] for idx in pred]
142
+ return label_names
143
+ elif output_mode in ["binary-class", "binary_class"]:
144
+ if prob.ndim == 2:
145
+ prob = prob.flatten(order="C")
146
+ pred = prob_2_pred(prob, threshold)
147
+ label_names = [label_list[idx] for idx in pred]
148
+ return label_names
149
+ else:
150
+ raise KeyError(output_mode)
151
+
152
+
153
+ def plot_bins(data, xlabel, ylabel, bins, filepath):
154
+ '''
155
+ plot bins
156
+ :param data:
157
+ :param xlabel:
158
+ :param ylabel:
159
+ :param bins: bins number
160
+ :param filepath: png save filepath
161
+ :return:
162
+ '''
163
+ plt.figure(figsize=(40, 20), dpi=100)
164
+ plt.hist(data, bins=bins)
165
+ # plt.xticks(range(min(data), max(data)))
166
+ # plt.grid(linestyle='--', alpha=0.5)
167
+
168
+ plt.xlabel(xlabel)
169
+ plt.ylabel(ylabel)
170
+ if filepath is None:
171
+ plt.show()
172
+ else:
173
+ plt.savefig(filepath)
174
+ plt.clf()
175
+ plt.close()
176
+
177
+
178
+ def plot_confusion_matrix_for_binary_class(targets, preds, cm=None, savepath=None):
179
+ '''
180
+ :param targets: ground truth
181
+ :param preds: prediction probs
182
+ :param cm: confusion matrix
183
+ :param savepath: confusion matrix picture savepth
184
+ '''
185
+
186
+ plt.figure(figsize=(40, 20), dpi=100)
187
+ if cm is None:
188
+ cm = confusion_matrix(targets, preds, labels=[0, 1])
189
+
190
+ plt.matshow(cm, cmap=plt.cm.Oranges)
191
+ plt.colorbar()
192
+
193
+ for x in range(len(cm)):
194
+ for y in range(len(cm)):
195
+ plt.annotate(cm[x, y], xy=(y, x), verticalalignment='center', horizontalalignment='center')
196
+ plt.ylabel('True')
197
+ plt.xlabel('Prediction')
198
+ if savepath:
199
+ plt.savefig(savepath, dpi=100)
200
+ else:
201
+ plt.show()
202
+ plt.close("all")
203
+
204
+
205
+ def save_labels(filepath, label_list):
206
+ '''
207
+ save labels
208
+ :param filepath:
209
+ :param label_list:
210
+ :return:
211
+ '''
212
+ with open(filepath, "w") as wfp:
213
+ wfp.write("label" + "\n")
214
+ for label in label_list:
215
+ wfp.write(label + "\n")
216
+
217
+
218
+ def load_labels(filepath, header=True):
219
+ '''
220
+ load labels
221
+ :param filepath:
222
+ :param header: where the file has header or not
223
+ :return:
224
+ '''
225
+ label_list = []
226
+ with open(filepath, "r") as rfp:
227
+ for label in rfp:
228
+ label_list.append(label.strip())
229
+ if len(label_list) > 0 and (header or label_list[0] == "label"):
230
+ return label_list[1:]
231
+ return label_list
232
+
233
+
234
+ def load_vocab(vocab_path):
235
+ '''
236
+ load vocab
237
+ :param vocab_path:
238
+ :return:
239
+ '''
240
+ vocab = {}
241
+ with open(vocab_path, "r") as rfp:
242
+ for line in rfp:
243
+ v = line.strip()
244
+ vocab[v] = len(vocab)
245
+ return vocab
246
+
247
+
248
+ def subprocess_popen(statement):
249
+ '''
250
+ execute shell cmd
251
+ :param statement:
252
+ :return:
253
+ '''
254
+ p = subprocess.Popen(statement, shell=True, stdout=subprocess.PIPE)
255
+ while p.poll() is None:
256
+ if p.wait() != 0:
257
+ print("fail.")
258
+ return False
259
+ else:
260
+ re = p.stdout.readlines()
261
+ result = []
262
+ for i in range(len(re)):
263
+ res = re[i].decode('utf-8').strip('\r\n')
264
+ result.append(res)
265
+ return result
266
+
267
+
268
+ def prepare_inputs(input_type, embedding_type, batch):
269
+ if input_type == "sequence":
270
+ inputs = {
271
+ "input_ids_a": batch[0],
272
+ "attention_mask_a": batch[1],
273
+ "token_type_ids_a": batch[2],
274
+ "input_ids_b": batch[4],
275
+ "attention_mask_b": batch[5],
276
+ "token_type_ids_b": batch[6],
277
+ "labels": batch[-1]
278
+ }
279
+ elif input_type == "embedding":
280
+ if embedding_type not in ["vector", "bos"]:
281
+ inputs = {
282
+ "embedding_info_a": batch[0],
283
+ "embedding_attention_mask_a": batch[1],
284
+ "embedding_info_b": batch[2],
285
+ "embedding_attention_mask_b": batch[3],
286
+ "labels": batch[-1]
287
+ }
288
+ else:
289
+ inputs = {
290
+ "embedding_info_a": batch[0],
291
+ "embedding_attention_mask_a": None,
292
+ "embedding_info_b": batch[1],
293
+ "embedding_attention_mask_b": None,
294
+ "labels": batch[-1]
295
+ }
296
+ elif input_type == "structure":
297
+ inputs = {
298
+ "struct_input_ids_a": batch[0],
299
+ "struct_contact_map_a": batch[1],
300
+ "struct_input_ids_b": batch[2],
301
+ "struct_contact_map_b": batch[3],
302
+ "labels": batch[-1]
303
+ }
304
+ elif input_type == "sefn":
305
+ if embedding_type not in ["vector", "bos"]:
306
+ inputs = {
307
+ "input_ids_a": batch[0],
308
+ "attention_mask_a": batch[1],
309
+ "token_type_ids_a": batch[2],
310
+ "embedding_info_a": batch[4],
311
+ "embedding_attention_mask_a": batch[5],
312
+ "input_ids_b": batch[6],
313
+ "attention_mask_b": batch[7],
314
+ "token_type_ids_b": batch[8],
315
+ "embedding_info_b": batch[10],
316
+ "embedding_attention_mask_b": batch[11],
317
+ "labels": batch[-1],
318
+ }
319
+ else:
320
+ inputs = {
321
+ "input_ids_a": batch[0],
322
+ "attention_mask_a": batch[1],
323
+ "token_type_ids_a": batch[2],
324
+ "embedding_info_a": batch[4],
325
+ "embedding_attention_mask_a": None,
326
+ "input_ids_b": batch[5],
327
+ "attention_mask_b": batch[6],
328
+ "token_type_ids_b": batch[7],
329
+ "embedding_info_b": batch[9],
330
+ "embedding_attention_mask_b": None,
331
+ "labels": batch[-1],
332
+ }
333
+ elif input_type == "ssfn":
334
+ inputs = {
335
+ "input_ids_a": batch[0],
336
+ "attention_mask_a": batch[1],
337
+ "token_type_ids_a": batch[2],
338
+ "struct_input_ids_a": batch[4],
339
+ "struct_contact_map_a": batch[5],
340
+ "input_ids_b": batch[6],
341
+ "attention_mask_b": batch[7],
342
+ "token_type_ids_b": batch[8],
343
+ "struct_input_ids_b": batch[10],
344
+ "struct_contact_map_b": batch[11],
345
+ "labels": batch[-1]
346
+ }
347
+ else:
348
+ inputs = None
349
+ return inputs
350
+
351
+
352
+ def gene_seq_replace_re(seq):
353
+ '''
354
+ Nucleic acid 还原
355
+ :param seq:
356
+ :return:
357
+ '''
358
+ new_seq = ""
359
+ for ch in seq:
360
+ if ch == '1':
361
+ new_seq += "A"
362
+ elif ch == '2':
363
+ new_seq += "T"
364
+ elif ch == '3':
365
+ new_seq += "C"
366
+ elif ch == '4':
367
+ new_seq += "G"
368
+ else: # unknown
369
+ new_seq += "N"
370
+ return new_seq
371
+
372
+
373
+ def gene_seq_replace(seq):
374
+ '''
375
+ Nucleic acid (gene replace: A->1, U/T->2, C->3, G->4, N->5
376
+ :param seq:
377
+ :return:
378
+ '''
379
+ new_seq = ""
380
+ for ch in seq:
381
+ if ch in ["A", "a"]:
382
+ new_seq += "1"
383
+ elif ch in ["T", "U", "t", "u"]:
384
+ new_seq += "2"
385
+ elif ch in ["C", "c"]:
386
+ new_seq += "3"
387
+ elif ch in ["G", "g"]:
388
+ new_seq += "4"
389
+ else: # unknown
390
+ new_seq += "5"
391
+ return new_seq
392
+
393
+
394
+ def get_labels(label_filepath, header=True):
395
+ '''
396
+ get labels from file, exists header
397
+ :param label_filepath:
398
+ :param header:
399
+ :return:
400
+ '''
401
+ with open(label_filepath, "r") as fp:
402
+ labels = []
403
+ multi_cols = False
404
+ cnt = 0
405
+ for line in fp:
406
+ line = line.strip()
407
+ cnt += 1
408
+ if cnt == 1 and (header or line == "label"):
409
+ if line.find(",") > 0:
410
+ multi_cols = True
411
+ continue
412
+ if multi_cols:
413
+ idx = line.find(",")
414
+ if idx > 0:
415
+ label_name = line[idx + 1:].strip()
416
+ else:
417
+ label_name = line
418
+ else:
419
+ label_name = line
420
+ labels.append(label_name)
421
+ return labels
422
+
423
+
424
+ def available_gpu_id():
425
+ '''
426
+ 计算可用的GPU id
427
+ :return:
428
+ '''
429
+ pynvml.nvmlInit()
430
+ if not torch.cuda.is_available():
431
+ print("GPU not available")
432
+ return -1
433
+ # 获取GPU数量
434
+ device_count = pynvml.nvmlDeviceGetCount()
435
+ max_available_gpu = -1
436
+ max_available_rate = 0
437
+
438
+ # 遍历所有GPU并检查可用性
439
+ for i in range(device_count):
440
+ handle = pynvml.nvmlDeviceGetHandleByIndex(i)
441
+ memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
442
+ utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
443
+ # 假设如果GPU利用率小于某个阈值(例如10%),我们认为这个GPU目前是空闲的
444
+ if utilization.gpu < 10 and max_available_rate < 100 - utilization.gpu:
445
+ max_available_rate = 100 - utilization.gpu
446
+ max_available_gpu = i
447
+ # 打印可用的GPU ID
448
+ if max_available_gpu > -1:
449
+ print("Available GPU ID: %d, Free Rate: %0.2f%%" % (max_available_gpu, max_available_rate))
450
+ else:
451
+ print("No Available GPU!")
452
+
453
+ # Shutdown NVML
454
+ pynvml.nvmlShutdown()
455
+ return max_available_gpu
456
+
457
+
458
+ def eval_metrics(output_mode, truths, preds, threshold=0.5):
459
+ '''
460
+ eval metrics
461
+ :param output_mode:
462
+ :param truths:
463
+ :param preds:
464
+ :param threshold:
465
+ :return:
466
+ '''
467
+ print("\ntruths size: ", truths.shape)
468
+ print("\npreds size: ", preds.shape)
469
+ if output_mode in ["multi-label", "multi_label"]:
470
+ return metrics_multi_label(truths, preds, threshold=threshold)
471
+ elif output_mode in ["multi-class", "multi_class"]:
472
+ return metrics_multi_class(truths, preds)
473
+ elif output_mode == "regression":
474
+ return metrics_regression(truths, preds)
475
+ elif output_mode in ["binary-class", "binary_class"]:
476
+ return metrics_binary(truths, preds, threshold=threshold)
477
+ else:
478
+ raise Exception("Not Support this output mode: %s" % output_mode)
479
+
480
+
481
+ def load_trained_model(model_config, args, model_class, model_dirpath):
482
+ # load exists checkpoint
483
+ print("load pretrained model: %s" % model_dirpath)
484
+ try:
485
+ model = model_class.from_pretrained(model_dirpath, args=args)
486
+ except Exception as e:
487
+ model = model_class(model_config, args=args)
488
+ pretrained_net_dict = torch.load(os.path.join(args.model_dirpath, "pytorch.pth"),
489
+ map_location=torch.device("cpu"))
490
+ model_state_dict_keys = set()
491
+ for key in model.state_dict():
492
+ model_state_dict_keys.add(key)
493
+ new_state_dict = OrderedDict()
494
+ for k, v in pretrained_net_dict.items():
495
+ if k.startswith("module."):
496
+ # remove `module.`
497
+ name = k[7:]
498
+ else:
499
+ name = k
500
+ if name in model_state_dict_keys:
501
+ new_state_dict[name] = v
502
+ # print("diff:")
503
+ # print(model_state_dict_keys.difference(new_state_dict.keys()))
504
+ model.load_state_dict(new_state_dict)
505
+ return model
506
+
507
+
508
+ def clean_seq(protein_id, seq, return_rm_index=False):
509
+ seq = seq.upper()
510
+ new_seq = ""
511
+ has_invalid_char = False
512
+ invalid_char_set = set()
513
+ return_rm_index_set = set()
514
+ for idx, ch in enumerate(seq):
515
+ if 'A' <= ch <= 'Z' and ch not in ['J']:
516
+ new_seq += ch
517
+ else:
518
+ invalid_char_set.add(ch)
519
+ return_rm_index_set.add(idx)
520
+ has_invalid_char = True
521
+ if has_invalid_char:
522
+ print("id: %s. Seq: %s" % (protein_id, seq))
523
+ print("invalid char set:", invalid_char_set)
524
+ print("return_rm_index:", return_rm_index_set)
525
+ if return_rm_index:
526
+ return new_seq, return_rm_index_set
527
+ return new_seq
528
+
529
+
530
+ def sample_size(data_dirpath):
531
+ if os.path.isdir(data_dirpath):
532
+ new_filepaths = []
533
+ for filename in os.listdir(data_dirpath):
534
+ if not filename.startswith("."):
535
+ new_filepaths.append(os.path.join(data_dirpath, filename))
536
+ filepaths = new_filepaths
537
+ else:
538
+ filepaths = [data_dirpath]
539
+ total = 0
540
+ for filepath in filepaths:
541
+ header = filepath.endswith(".tsv") or filepath.endswith(".csv")
542
+ print("sample_size filepath: %s" % filepath)
543
+ for _ in file_reader(filepath, header=header, header_filter=True):
544
+ total += 1
545
+ return total
546
+
547
+
548
+ def writer_info_tb(tb_writer, logs, global_step, prefix=None):
549
+ '''
550
+ write info to tensorboard
551
+ :param tb_writer:
552
+ :param logs:
553
+ :param global_step:
554
+ :param prefix:
555
+ :return:
556
+ '''
557
+ for key, value in logs.items():
558
+ if isinstance(value, dict):
559
+ '''
560
+ for key1, value1 in value.items():
561
+ tb_writer.add_scalar(key + "_" + key1, value1, global_step)
562
+ '''
563
+ writer_info_tb(tb_writer, value, global_step, prefix=key)
564
+ elif not math.isnan(value) and not math.isinf(value):
565
+ tb_writer.add_scalar(prefix + "_" + key if prefix else key, value, global_step)
566
+ else:
567
+ print("writer_info_tb NaN or Inf, Key-Value: %s=%s" % (key, value))
568
+
569
+
570
+ def get_lr(optimizer):
571
+ '''
572
+ get learning rate
573
+ :param optimizer:
574
+ :return:
575
+ '''
576
+ for p in optimizer.param_groups:
577
+ if "lr" in p:
578
+ return p["lr"]
579
+
580
+
581
+ def metrics_merge(results, all_results):
582
+ '''
583
+ merge metrics
584
+ :param results:
585
+ :param all_results:
586
+ :return:
587
+ '''
588
+ for item1 in results.items():
589
+ if item1[0] not in all_results:
590
+ all_results[item1[0]] = {}
591
+ for item2 in item1[1].items():
592
+ if item2[0] not in all_results[item1[0]]:
593
+ all_results[item1[0]][item2[0]] = {}
594
+ for item3 in item2[1].items():
595
+ if item3[0] not in all_results[item1[0]][item2[0]]:
596
+ all_results[item1[0]][item2[0]][item3[0]] = item3[1]
597
+ else:
598
+ all_results[item1[0]][item2[0]][item3[0]] += item3[1]
599
+ return all_results
600
+
601
+
602
+ def print_shape(item):
603
+ '''
604
+ print shape
605
+ :param item:
606
+ :return:
607
+ '''
608
+ if isinstance(item, dict):
609
+ for item1 in item.items():
610
+ print(item1[0] + ":")
611
+ print_shape(item1[1])
612
+ elif isinstance(item, list):
613
+ for idx, item1 in enumerate(item):
614
+ print("idx: %d" % idx)
615
+ print_shape(item1)
616
+ else:
617
+ print("shape:", item.shape)
618
+
619
+
620
+ def process_outputs(output_mode, truth, pred, output_truth, output_pred, ignore_index, keep_seq=False):
621
+ if keep_seq:
622
+ # to do
623
+ return None, None
624
+ else:
625
+ if output_mode in ["multi_class", "multi-class"]:
626
+ cur_truth = truth.view(-1)
627
+ cur_mask = cur_truth != ignore_index
628
+ cur_pred = pred.view(-1, pred.shape[-1])
629
+ cur_truth = cur_truth[cur_mask]
630
+ cur_pred = cur_pred[cur_mask, :]
631
+ sum_v = cur_mask.sum().item()
632
+ elif output_mode in ["multi_label", "multi-label"]:
633
+ cur_truth = truth.view(-1, truth.shape[-1])
634
+ cur_pred = pred.view(-1, pred.shape[-1])
635
+ sum_v = pred.shape[0]
636
+ elif output_mode in ["binary_class", "binary-class"]:
637
+ cur_truth = truth.view(-1)
638
+ cur_mask = cur_truth != ignore_index
639
+ cur_pred = pred.view(-1)
640
+ cur_truth = cur_truth[cur_mask]
641
+ cur_pred = cur_pred[cur_mask]
642
+ sum_v = cur_mask.sum().item()
643
+ elif output_mode in ["regression"]:
644
+ cur_truth = truth.view(-1)
645
+ cur_mask = cur_truth != ignore_index
646
+ cur_pred = pred.view(-1)
647
+ cur_truth = cur_truth[cur_mask]
648
+ cur_pred = cur_pred[cur_mask]
649
+ sum_v = cur_mask.sum().item()
650
+ else:
651
+ raise Exception("not output mode: %s" % output_mode)
652
+ if sum_v > 0:
653
+ cur_truth = cur_truth.detach().cpu().numpy()
654
+ cur_pred = cur_pred.detach().cpu().numpy()
655
+ if output_truth is None or output_pred is None:
656
+ return cur_truth, cur_pred
657
+ else:
658
+ output_truth = np.append(output_truth, cur_truth, axis=0)
659
+ output_pred = np.append(output_pred, cur_pred, axis=0)
660
+ return output_truth, output_pred
661
+ return truth, pred
662
+
663
+
664
+ def print_batch(value, key=None, debug_path=None, wfp=None, local_rank=-1):
665
+ '''
666
+ print a batch
667
+ :param value:
668
+ :param key:
669
+ :param debug_path:
670
+ :param wfp:
671
+ :param local_rank:
672
+ :return:
673
+ '''
674
+ if isinstance(value, list):
675
+ for idx, v in enumerate(value):
676
+ if wfp is not None:
677
+ if v is not None:
678
+ wfp.write(str([torch.min(v), torch.min(torch.where(v == -100, 10000, v)), torch.max(v)]) + "\n")
679
+ wfp.write(str(v.shape) + "\n")
680
+ else:
681
+ wfp.write("None\n")
682
+ wfp.write("-" * 10 + "\n")
683
+ else:
684
+ if v is not None:
685
+ print([torch.min(v), torch.min(torch.where(v == -100, 10000, v)), torch.max(v)])
686
+ print(v.shape)
687
+ else:
688
+ print("None")
689
+ print("-" * 50)
690
+ if v is not None:
691
+ try:
692
+ value = v.detach().cpu().numpy().astype(int)
693
+ if debug_path is not None:
694
+ if value.ndim == 3:
695
+ for dim_1_idx in range(value.shape[0]):
696
+ np.savetxt(os.path.join(debug_path, "%s_batch_%d.txt" % (key, dim_1_idx)), value[dim_1_idx, :, :], fmt='%i', delimiter=",")
697
+ else:
698
+ np.savetxt(os.path.join(debug_path, "%d.txt" % idx), value, fmt='%i', delimiter=",")
699
+ else:
700
+ if value.ndim == 3:
701
+ for dim_1_idx in range(value.shape[0]):
702
+ np.savetxt(os.path.join(debug_path, "%s_batch_%d.txt" % (key, dim_1_idx)), value[dim_1_idx, :, :], fmt='%i', delimiter=",")
703
+ else:
704
+ np.savetxt("%d.txt" % idx, value, fmt='%i', delimiter=",")
705
+ except Exception as e:
706
+ print(e)
707
+ elif isinstance(value, dict):
708
+ for item in value.items():
709
+ if wfp is not None:
710
+ wfp.write(str(item[0]) + ":\n")
711
+ else:
712
+ print(str(item[0]) + ':')
713
+ print_batch(item[1], item[0], debug_path, wfp, local_rank)
714
+ else:
715
+ if wfp is not None:
716
+ if value is not None:
717
+ wfp.write(str([torch.min(value), torch.min(torch.where(value == -100, 10000, value)), torch.max(value)]) + "\n")
718
+ wfp.write(str(value.shape) + "\n")
719
+ else:
720
+ wfp.write("None\n")
721
+ wfp.write("-" * 10 + "\n")
722
+ else:
723
+ if value is not None:
724
+ print([torch.min(value), torch.min(torch.where(value == -100, 10000, value)), torch.max(value)])
725
+ print(value.shape)
726
+ else:
727
+ print("None")
728
+ print("-" * 10)
729
+ if value is not None:
730
+ if key != "prot_structure":
731
+ fmt = '%i'
732
+ d_type = int
733
+ else:
734
+ fmt = '%0.4f'
735
+ d_type = float
736
+ try:
737
+ value = value.detach().cpu().numpy().astype(d_type)
738
+ if debug_path is not None:
739
+ if value.ndim == 3:
740
+ for dim_1_idx in range(value.shape[0]):
741
+ np.savetxt(os.path.join(debug_path, "%s_batch_%d.txt" % (key, dim_1_idx)), value[dim_1_idx, :, :], fmt=fmt, delimiter=",")
742
+ else:
743
+ np.savetxt(os.path.join(debug_path, "%s.txt" % key), value, fmt=fmt, delimiter=",")
744
+ else:
745
+ if value.ndim == 3:
746
+ for dim_1_idx in range(value.shape[0]):
747
+ np.savetxt("%s_batch_%d.txt" % (key, dim_1_idx), value[dim_1_idx, :, :], fmt=fmt, delimiter=",")
748
+ else:
749
+ np.savetxt("%s.txt" % key, value, fmt=fmt, delimiter=",")
750
+ except Exception as e:
751
+ print(e)
752
+
753
+
754
+ def gcd(x, y):
755
+ '''
756
+ 最大公约数
757
+ :param x:
758
+ :param y:
759
+ :return:
760
+ '''
761
+ m = max(x, y)
762
+ n = min(x, y)
763
+ while m % n:
764
+ m, n = n, m % n
765
+ return n
766
+
767
+
768
+ def lcm(x, y):
769
+ '''
770
+ 最小公倍数
771
+ :param x:
772
+ :param y:
773
+ :return:
774
+ '''
775
+ m = max(x, y)
776
+ n = min(x, y)
777
+ while m % n:
778
+ m, n = n, m % n
779
+ return x*y//n
780
+
781
+
782
+ def device_memory(gpu_id):
783
+ if gpu_id is None or gpu_id < 0:
784
+ return
785
+ pynvml.nvmlInit()
786
+ device_cnt = pynvml.nvmlDeviceGetCount()
787
+ for idx in range(device_cnt):
788
+ if gpu_id is not None and gpu_id != idx:
789
+ continue
790
+ handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
791
+ info = pynvml.nvmlDeviceGetMemoryInfo(handle)
792
+ print(f"Device {idx}: {pynvml.nvmlDeviceGetName(handle)}")
793
+ print(f"Total memory: {info.total / 1024**3:.8f} GB")
794
+ print(f"Used memory: {info.used / 1024**3:.8f} GB")
795
+ print(f"Free memory: {info.free / 1024**3:.8f} GB")
796
+ pynvml.nvmlShutdown()
797
+
798
+
799
+ def calc_emb_filename_by_seq_id(seq_id, embedding_type):
800
+ """
801
+ 根据seq_id得到emb_filename
802
+ :param seq_id:
803
+ :param embedding_type:
804
+ :return:
805
+ """
806
+ if seq_id[0] == ">":
807
+ seq_id = seq_id[1:]
808
+ if "|" in seq_id:
809
+ strs = seq_id.split("|")
810
+ if len(strs) > 1:
811
+ emb_filename = embedding_type + "_" + strs[1].strip() + ".pt"
812
+ else:
813
+ emb_filename = embedding_type + "_" + seq_id.replace(" ", "").replace("/", "_") + ".pt"
814
+ else:
815
+ emb_filename = embedding_type + "_" + seq_id.replace(" ", "").replace("/", "_") + ".pt"
816
+ return emb_filename
817
+
818
+
819
+ def download_file(url, local_filename):
820
+ with requests.get(url, stream=True) as r:
821
+ r.raise_for_status()
822
+ dir_name = os.path.dirname(local_filename)
823
+ if not os.path.exists(dir_name):
824
+ os.makedirs(dir_name)
825
+ with open(local_filename, 'wb') as f:
826
+ for chunk in r.iter_content(chunk_size=8192):
827
+ if chunk: # filter out keep-alive new chunks
828
+ f.write(chunk)
829
+ return local_filename
830
+
831
+
832
+ def download_folder(base_url, file_names, local_dir):
833
+ if not os.path.exists(local_dir):
834
+ os.makedirs(local_dir)
835
+
836
+ for file_name in file_names:
837
+ file_url = f"{base_url}/{file_name}"
838
+ local_filename = os.path.join(local_dir, file_name)
839
+ download_file(file_url, local_filename)
840
+ print(f"Downloaded {file_name}")
841
+
842
+
843
+ def download_trained_checkpoint_lucaone(
844
+ llm_dir,
845
+ llm_type="lucaone_gplm",
846
+ llm_version="v2.0",
847
+ llm_task_level="token_level,span_level,seq_level,structure_level",
848
+ llm_time_str="20231125113045",
849
+ llm_step="5600000",
850
+ base_url="http://47.93.21.181/lucaone/TrainedCheckPoint"
851
+ ):
852
+ """
853
+ donwload trained checkpoint of LucaOne
854
+ :param llm_dir:
855
+ :param llm_type:
856
+ :param llm_version:
857
+ :param llm_task_level:
858
+ :param llm_time_str:
859
+ :param llm_step:
860
+ :param base_url:
861
+ :return:
862
+ """
863
+ print("------Download Trained LLM(LucaOne)------")
864
+ try:
865
+ logs_file_names = ["logs.txt"]
866
+ models_file_names = ["config.json", "pytorch.pth", "training_args.bin", "tokenizer/alphabet.pkl"]
867
+ logs_path = "logs/lucagplm/%s/%s/%s/%s" % (llm_version, llm_task_level, llm_type, llm_time_str)
868
+ models_path = "models/lucagplm/%s/%s/%s/%s/checkpoint-step%s" % (llm_version, llm_task_level, llm_type, llm_time_str, llm_step)
869
+ logs_local_dir = os.path.join(llm_dir, logs_path)
870
+ exists = True
871
+ for logs_file_name in logs_file_names:
872
+ if not os.path.exists(os.path.join(logs_local_dir, logs_file_name)):
873
+ exists = False
874
+ break
875
+ models_local_dir = os.path.join(llm_dir, models_path)
876
+ if exists:
877
+ for models_file_name in models_file_names:
878
+ if not os.path.exists(os.path.join(models_local_dir, models_file_name)):
879
+ exists = False
880
+ break
881
+ if not exists:
882
+ print("*" * 20 + "Downloading" + "*" * 20)
883
+ print("Downloading LucaOne TrainedCheckPoint: LucaOne-%s-%s-%s ..." % (llm_version, llm_time_str, llm_step))
884
+ print("Wait a moment, please.")
885
+ # download logs
886
+ if not os.path.exists(logs_local_dir):
887
+ os.makedirs(logs_local_dir)
888
+ logs_base_url = os.path.join(base_url, logs_path)
889
+ download_folder(logs_base_url, logs_file_names, logs_local_dir)
890
+ # download models
891
+ if not os.path.exists(models_local_dir):
892
+ os.makedirs(models_local_dir)
893
+ models_base_url = os.path.join(base_url, models_path)
894
+ download_folder(models_base_url, models_file_names, models_local_dir)
895
+ print("LucaOne Download Succeed.")
896
+ print("*" * 50)
897
+ except Exception as e:
898
+ print(e)
899
+ print("Download automatically LucaOne Trained CheckPoint failed!")
900
+ print("You can manually download 'logs/' and 'models/' into local directory: %s/ from %s" % (os.path.abspath(llm_dir), os.path.join(base_url, "TrainedCheckPoint/")))
901
+ raise Exception(e)
902
+
903
+
904
+ def download_trained_checkpoint_downstream_tasks(
905
+ save_dir="../",
906
+ dataset_name=["CentralDogma", "GenusTax", "InfA", "ncRNAFam", "ncRPI", "PPI", "ProtLoc", "ProtStab", "SpeciesTax", "SupKTax"],
907
+ dataset_type=["gene_protein", "gene", "gene_gene", "gene", "gene_protein", "protein", "protein", "protein", "gene", "gene"],
908
+ task_type=["binary_class", "multi_class", "binary_class", "multi_class", "binary_class", "binary_class", "multi_class", "regression", "multi_class", "multi_class"],
909
+ model_type=["lucappi2", "luca_base", "lucappi", "luca_base", "lucappi2", "lucappi", "luca_base", "luca_base", "luca_base", "luca_base"],
910
+ input_type=["matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix"],
911
+ time_str=["20240406173806", "20240412100337", "20240214105653", "20240414155526", "20240404105148", "20240216205421", "20240412140824", "20240404104215", "20240411144916", "20240212202328"],
912
+ step=[64000, 24500, 9603, 1958484, 716380, 52304, 466005, 70371, 24000, 37000],
913
+ base_url="http://47.93.21.181/lucaone/DownstreamTasksTrainedModels"
914
+ ):
915
+ """
916
+ donwload trained downstream task models
917
+ :param save_dir: 本地保存路径
918
+ :param dataset_name:
919
+ :param dataset_type:
920
+ :param task_type:
921
+ :param model_type:
922
+ :param input_type:
923
+ :param time_str:
924
+ :param step:
925
+ :param base_url:
926
+ :return:
927
+ """
928
+ assert len(dataset_name) == len(dataset_type) == len(task_type) == \
929
+ len(model_type) == len(input_type) == len(time_str) == len(step)
930
+ assert isinstance(dataset_name, list)
931
+ assert isinstance(dataset_type, list)
932
+ assert isinstance(task_type, list)
933
+ assert isinstance(model_type, list)
934
+ assert isinstance(input_type, list)
935
+ assert isinstance(time_str, list)
936
+ assert isinstance(step, list)
937
+ download_succeed_task_num = 0
938
+ print("------Download Trained Models------")
939
+ for idx in range(len(dataset_name)):
940
+ try:
941
+ logs_file_names = ["logs.txt", "label.txt"]
942
+ models_file_names = ["config.json", "pytorch_model.bin", "training_args.bin", "tokenizer/alphabet.pkl"]
943
+ logs_path = "logs/%s/%s/%s/%s/%s/%s" % (dataset_name[idx], dataset_type[idx], task_type[idx], model_type[idx], input_type[idx], time_str[idx])
944
+ models_path = "models/%s/%s/%s/%s/%s/%s/checkpoint-%s" % (dataset_name[idx], dataset_type[idx], task_type[idx], model_type[idx], input_type[idx], time_str[idx], str(step[idx]))
945
+ logs_local_dir = os.path.join(save_dir, logs_path)
946
+ exists = True
947
+ for logs_file_name in logs_file_names:
948
+ if not os.path.exists(os.path.join(logs_local_dir, logs_file_name)):
949
+ exists = False
950
+ break
951
+ models_local_dir = os.path.join(save_dir, models_path)
952
+ if exists:
953
+ for models_file_name in models_file_names:
954
+ if not os.path.exists(os.path.join(models_local_dir, models_file_name)):
955
+ exists = False
956
+ break
957
+ if not exists:
958
+ print("*" * 20 + "Downloading" + "*" * 20)
959
+ print("Downloading Downstream Task: %s TrainedCheckPoint: %s-%s-%s ..." % (dataset_name[idx], dataset_name[idx], time_str[idx], str(step[idx])))
960
+ print("Wait a moment, please.")
961
+ # download logs
962
+ if not os.path.exists(logs_local_dir):
963
+ os.makedirs(logs_local_dir)
964
+ logs_base_url = os.path.join(base_url, dataset_name[idx], logs_path)
965
+ download_folder(logs_base_url, logs_file_names, logs_local_dir)
966
+ # download models
967
+ if not os.path.exists(models_local_dir):
968
+ os.makedirs(models_local_dir)
969
+ models_base_url = os.path.join(base_url, dataset_name[idx], models_path)
970
+ download_folder(models_base_url, models_file_names, models_local_dir)
971
+ print("Downstream Task: %s Trained Model Download Succeed." % dataset_name[idx])
972
+ print("*" * 50)
973
+ download_succeed_task_num += 1
974
+ except Exception as e:
975
+ print(e)
976
+ print("Download automatically LucaDownstream Task: %s Trained CheckPoint failed!" % dataset_name[idx])
977
+ print("You can manually download 'logs/' and 'models/' into local directory: %s/ from %s" % (os.path.abspath(save_dir), os.path.join(base_url, dataset_name[idx])))
978
+ raise Exception(e)
979
+ print("%d Downstream Task Trained Model Download Succeed." % download_succeed_task_num)