Upload 7 files
Browse files- alphabet.py +0 -9
- lucaone_gplm.py +142 -109
- modeling_bert.py +27 -33
- modeling_gplm.py +26 -75
alphabet.py
CHANGED
@@ -6,7 +6,6 @@ import json
|
|
6 |
import itertools
|
7 |
from typing import Sequence, List
|
8 |
from transformers import PreTrainedTokenizer
|
9 |
-
from .batch_converter import BatchConverter
|
10 |
|
11 |
gene_standard_toks = ['1', '2', '3', '4', '5', '.', '-', '*']
|
12 |
|
@@ -63,14 +62,6 @@ class Alphabet(object):
|
|
63 |
def to_dict(self):
|
64 |
return self.tok_to_idx.copy()
|
65 |
|
66 |
-
def get_batch_converter(self, no_position_embeddings, no_token_type_embeddings, truncation_seq_length: int = None, ignore_index: int = -100, mlm_probability=0.15):
|
67 |
-
return BatchConverter(self,
|
68 |
-
no_position_embeddings=no_position_embeddings,
|
69 |
-
no_token_type_embeddings=no_token_type_embeddings,
|
70 |
-
truncation_seq_length=truncation_seq_length,
|
71 |
-
ignore_index=ignore_index,
|
72 |
-
mlm_probability=mlm_probability)
|
73 |
-
|
74 |
@classmethod
|
75 |
def from_predefined(cls, name: str):
|
76 |
if name.lower() == "prot":
|
|
|
6 |
import itertools
|
7 |
from typing import Sequence, List
|
8 |
from transformers import PreTrainedTokenizer
|
|
|
9 |
|
10 |
gene_standard_toks = ['1', '2', '3', '4', '5', '.', '-', '*']
|
11 |
|
|
|
62 |
def to_dict(self):
|
63 |
return self.tok_to_idx.copy()
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
@classmethod
|
66 |
def from_predefined(cls, name: str):
|
67 |
if name.lower() == "prot":
|
lucaone_gplm.py
CHANGED
@@ -37,6 +37,7 @@ class LucaGPLM(PreTrainedModel):
|
|
37 |
self.use_embed_layer_norm = config.use_embed_layer_norm
|
38 |
self.use_last_layer_norm = config.use_last_layer_norm
|
39 |
self.embed_scale = config.embed_scale
|
|
|
40 |
self._init_submodules()
|
41 |
|
42 |
def _init_submodules(self):
|
@@ -72,22 +73,23 @@ class LucaGPLM(PreTrainedModel):
|
|
72 |
)
|
73 |
self.layer_size = len(self.layers)
|
74 |
|
75 |
-
self.
|
76 |
-
self.
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
81 |
if self.use_last_layer_norm:
|
82 |
self.last_layer_norm = LucaGPLM1bLayerNorm(self.embed_dim)
|
83 |
else:
|
84 |
self.last_layer_norm = None
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
|
92 |
def _init_embedding(self, pretrained_token_matrix, token_matrix):
|
93 |
'''
|
@@ -103,7 +105,7 @@ class LucaGPLM(PreTrainedModel):
|
|
103 |
31->38
|
104 |
32->4
|
105 |
'''
|
106 |
-
print("Load pretrained
|
107 |
token_matrix[2, :] = pretrained_token_matrix[0, :]
|
108 |
token_matrix[0, :] = pretrained_token_matrix[1, :]
|
109 |
token_matrix[3, :] = pretrained_token_matrix[2, :]
|
@@ -117,7 +119,7 @@ class LucaGPLM(PreTrainedModel):
|
|
117 |
return token_matrix
|
118 |
|
119 |
def _init_submodules_new(self, pretrained_model_name):
|
120 |
-
print("Load pretrained model exists weights:")
|
121 |
from esm import pretrained
|
122 |
from collections import OrderedDict
|
123 |
pretrained, _ = pretrained.load_model_and_alphabet(pretrained_model_name)
|
@@ -143,33 +145,16 @@ class LucaGPLM(PreTrainedModel):
|
|
143 |
elif name in our_model_state_dict and our_model_state_dict[name].shape == weight.shape:
|
144 |
del our_model_state_dict[name]
|
145 |
new_state_dict[name] = weight
|
146 |
-
|
147 |
print("Exists layer names:")
|
148 |
print(new_state_dict.keys())
|
149 |
print("Not exists Layer names:")
|
150 |
print(our_model_state_dict.keys())
|
|
|
151 |
new_state_dict.update(our_model_state_dict)
|
152 |
self.load_state_dict(new_state_dict)
|
153 |
|
154 |
def __calc_loss__(self, task_level_type, output_mode, logits, label, label_size, loss_fct, loss_reduction):
|
155 |
-
'''
|
156 |
-
if label_size <= 2 or output_mode in ["binary_class", "binary-class"]:
|
157 |
-
loss = loss_fct(logits.view(-1), label.view(-1).float())
|
158 |
-
elif output_mode in ["multi_label", "multi-label"]:
|
159 |
-
loss = loss_fct(logits.view(-1, label_size), label.view(-1, label_size).float())
|
160 |
-
elif output_mode in ["multi_class", "multi-class"]:
|
161 |
-
loss = loss_fct(logits.view(-1, label_size), label.view(-1))
|
162 |
-
else:
|
163 |
-
loss = loss_fct(logits.view(-1), label.view(-1))
|
164 |
-
return loss
|
165 |
-
'''
|
166 |
-
'''
|
167 |
-
print(task_level_type, output_mode, label_size, loss_fct, loss_reduction)
|
168 |
-
print("logits:")
|
169 |
-
print(logits.shape)
|
170 |
-
print("label:")
|
171 |
-
print(label.shape)
|
172 |
-
'''
|
173 |
if output_mode in ["regression"]:
|
174 |
if task_level_type not in ["seq_level"] and loss_reduction == "meanmean":
|
175 |
# structure-level regression
|
@@ -307,7 +292,8 @@ class LucaGPLM(PreTrainedModel):
|
|
307 |
representation_matrix = hidden_representations[self.layer_size]
|
308 |
# mask 任务
|
309 |
# B * Seq_len * vocab_size
|
310 |
-
|
|
|
311 |
# lm head的输出向量作为表征向量
|
312 |
# (B, E)
|
313 |
representation_vector = representation_matrix[:, 0, :]
|
@@ -329,14 +315,15 @@ class LucaGPLM(PreTrainedModel):
|
|
329 |
attentions = attentions * attention_mask[:, None, None, :, :]
|
330 |
representations["attentions"] = attentions
|
331 |
# 预测contact矩阵
|
332 |
-
if return_contacts
|
|
|
333 |
contacts = self.contact_head(input_ids, attentions)
|
334 |
representations["contacts"] = contacts
|
335 |
'''
|
336 |
print("output_keys:")
|
337 |
print(output_keys)
|
338 |
'''
|
339 |
-
if output_keys:
|
340 |
for item in output_keys.items():
|
341 |
cur_task_level_type = item[0]
|
342 |
if cur_task_level_type not in logits:
|
@@ -466,107 +453,153 @@ class LucaGPLM(PreTrainedModel):
|
|
466 |
use_last_layer_norm=use_last_layer_norm
|
467 |
)
|
468 |
has_pair_b = True
|
469 |
-
if has_pair and has_pair_b and pair_output_keys and len(pair_output_keys) > 0:
|
470 |
-
cur_representation_vector = encoding["representation_vector"]
|
471 |
-
cur_representation_vector_b = encoding_b["representation_vector"]
|
472 |
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
if cur_task_level_type not in pair_outputs:
|
478 |
-
pair_outputs[cur_task_level_type] = {}
|
479 |
-
pair_logits[cur_task_level_type] = {}
|
480 |
-
for cur_task_level_name in item1[1]:
|
481 |
-
cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](
|
482 |
-
torch.cat((cur_representation_vector, cur_representation_vector_b), dim=-1)
|
483 |
-
)
|
484 |
-
cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
|
485 |
-
if cur_hidden_layer is not None:
|
486 |
-
cur_logits = cur_hidden_layer(cur_logits)
|
487 |
-
cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
|
488 |
-
pair_logits[cur_task_level_type][cur_task_level_name] = cur_logits
|
489 |
-
pair_outputs[cur_task_level_type][cur_task_level_name] = self.output[cur_task_level_type][cur_task_level_name](cur_logits)
|
490 |
|
491 |
-
|
492 |
-
|
493 |
for item1 in pair_output_keys.items():
|
494 |
cur_task_level_type = item1[0]
|
495 |
-
if cur_task_level_type not in
|
496 |
-
|
497 |
-
|
498 |
-
pair_loss[cur_task_level_type] = {}
|
499 |
for cur_task_level_name in item1[1]:
|
500 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
continue
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
|
|
|
|
|
|
|
|
|
|
513 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
514 |
if not return_dict:
|
515 |
-
return [[losses
|
516 |
return AllOutput(
|
517 |
losses=losses,
|
518 |
outputs=outputs,
|
519 |
hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
|
520 |
attentions=encoding["attentions"] if "attentions" in encoding else None,
|
521 |
global_attentions=None,
|
522 |
-
contacts=encoding["contacts"] if "contacts" in encoding else None
|
|
|
|
|
|
|
|
|
|
|
523 |
losses_b=losses_b,
|
524 |
outputs_b=outputs_b,
|
525 |
hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
|
526 |
-
attentions_b=encoding_b["attentions"] if "
|
527 |
global_attentions_b=None,
|
528 |
-
contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
if not return_dict:
|
533 |
-
return [[
|
534 |
return AllOutput(
|
535 |
-
losses=
|
536 |
-
outputs=
|
537 |
hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
|
538 |
attentions=encoding["attentions"] if "attentions" in encoding else None,
|
539 |
global_attentions=None,
|
540 |
contacts=encoding["contacts"] if "contacts" in encoding else None,
|
541 |
-
losses_b=
|
542 |
-
outputs_b=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
543 |
hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
|
544 |
attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
|
545 |
global_attentions_b=None,
|
546 |
contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
|
547 |
)
|
548 |
-
elif has_pair:
|
549 |
-
if not return_dict:
|
550 |
-
return [[losses], [outputs], [encoding]]
|
551 |
-
return AllOutput(
|
552 |
-
losses=losses,
|
553 |
-
outputs=outputs,
|
554 |
-
hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
|
555 |
-
attentions=encoding["attentions"] if "attentions" in encoding else None,
|
556 |
-
global_attentions=None,
|
557 |
-
contacts=encoding["contacts"] if "contacts" in encoding else None
|
558 |
-
)
|
559 |
-
else:
|
560 |
-
if not return_dict:
|
561 |
-
return [[losses_b], [outputs_b], [encoding_b]]
|
562 |
-
return AllOutput(
|
563 |
-
losses_b=losses_b,
|
564 |
-
outputs_b=outputs_b,
|
565 |
-
hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
|
566 |
-
attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
|
567 |
-
global_attentions_b=None,
|
568 |
-
contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
|
569 |
-
)
|
570 |
|
571 |
def predict_contacts(self, input_ids, position_ids=None, token_type_ids=None):
|
572 |
-
return self(
|
|
|
|
|
|
|
|
|
|
37 |
self.use_embed_layer_norm = config.use_embed_layer_norm
|
38 |
self.use_last_layer_norm = config.use_last_layer_norm
|
39 |
self.embed_scale = config.embed_scale
|
40 |
+
self.embedding_inference = True
|
41 |
self._init_submodules()
|
42 |
|
43 |
def _init_submodules(self):
|
|
|
73 |
)
|
74 |
self.layer_size = len(self.layers)
|
75 |
|
76 |
+
if not self.embedding_inference:
|
77 |
+
self.contact_head = ContactPredictionHead(
|
78 |
+
self.num_layers * self.attention_heads,
|
79 |
+
self.prepend_bos,
|
80 |
+
self.append_eos,
|
81 |
+
eos_idx=self.eos_idx,
|
82 |
+
)
|
83 |
if self.use_last_layer_norm:
|
84 |
self.last_layer_norm = LucaGPLM1bLayerNorm(self.embed_dim)
|
85 |
else:
|
86 |
self.last_layer_norm = None
|
87 |
+
if not self.embedding_inference:
|
88 |
+
self.lm_head = RobertaLMHead(
|
89 |
+
embed_dim=self.embed_dim,
|
90 |
+
output_dim=self.alphabet_size,
|
91 |
+
weight=self.embed_tokens.weight,
|
92 |
+
)
|
93 |
|
94 |
def _init_embedding(self, pretrained_token_matrix, token_matrix):
|
95 |
'''
|
|
|
105 |
31->38
|
106 |
32->4
|
107 |
'''
|
108 |
+
# print("Load pretrained exists embedding vectors:")
|
109 |
token_matrix[2, :] = pretrained_token_matrix[0, :]
|
110 |
token_matrix[0, :] = pretrained_token_matrix[1, :]
|
111 |
token_matrix[3, :] = pretrained_token_matrix[2, :]
|
|
|
119 |
return token_matrix
|
120 |
|
121 |
def _init_submodules_new(self, pretrained_model_name):
|
122 |
+
# print("Load pretrained model exists weights:")
|
123 |
from esm import pretrained
|
124 |
from collections import OrderedDict
|
125 |
pretrained, _ = pretrained.load_model_and_alphabet(pretrained_model_name)
|
|
|
145 |
elif name in our_model_state_dict and our_model_state_dict[name].shape == weight.shape:
|
146 |
del our_model_state_dict[name]
|
147 |
new_state_dict[name] = weight
|
148 |
+
'''
|
149 |
print("Exists layer names:")
|
150 |
print(new_state_dict.keys())
|
151 |
print("Not exists Layer names:")
|
152 |
print(our_model_state_dict.keys())
|
153 |
+
'''
|
154 |
new_state_dict.update(our_model_state_dict)
|
155 |
self.load_state_dict(new_state_dict)
|
156 |
|
157 |
def __calc_loss__(self, task_level_type, output_mode, logits, label, label_size, loss_fct, loss_reduction):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
if output_mode in ["regression"]:
|
159 |
if task_level_type not in ["seq_level"] and loss_reduction == "meanmean":
|
160 |
# structure-level regression
|
|
|
292 |
representation_matrix = hidden_representations[self.layer_size]
|
293 |
# mask 任务
|
294 |
# B * Seq_len * vocab_size
|
295 |
+
if not self.embedding_inference:
|
296 |
+
lm_mask_logits = self.lm_head(x)
|
297 |
# lm head的输出向量作为表征向量
|
298 |
# (B, E)
|
299 |
representation_vector = representation_matrix[:, 0, :]
|
|
|
315 |
attentions = attentions * attention_mask[:, None, None, :, :]
|
316 |
representations["attentions"] = attentions
|
317 |
# 预测contact矩阵
|
318 |
+
if return_contacts and hasattr(self, "contact_head") \
|
319 |
+
and not self.embedding_inference:
|
320 |
contacts = self.contact_head(input_ids, attentions)
|
321 |
representations["contacts"] = contacts
|
322 |
'''
|
323 |
print("output_keys:")
|
324 |
print(output_keys)
|
325 |
'''
|
326 |
+
if not self.embedding_inference and output_keys:
|
327 |
for item in output_keys.items():
|
328 |
cur_task_level_type = item[0]
|
329 |
if cur_task_level_type not in logits:
|
|
|
453 |
use_last_layer_norm=use_last_layer_norm
|
454 |
)
|
455 |
has_pair_b = True
|
|
|
|
|
|
|
456 |
|
457 |
+
if not self.embedding_inference:
|
458 |
+
if has_pair and has_pair_b and pair_output_keys and len(pair_output_keys) > 0:
|
459 |
+
cur_representation_vector = encoding["representation_vector"]
|
460 |
+
cur_representation_vector_b = encoding_b["representation_vector"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
|
462 |
+
pair_logits = {}
|
463 |
+
pair_outputs = {}
|
464 |
for item1 in pair_output_keys.items():
|
465 |
cur_task_level_type = item1[0]
|
466 |
+
if cur_task_level_type not in pair_outputs:
|
467 |
+
pair_outputs[cur_task_level_type] = {}
|
468 |
+
pair_logits[cur_task_level_type] = {}
|
|
|
469 |
for cur_task_level_name in item1[1]:
|
470 |
+
cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](
|
471 |
+
torch.cat((cur_representation_vector, cur_representation_vector_b), dim=-1)
|
472 |
+
)
|
473 |
+
cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
|
474 |
+
if cur_hidden_layer is not None:
|
475 |
+
cur_logits = cur_hidden_layer(cur_logits)
|
476 |
+
cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
|
477 |
+
pair_logits[cur_task_level_type][cur_task_level_name] = cur_logits
|
478 |
+
pair_outputs[cur_task_level_type][cur_task_level_name] = self.output[cur_task_level_type][cur_task_level_name](cur_logits)
|
479 |
+
|
480 |
+
if pair_label is not None:
|
481 |
+
pair_loss = {}
|
482 |
+
for item1 in pair_output_keys.items():
|
483 |
+
cur_task_level_type = item1[0]
|
484 |
+
if cur_task_level_type not in pair_label:
|
485 |
continue
|
486 |
+
if cur_task_level_type in pair_label:
|
487 |
+
pair_loss[cur_task_level_type] = {}
|
488 |
+
for cur_task_level_name in item1[1]:
|
489 |
+
if cur_task_level_name not in pair_label[cur_task_level_type]:
|
490 |
+
continue
|
491 |
+
cur_label = pair_label[cur_task_level_type][cur_task_level_name]
|
492 |
+
cur_label_size = self.label_size[cur_task_level_type][cur_task_level_name]
|
493 |
+
cur_output_mode = self.output_mode[cur_task_level_type][cur_task_level_name]
|
494 |
+
cur_loss_fct = self.loss_fct[cur_task_level_type][cur_task_level_name]
|
495 |
+
cur_logits = pair_logits[cur_task_level_type][cur_task_level_name]
|
496 |
+
cur_loss = self.__calc_loss__(
|
497 |
+
task_level_type=cur_task_level_type,
|
498 |
+
output_mode=cur_output_mode, logits=cur_logits,
|
499 |
+
label=cur_label, label_size=cur_label_size, loss_fct=cur_loss_fct,
|
500 |
+
loss_reduction="meanmean")
|
501 |
+
pair_loss[cur_task_level_type][cur_task_level_name] = cur_loss
|
502 |
|
503 |
+
if not return_dict:
|
504 |
+
return [[losses, losses_b, pair_loss], [outputs, outputs_b, pair_outputs]] + [[encoding, encoding_b]]
|
505 |
+
return AllOutput(
|
506 |
+
losses=losses,
|
507 |
+
outputs=outputs,
|
508 |
+
hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
|
509 |
+
attentions=encoding["attentions"] if "attentions" in encoding else None,
|
510 |
+
global_attentions=None,
|
511 |
+
contacts=encoding["contacts"] if "contacts" in encoding else None,
|
512 |
+
losses_b=losses_b,
|
513 |
+
outputs_b=outputs_b,
|
514 |
+
hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
|
515 |
+
attentions_b=encoding_b["attentions"] if "hidden_states" in encoding_b else None,
|
516 |
+
global_attentions_b=None,
|
517 |
+
contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None,
|
518 |
+
pair_outputs=pair_outputs,
|
519 |
+
pair_losses=pair_loss)
|
520 |
+
else:
|
521 |
+
if not return_dict:
|
522 |
+
return [[losses, losses_b], [outputs, outputs_b]] + [[encoding, encoding_b]]
|
523 |
+
return AllOutput(
|
524 |
+
losses=losses,
|
525 |
+
outputs=outputs,
|
526 |
+
hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
|
527 |
+
attentions=encoding["attentions"] if "attentions" in encoding else None,
|
528 |
+
global_attentions=None,
|
529 |
+
contacts=encoding["contacts"] if "contacts" in encoding else None,
|
530 |
+
losses_b=losses_b,
|
531 |
+
outputs_b=outputs_b,
|
532 |
+
hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
|
533 |
+
attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
|
534 |
+
global_attentions_b=None,
|
535 |
+
contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
|
536 |
+
)
|
537 |
+
elif has_pair:
|
538 |
if not return_dict:
|
539 |
+
return [[losses], [outputs], [encoding]]
|
540 |
return AllOutput(
|
541 |
losses=losses,
|
542 |
outputs=outputs,
|
543 |
hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
|
544 |
attentions=encoding["attentions"] if "attentions" in encoding else None,
|
545 |
global_attentions=None,
|
546 |
+
contacts=encoding["contacts"] if "contacts" in encoding else None
|
547 |
+
)
|
548 |
+
else:
|
549 |
+
if not return_dict:
|
550 |
+
return [[losses_b], [outputs_b], [encoding_b]]
|
551 |
+
return AllOutput(
|
552 |
losses_b=losses_b,
|
553 |
outputs_b=outputs_b,
|
554 |
hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
|
555 |
+
attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
|
556 |
global_attentions_b=None,
|
557 |
+
contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
|
558 |
+
)
|
559 |
+
else:
|
560 |
+
if has_pair and has_pair_b:
|
561 |
if not return_dict:
|
562 |
+
return [[None, None], [None, None]] + [[encoding, encoding_b]]
|
563 |
return AllOutput(
|
564 |
+
losses=None,
|
565 |
+
outputs=None,
|
566 |
hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
|
567 |
attentions=encoding["attentions"] if "attentions" in encoding else None,
|
568 |
global_attentions=None,
|
569 |
contacts=encoding["contacts"] if "contacts" in encoding else None,
|
570 |
+
losses_b=None,
|
571 |
+
outputs_b=None,
|
572 |
+
hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
|
573 |
+
attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
|
574 |
+
global_attentions_b=None,
|
575 |
+
contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
|
576 |
+
)
|
577 |
+
elif has_pair:
|
578 |
+
if not return_dict:
|
579 |
+
return [[None], [None], [encoding]]
|
580 |
+
return AllOutput(
|
581 |
+
losses=None,
|
582 |
+
outputs=None,
|
583 |
+
hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
|
584 |
+
attentions=encoding["attentions"] if "attentions" in encoding else None,
|
585 |
+
global_attentions=None,
|
586 |
+
contacts=encoding["contacts"] if "contacts" in encoding else None
|
587 |
+
)
|
588 |
+
else:
|
589 |
+
if not return_dict:
|
590 |
+
return [[None], [None], [encoding_b]]
|
591 |
+
return AllOutput(
|
592 |
+
losses_b=None,
|
593 |
+
outputs_b=None,
|
594 |
hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
|
595 |
attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
|
596 |
global_attentions_b=None,
|
597 |
contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
|
598 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
599 |
|
600 |
def predict_contacts(self, input_ids, position_ids=None, token_type_ids=None):
|
601 |
+
return self(
|
602 |
+
input_ids=input_ids,
|
603 |
+
position_ids=position_ids,
|
604 |
+
token_type_ids=token_type_ids,
|
605 |
+
return_contacts=True)["contacts"]
|
modeling_bert.py
CHANGED
@@ -6,7 +6,7 @@
|
|
6 |
@email: sanyuan.**@**.com
|
7 |
@tel: 137****6540
|
8 |
@datetime: 2022/12/2 09:38
|
9 |
-
@project:
|
10 |
@file: modeling_bert
|
11 |
@desc: transformer layers
|
12 |
'''
|
@@ -179,22 +179,20 @@ class BertEmbeddings(nn.Module):
|
|
179 |
|
180 |
def __init__(self, config):
|
181 |
super().__init__()
|
182 |
-
|
183 |
-
self.no_token_embeddings = config.no_token_embeddings
|
184 |
-
else:
|
185 |
-
self.no_token_embeddings = False
|
186 |
-
if not self.no_token_embeddings:
|
187 |
-
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
188 |
if hasattr(config, "no_position_embeddings"):
|
189 |
self.no_position_embeddings = config.no_position_embeddings
|
190 |
else:
|
191 |
self.no_position_embeddings = False
|
|
|
192 |
if hasattr(config, "no_token_type_embeddings"):
|
193 |
self.no_token_type_embeddings = config.no_token_type_embeddings
|
194 |
else:
|
195 |
self.no_token_type_embeddings = False
|
|
|
196 |
if not self.no_position_embeddings:
|
197 |
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
|
|
198 |
if not self.no_token_type_embeddings:
|
199 |
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
200 |
|
@@ -206,7 +204,10 @@ class BertEmbeddings(nn.Module):
|
|
206 |
if not self.no_position_embeddings:
|
207 |
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
208 |
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
209 |
-
|
|
|
|
|
|
|
210 |
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
211 |
self.register_buffer(
|
212 |
"token_type_ids",
|
@@ -229,21 +230,20 @@ class BertEmbeddings(nn.Module):
|
|
229 |
|
230 |
seq_length = input_shape[1]
|
231 |
|
232 |
-
if not self.no_position_embeddings and position_ids is None
|
233 |
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
234 |
|
235 |
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
236 |
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
237 |
# issue #5664
|
238 |
-
if not self.no_token_type_embeddings
|
239 |
-
if
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
raise Exception("The model has not token_embeddings layer, the inputs_embeds cannot None")
|
247 |
|
248 |
if inputs_embeds is None:
|
249 |
inputs_embeds = self.word_embeddings(input_ids)
|
@@ -898,14 +898,11 @@ class BertModel(BertPreTrainedModel):
|
|
898 |
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
899 |
"""
|
900 |
|
901 |
-
def __init__(self, config,
|
902 |
super().__init__(config)
|
903 |
self.config = config
|
904 |
-
self.use_pretrained_embedding = use_pretrained_embedding
|
905 |
-
self.add_pooling_layer = add_pooling_layer
|
906 |
-
|
907 |
-
self.embeddings = nn.Linear(config.embedding_input_size, config.hidden_size) if use_pretrained_embedding else BertEmbeddings(config)
|
908 |
|
|
|
909 |
self.encoder = BertEncoder(config)
|
910 |
|
911 |
self.pooler = BertPooler(config) if add_pooling_layer else None
|
@@ -1029,16 +1026,13 @@ class BertModel(BertPreTrainedModel):
|
|
1029 |
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
1030 |
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
1031 |
|
1032 |
-
|
1033 |
-
|
1034 |
-
|
1035 |
-
|
1036 |
-
|
1037 |
-
|
1038 |
-
|
1039 |
-
inputs_embeds=inputs_embeds,
|
1040 |
-
past_key_values_length=past_key_values_length,
|
1041 |
-
)
|
1042 |
encoder_outputs = self.encoder(
|
1043 |
embedding_output,
|
1044 |
attention_mask=extended_attention_mask,
|
|
|
6 |
@email: sanyuan.**@**.com
|
7 |
@tel: 137****6540
|
8 |
@datetime: 2022/12/2 09:38
|
9 |
+
@project: LucaOne
|
10 |
@file: modeling_bert
|
11 |
@desc: transformer layers
|
12 |
'''
|
|
|
179 |
|
180 |
def __init__(self, config):
|
181 |
super().__init__()
|
182 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
|
|
|
|
|
|
|
|
|
|
183 |
if hasattr(config, "no_position_embeddings"):
|
184 |
self.no_position_embeddings = config.no_position_embeddings
|
185 |
else:
|
186 |
self.no_position_embeddings = False
|
187 |
+
|
188 |
if hasattr(config, "no_token_type_embeddings"):
|
189 |
self.no_token_type_embeddings = config.no_token_type_embeddings
|
190 |
else:
|
191 |
self.no_token_type_embeddings = False
|
192 |
+
|
193 |
if not self.no_position_embeddings:
|
194 |
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
195 |
+
|
196 |
if not self.no_token_type_embeddings:
|
197 |
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
198 |
|
|
|
204 |
if not self.no_position_embeddings:
|
205 |
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
206 |
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
207 |
+
|
208 |
+
if not self.no_token_type_embeddings:
|
209 |
+
if not hasattr(self, "position_ids"):
|
210 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
211 |
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
212 |
self.register_buffer(
|
213 |
"token_type_ids",
|
|
|
230 |
|
231 |
seq_length = input_shape[1]
|
232 |
|
233 |
+
if (not self.no_position_embeddings or not self.no_token_type_embeddings) and position_ids is None:
|
234 |
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
235 |
|
236 |
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
237 |
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
238 |
# issue #5664
|
239 |
+
if not self.no_token_type_embeddings:
|
240 |
+
if token_type_ids is None:
|
241 |
+
if hasattr(self, "token_type_ids"):
|
242 |
+
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
243 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
244 |
+
token_type_ids = buffered_token_type_ids_expanded
|
245 |
+
else:
|
246 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
|
|
247 |
|
248 |
if inputs_embeds is None:
|
249 |
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
898 |
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
899 |
"""
|
900 |
|
901 |
+
def __init__(self, config, add_pooling_layer=True):
|
902 |
super().__init__(config)
|
903 |
self.config = config
|
|
|
|
|
|
|
|
|
904 |
|
905 |
+
self.embeddings = BertEmbeddings(config)
|
906 |
self.encoder = BertEncoder(config)
|
907 |
|
908 |
self.pooler = BertPooler(config) if add_pooling_layer else None
|
|
|
1026 |
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
1027 |
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
1028 |
|
1029 |
+
embedding_output = self.embeddings(
|
1030 |
+
input_ids=input_ids,
|
1031 |
+
position_ids=position_ids,
|
1032 |
+
token_type_ids=token_type_ids,
|
1033 |
+
inputs_embeds=inputs_embeds,
|
1034 |
+
past_key_values_length=past_key_values_length,
|
1035 |
+
)
|
|
|
|
|
|
|
1036 |
encoder_outputs = self.encoder(
|
1037 |
embedding_output,
|
1038 |
attention_mask=extended_attention_mask,
|
modeling_gplm.py
CHANGED
@@ -1,6 +1,15 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
# encoding: utf-8
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import math
|
5 |
from typing import Dict, Optional, Sequence, Tuple, List, Union
|
6 |
import uuid
|
@@ -11,19 +20,14 @@ from torch.nn import Parameter
|
|
11 |
|
12 |
|
13 |
def gelu(x):
|
14 |
-
"""Implementation of the gelu activation function.
|
15 |
-
OpenAI GPT's gelu: 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
16 |
-
"""
|
17 |
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
18 |
|
19 |
|
20 |
def symmetrize(x):
|
21 |
-
"Make layer symmetric in final two dimensions, used for contact prediction."
|
22 |
return x + x.transpose(-1, -2)
|
23 |
|
24 |
|
25 |
def apc(x):
|
26 |
-
"Perform average product correct, used for contact prediction."
|
27 |
a1 = x.sum(-1, keepdims=True)
|
28 |
a2 = x.sum(-2, keepdims=True)
|
29 |
a12 = x.sum((-1, -2), keepdims=True)
|
@@ -57,7 +61,22 @@ class LucaGPLM1LayerNorm(nn.Module):
|
|
57 |
x = (self.weight * x) + self.bias
|
58 |
return x
|
59 |
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
|
63 |
class LucaGPLMTransformerLayer(nn.Module):
|
@@ -141,7 +160,6 @@ class LucaGPLMTransformerLayer(nn.Module):
|
|
141 |
|
142 |
|
143 |
class AxialTransformerLayer(nn.Module):
|
144 |
-
"""Implements an Axial MSA Transformer block."""
|
145 |
def __init__(
|
146 |
self,
|
147 |
embedding_dim: int = 768,
|
@@ -197,10 +215,6 @@ class AxialTransformerLayer(nn.Module):
|
|
197 |
self_attn_padding_mask: Optional[torch.Tensor] = None,
|
198 |
need_head_weights: bool = False,
|
199 |
):
|
200 |
-
"""
|
201 |
-
LayerNorm is applied either before or after the self-attention/ffn
|
202 |
-
modules similar to the original Transformer implementation.
|
203 |
-
"""
|
204 |
x, row_attn = self.row_self_attention(
|
205 |
x,
|
206 |
self_attn_mask=self_attn_mask,
|
@@ -219,13 +233,6 @@ class AxialTransformerLayer(nn.Module):
|
|
219 |
|
220 |
|
221 |
class LearnedPositionalEmbedding(nn.Embedding):
|
222 |
-
"""
|
223 |
-
This module learns positional embeddings up to a fixed maximum size.
|
224 |
-
Padding ids are ignored by either offsetting based on padding_idx
|
225 |
-
or by setting padding_idx to None and ensuring that the appropriate
|
226 |
-
position ids are passed to the forward function.
|
227 |
-
"""
|
228 |
-
|
229 |
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
|
230 |
if padding_idx is not None:
|
231 |
num_embeddings_ = num_embeddings + padding_idx + 1
|
@@ -293,8 +300,6 @@ class SinusoidalPositionalEmbedding(nn.Module):
|
|
293 |
|
294 |
|
295 |
class RobertaLMHead(nn.Module):
|
296 |
-
"""Head for masked language modeling."""
|
297 |
-
|
298 |
def __init__(self, embed_dim, output_dim, weight):
|
299 |
super().__init__()
|
300 |
self.dense = nn.Linear(embed_dim, embed_dim)
|
@@ -312,8 +317,6 @@ class RobertaLMHead(nn.Module):
|
|
312 |
|
313 |
|
314 |
class ContactPredictionHead(nn.Module):
|
315 |
-
"""Performs symmetrization, apc, and computes a logistic regression on the output features"""
|
316 |
-
|
317 |
def __init__(
|
318 |
self,
|
319 |
in_features: int,
|
@@ -697,11 +700,6 @@ def with_incremental_state(cls):
|
|
697 |
|
698 |
@with_incremental_state
|
699 |
class LucaGPLMMultiheadAttention(nn.Module):
|
700 |
-
"""Multi-headed attention.
|
701 |
-
|
702 |
-
See "Attention Is All You Need" for more details.
|
703 |
-
"""
|
704 |
-
|
705 |
def __init__(
|
706 |
self,
|
707 |
embed_dim,
|
@@ -768,18 +766,6 @@ class LucaGPLMMultiheadAttention(nn.Module):
|
|
768 |
self.onnx_trace = True
|
769 |
|
770 |
def reset_parameters(self):
|
771 |
-
'''
|
772 |
-
if self.qkv_same_dim:
|
773 |
-
# Empirically observed the convergence to be much better with
|
774 |
-
# the scaled initialization
|
775 |
-
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
776 |
-
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
777 |
-
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
778 |
-
else:
|
779 |
-
nn.init.xavier_uniform_(self.k_proj.weight)
|
780 |
-
nn.init.xavier_uniform_(self.v_proj.weight)
|
781 |
-
nn.init.xavier_uniform_(self.q_proj.weight)
|
782 |
-
'''
|
783 |
nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu"))
|
784 |
nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu"))
|
785 |
nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu"))
|
@@ -806,23 +792,6 @@ class LucaGPLMMultiheadAttention(nn.Module):
|
|
806 |
before_softmax: bool = False,
|
807 |
need_head_weights: bool = False,
|
808 |
) -> Tuple[Tensor, Optional[Tensor]]:
|
809 |
-
"""Input shape: Time x Batch x Channel
|
810 |
-
|
811 |
-
Args:
|
812 |
-
key_padding_mask (ByteTensor, optional): mask to exclude
|
813 |
-
keys that are pads, of shape `(batch, src_len)`, where
|
814 |
-
padding elements are indicated by 1s.
|
815 |
-
need_weights (bool, optional): return the attention weights,
|
816 |
-
averaged over heads (default: False).
|
817 |
-
attn_mask (ByteTensor, optional): typically used to
|
818 |
-
implement causal attention, where the mask prevents the
|
819 |
-
attention from looking forward in time (default: None).
|
820 |
-
before_softmax (bool, optional): return the raw attention
|
821 |
-
weights and values before the attention softmax.
|
822 |
-
need_head_weights (bool, optional): return the attention
|
823 |
-
weights for each head. Implies *need_weights*. Default:
|
824 |
-
return the average attention weights over all heads.
|
825 |
-
"""
|
826 |
if need_head_weights:
|
827 |
need_weights = True
|
828 |
|
@@ -1081,7 +1050,6 @@ class LucaGPLMMultiheadAttention(nn.Module):
|
|
1081 |
def reorder_incremental_state(
|
1082 |
self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
|
1083 |
):
|
1084 |
-
"""Reorder buffered internal state (for incremental generation)."""
|
1085 |
input_buffer = self._get_input_buffer(incremental_state)
|
1086 |
if input_buffer is not None:
|
1087 |
for k in input_buffer.keys():
|
@@ -1121,7 +1089,6 @@ class LucaGPLMMultiheadAttention(nn.Module):
|
|
1121 |
keys_to_remove = []
|
1122 |
for k in state_dict.keys():
|
1123 |
if k.endswith(prefix + "in_proj_weight"):
|
1124 |
-
# in_proj_weight used to be q + k + v with same dimensions
|
1125 |
dim = int(state_dict[k].shape[0] / 3)
|
1126 |
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
1127 |
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
@@ -1158,22 +1125,8 @@ def apply_rotary_pos_emb(x, cos, sin):
|
|
1158 |
|
1159 |
|
1160 |
class RotaryEmbedding(torch.nn.Module):
|
1161 |
-
"""
|
1162 |
-
The rotary position embeddings from RoFormer_ (Su et. al).
|
1163 |
-
A crucial insight from the method is that the query and keys are
|
1164 |
-
transformed by rotation matrices which depend on the relative positions.
|
1165 |
-
Other implementations are available in the Rotary Transformer repo_ and in
|
1166 |
-
GPT-NeoX_, GPT-NeoX was an inspiration
|
1167 |
-
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
1168 |
-
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
1169 |
-
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
1170 |
-
.. warning: Please note that this embedding is not registered on purpose, as it is transformative
|
1171 |
-
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
|
1172 |
-
"""
|
1173 |
-
|
1174 |
def __init__(self, dim: int, *_, **__):
|
1175 |
super().__init__()
|
1176 |
-
# Generate and save the inverse frequency buffer (non trainable)
|
1177 |
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
1178 |
self.register_buffer("inv_freq", inv_freq)
|
1179 |
|
@@ -1184,8 +1137,6 @@ class RotaryEmbedding(torch.nn.Module):
|
|
1184 |
def _update_cos_sin_tables(self, x, seq_dimension=1):
|
1185 |
seq_len = x.shape[seq_dimension]
|
1186 |
|
1187 |
-
# Reset the tables if the sequence length has changed,
|
1188 |
-
# or if we're on a new device (possibly due to tracing for instance)
|
1189 |
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
|
1190 |
self._seq_len_cached = seq_len
|
1191 |
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
|
|
|
1 |
#!/usr/bin/env python
|
2 |
# encoding: utf-8
|
3 |
+
'''
|
4 |
+
@license: (C) Copyright 2021, Hey.
|
5 |
+
@author: Hey
|
6 |
+
@email: [email protected]
|
7 |
+
@tel: 137****6540
|
8 |
+
@datetime: 2023/7/24 10:01
|
9 |
+
@project: LucaOne
|
10 |
+
@file: modeling_gplm
|
11 |
+
@desc: LucaOne Model Detail
|
12 |
+
'''
|
13 |
import math
|
14 |
from typing import Dict, Optional, Sequence, Tuple, List, Union
|
15 |
import uuid
|
|
|
20 |
|
21 |
|
22 |
def gelu(x):
|
|
|
|
|
|
|
23 |
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
24 |
|
25 |
|
26 |
def symmetrize(x):
|
|
|
27 |
return x + x.transpose(-1, -2)
|
28 |
|
29 |
|
30 |
def apc(x):
|
|
|
31 |
a1 = x.sum(-1, keepdims=True)
|
32 |
a2 = x.sum(-2, keepdims=True)
|
33 |
a12 = x.sum((-1, -2), keepdims=True)
|
|
|
61 |
x = (self.weight * x) + self.bias
|
62 |
return x
|
63 |
|
64 |
+
|
65 |
+
try:
|
66 |
+
# Optimized LayerNorm
|
67 |
+
from apex.normalization import FusedLayerNorm as _FusedLayerNorm
|
68 |
+
class LucaGPLM1bLayerNorm(_FusedLayerNorm):
|
69 |
+
@torch.jit.unused
|
70 |
+
def forward(self, x):
|
71 |
+
if not x.is_cuda:
|
72 |
+
return super().forward(x)
|
73 |
+
else:
|
74 |
+
with torch.cuda.device(x.device):
|
75 |
+
return super().forward(x)
|
76 |
+
|
77 |
+
except ImportError as e:
|
78 |
+
print("import apex err:", e)
|
79 |
+
from torch.nn import LayerNorm as LucaGPLM1bLayerNorm
|
80 |
|
81 |
|
82 |
class LucaGPLMTransformerLayer(nn.Module):
|
|
|
160 |
|
161 |
|
162 |
class AxialTransformerLayer(nn.Module):
|
|
|
163 |
def __init__(
|
164 |
self,
|
165 |
embedding_dim: int = 768,
|
|
|
215 |
self_attn_padding_mask: Optional[torch.Tensor] = None,
|
216 |
need_head_weights: bool = False,
|
217 |
):
|
|
|
|
|
|
|
|
|
218 |
x, row_attn = self.row_self_attention(
|
219 |
x,
|
220 |
self_attn_mask=self_attn_mask,
|
|
|
233 |
|
234 |
|
235 |
class LearnedPositionalEmbedding(nn.Embedding):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
|
237 |
if padding_idx is not None:
|
238 |
num_embeddings_ = num_embeddings + padding_idx + 1
|
|
|
300 |
|
301 |
|
302 |
class RobertaLMHead(nn.Module):
|
|
|
|
|
303 |
def __init__(self, embed_dim, output_dim, weight):
|
304 |
super().__init__()
|
305 |
self.dense = nn.Linear(embed_dim, embed_dim)
|
|
|
317 |
|
318 |
|
319 |
class ContactPredictionHead(nn.Module):
|
|
|
|
|
320 |
def __init__(
|
321 |
self,
|
322 |
in_features: int,
|
|
|
700 |
|
701 |
@with_incremental_state
|
702 |
class LucaGPLMMultiheadAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
703 |
def __init__(
|
704 |
self,
|
705 |
embed_dim,
|
|
|
766 |
self.onnx_trace = True
|
767 |
|
768 |
def reset_parameters(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
769 |
nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu"))
|
770 |
nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu"))
|
771 |
nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu"))
|
|
|
792 |
before_softmax: bool = False,
|
793 |
need_head_weights: bool = False,
|
794 |
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
795 |
if need_head_weights:
|
796 |
need_weights = True
|
797 |
|
|
|
1050 |
def reorder_incremental_state(
|
1051 |
self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
|
1052 |
):
|
|
|
1053 |
input_buffer = self._get_input_buffer(incremental_state)
|
1054 |
if input_buffer is not None:
|
1055 |
for k in input_buffer.keys():
|
|
|
1089 |
keys_to_remove = []
|
1090 |
for k in state_dict.keys():
|
1091 |
if k.endswith(prefix + "in_proj_weight"):
|
|
|
1092 |
dim = int(state_dict[k].shape[0] / 3)
|
1093 |
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
1094 |
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
|
|
1125 |
|
1126 |
|
1127 |
class RotaryEmbedding(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1128 |
def __init__(self, dim: int, *_, **__):
|
1129 |
super().__init__()
|
|
|
1130 |
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
1131 |
self.register_buffer("inv_freq", inv_freq)
|
1132 |
|
|
|
1137 |
def _update_cos_sin_tables(self, x, seq_dimension=1):
|
1138 |
seq_len = x.shape[seq_dimension]
|
1139 |
|
|
|
|
|
1140 |
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
|
1141 |
self._seq_len_cached = seq_len
|
1142 |
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
|