Yuanfei commited on
Commit
ca6b592
·
verified ·
1 Parent(s): 4a61dbb

Upload 7 files

Browse files
Files changed (4) hide show
  1. alphabet.py +0 -9
  2. lucaone_gplm.py +142 -109
  3. modeling_bert.py +27 -33
  4. modeling_gplm.py +26 -75
alphabet.py CHANGED
@@ -6,7 +6,6 @@ import json
6
  import itertools
7
  from typing import Sequence, List
8
  from transformers import PreTrainedTokenizer
9
- from .batch_converter import BatchConverter
10
 
11
  gene_standard_toks = ['1', '2', '3', '4', '5', '.', '-', '*']
12
 
@@ -63,14 +62,6 @@ class Alphabet(object):
63
  def to_dict(self):
64
  return self.tok_to_idx.copy()
65
 
66
- def get_batch_converter(self, no_position_embeddings, no_token_type_embeddings, truncation_seq_length: int = None, ignore_index: int = -100, mlm_probability=0.15):
67
- return BatchConverter(self,
68
- no_position_embeddings=no_position_embeddings,
69
- no_token_type_embeddings=no_token_type_embeddings,
70
- truncation_seq_length=truncation_seq_length,
71
- ignore_index=ignore_index,
72
- mlm_probability=mlm_probability)
73
-
74
  @classmethod
75
  def from_predefined(cls, name: str):
76
  if name.lower() == "prot":
 
6
  import itertools
7
  from typing import Sequence, List
8
  from transformers import PreTrainedTokenizer
 
9
 
10
  gene_standard_toks = ['1', '2', '3', '4', '5', '.', '-', '*']
11
 
 
62
  def to_dict(self):
63
  return self.tok_to_idx.copy()
64
 
 
 
 
 
 
 
 
 
65
  @classmethod
66
  def from_predefined(cls, name: str):
67
  if name.lower() == "prot":
lucaone_gplm.py CHANGED
@@ -37,6 +37,7 @@ class LucaGPLM(PreTrainedModel):
37
  self.use_embed_layer_norm = config.use_embed_layer_norm
38
  self.use_last_layer_norm = config.use_last_layer_norm
39
  self.embed_scale = config.embed_scale
 
40
  self._init_submodules()
41
 
42
  def _init_submodules(self):
@@ -72,22 +73,23 @@ class LucaGPLM(PreTrainedModel):
72
  )
73
  self.layer_size = len(self.layers)
74
 
75
- self.contact_head = ContactPredictionHead(
76
- self.num_layers * self.attention_heads,
77
- self.prepend_bos,
78
- self.append_eos,
79
- eos_idx=self.eos_idx,
80
- )
 
81
  if self.use_last_layer_norm:
82
  self.last_layer_norm = LucaGPLM1bLayerNorm(self.embed_dim)
83
  else:
84
  self.last_layer_norm = None
85
-
86
- self.lm_head = RobertaLMHead(
87
- embed_dim=self.embed_dim,
88
- output_dim=self.alphabet_size,
89
- weight=self.embed_tokens.weight,
90
- )
91
 
92
  def _init_embedding(self, pretrained_token_matrix, token_matrix):
93
  '''
@@ -103,7 +105,7 @@ class LucaGPLM(PreTrainedModel):
103
  31->38
104
  32->4
105
  '''
106
- print("Load pretrained exsists embedding vectors:")
107
  token_matrix[2, :] = pretrained_token_matrix[0, :]
108
  token_matrix[0, :] = pretrained_token_matrix[1, :]
109
  token_matrix[3, :] = pretrained_token_matrix[2, :]
@@ -117,7 +119,7 @@ class LucaGPLM(PreTrainedModel):
117
  return token_matrix
118
 
119
  def _init_submodules_new(self, pretrained_model_name):
120
- print("Load pretrained model exists weights:")
121
  from esm import pretrained
122
  from collections import OrderedDict
123
  pretrained, _ = pretrained.load_model_and_alphabet(pretrained_model_name)
@@ -143,33 +145,16 @@ class LucaGPLM(PreTrainedModel):
143
  elif name in our_model_state_dict and our_model_state_dict[name].shape == weight.shape:
144
  del our_model_state_dict[name]
145
  new_state_dict[name] = weight
146
-
147
  print("Exists layer names:")
148
  print(new_state_dict.keys())
149
  print("Not exists Layer names:")
150
  print(our_model_state_dict.keys())
 
151
  new_state_dict.update(our_model_state_dict)
152
  self.load_state_dict(new_state_dict)
153
 
154
  def __calc_loss__(self, task_level_type, output_mode, logits, label, label_size, loss_fct, loss_reduction):
155
- '''
156
- if label_size <= 2 or output_mode in ["binary_class", "binary-class"]:
157
- loss = loss_fct(logits.view(-1), label.view(-1).float())
158
- elif output_mode in ["multi_label", "multi-label"]:
159
- loss = loss_fct(logits.view(-1, label_size), label.view(-1, label_size).float())
160
- elif output_mode in ["multi_class", "multi-class"]:
161
- loss = loss_fct(logits.view(-1, label_size), label.view(-1))
162
- else:
163
- loss = loss_fct(logits.view(-1), label.view(-1))
164
- return loss
165
- '''
166
- '''
167
- print(task_level_type, output_mode, label_size, loss_fct, loss_reduction)
168
- print("logits:")
169
- print(logits.shape)
170
- print("label:")
171
- print(label.shape)
172
- '''
173
  if output_mode in ["regression"]:
174
  if task_level_type not in ["seq_level"] and loss_reduction == "meanmean":
175
  # structure-level regression
@@ -307,7 +292,8 @@ class LucaGPLM(PreTrainedModel):
307
  representation_matrix = hidden_representations[self.layer_size]
308
  # mask 任务
309
  # B * Seq_len * vocab_size
310
- lm_mask_logits = self.lm_head(x)
 
311
  # lm head的输出向量作为表征向量
312
  # (B, E)
313
  representation_vector = representation_matrix[:, 0, :]
@@ -329,14 +315,15 @@ class LucaGPLM(PreTrainedModel):
329
  attentions = attentions * attention_mask[:, None, None, :, :]
330
  representations["attentions"] = attentions
331
  # 预测contact矩阵
332
- if return_contacts:
 
333
  contacts = self.contact_head(input_ids, attentions)
334
  representations["contacts"] = contacts
335
  '''
336
  print("output_keys:")
337
  print(output_keys)
338
  '''
339
- if output_keys:
340
  for item in output_keys.items():
341
  cur_task_level_type = item[0]
342
  if cur_task_level_type not in logits:
@@ -466,107 +453,153 @@ class LucaGPLM(PreTrainedModel):
466
  use_last_layer_norm=use_last_layer_norm
467
  )
468
  has_pair_b = True
469
- if has_pair and has_pair_b and pair_output_keys and len(pair_output_keys) > 0:
470
- cur_representation_vector = encoding["representation_vector"]
471
- cur_representation_vector_b = encoding_b["representation_vector"]
472
 
473
- pair_logits = {}
474
- pair_outputs = {}
475
- for item1 in pair_output_keys.items():
476
- cur_task_level_type = item1[0]
477
- if cur_task_level_type not in pair_outputs:
478
- pair_outputs[cur_task_level_type] = {}
479
- pair_logits[cur_task_level_type] = {}
480
- for cur_task_level_name in item1[1]:
481
- cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](
482
- torch.cat((cur_representation_vector, cur_representation_vector_b), dim=-1)
483
- )
484
- cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
485
- if cur_hidden_layer is not None:
486
- cur_logits = cur_hidden_layer(cur_logits)
487
- cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
488
- pair_logits[cur_task_level_type][cur_task_level_name] = cur_logits
489
- pair_outputs[cur_task_level_type][cur_task_level_name] = self.output[cur_task_level_type][cur_task_level_name](cur_logits)
490
 
491
- if pair_label is not None:
492
- pair_loss = {}
493
  for item1 in pair_output_keys.items():
494
  cur_task_level_type = item1[0]
495
- if cur_task_level_type not in pair_label:
496
- continue
497
- if cur_task_level_type in pair_label:
498
- pair_loss[cur_task_level_type] = {}
499
  for cur_task_level_name in item1[1]:
500
- if cur_task_level_name not in pair_label[cur_task_level_type]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  continue
502
- cur_label = pair_label[cur_task_level_type][cur_task_level_name]
503
- cur_label_size = self.label_size[cur_task_level_type][cur_task_level_name]
504
- cur_output_mode = self.output_mode[cur_task_level_type][cur_task_level_name]
505
- cur_loss_fct = self.loss_fct[cur_task_level_type][cur_task_level_name]
506
- cur_logits = pair_logits[cur_task_level_type][cur_task_level_name]
507
- cur_loss = self.__calc_loss__(
508
- task_level_type=cur_task_level_type,
509
- output_mode=cur_output_mode, logits=cur_logits,
510
- label=cur_label, label_size=cur_label_size, loss_fct=cur_loss_fct,
511
- loss_reduction="meanmean")
512
- pair_loss[cur_task_level_type][cur_task_level_name] = cur_loss
 
 
 
 
 
513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  if not return_dict:
515
- return [[losses, losses_b, pair_loss], [outputs, outputs_b, pair_outputs]] + [[encoding, encoding_b]]
516
  return AllOutput(
517
  losses=losses,
518
  outputs=outputs,
519
  hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
520
  attentions=encoding["attentions"] if "attentions" in encoding else None,
521
  global_attentions=None,
522
- contacts=encoding["contacts"] if "contacts" in encoding else None,
 
 
 
 
 
523
  losses_b=losses_b,
524
  outputs_b=outputs_b,
525
  hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
526
- attentions_b=encoding_b["attentions"] if "hidden_states" in encoding_b else None,
527
  global_attentions_b=None,
528
- contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None,
529
- pair_outputs=pair_outputs,
530
- pair_losses=pair_loss)
531
- else:
532
  if not return_dict:
533
- return [[losses, losses_b], [outputs, outputs_b]] + [[encoding, encoding_b]]
534
  return AllOutput(
535
- losses=losses,
536
- outputs=outputs,
537
  hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
538
  attentions=encoding["attentions"] if "attentions" in encoding else None,
539
  global_attentions=None,
540
  contacts=encoding["contacts"] if "contacts" in encoding else None,
541
- losses_b=losses_b,
542
- outputs_b=outputs_b,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
544
  attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
545
  global_attentions_b=None,
546
  contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
547
  )
548
- elif has_pair:
549
- if not return_dict:
550
- return [[losses], [outputs], [encoding]]
551
- return AllOutput(
552
- losses=losses,
553
- outputs=outputs,
554
- hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
555
- attentions=encoding["attentions"] if "attentions" in encoding else None,
556
- global_attentions=None,
557
- contacts=encoding["contacts"] if "contacts" in encoding else None
558
- )
559
- else:
560
- if not return_dict:
561
- return [[losses_b], [outputs_b], [encoding_b]]
562
- return AllOutput(
563
- losses_b=losses_b,
564
- outputs_b=outputs_b,
565
- hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
566
- attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
567
- global_attentions_b=None,
568
- contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
569
- )
570
 
571
  def predict_contacts(self, input_ids, position_ids=None, token_type_ids=None):
572
- return self(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, return_contacts=True)["contacts"]
 
 
 
 
 
37
  self.use_embed_layer_norm = config.use_embed_layer_norm
38
  self.use_last_layer_norm = config.use_last_layer_norm
39
  self.embed_scale = config.embed_scale
40
+ self.embedding_inference = True
41
  self._init_submodules()
42
 
43
  def _init_submodules(self):
 
73
  )
74
  self.layer_size = len(self.layers)
75
 
76
+ if not self.embedding_inference:
77
+ self.contact_head = ContactPredictionHead(
78
+ self.num_layers * self.attention_heads,
79
+ self.prepend_bos,
80
+ self.append_eos,
81
+ eos_idx=self.eos_idx,
82
+ )
83
  if self.use_last_layer_norm:
84
  self.last_layer_norm = LucaGPLM1bLayerNorm(self.embed_dim)
85
  else:
86
  self.last_layer_norm = None
87
+ if not self.embedding_inference:
88
+ self.lm_head = RobertaLMHead(
89
+ embed_dim=self.embed_dim,
90
+ output_dim=self.alphabet_size,
91
+ weight=self.embed_tokens.weight,
92
+ )
93
 
94
  def _init_embedding(self, pretrained_token_matrix, token_matrix):
95
  '''
 
105
  31->38
106
  32->4
107
  '''
108
+ # print("Load pretrained exists embedding vectors:")
109
  token_matrix[2, :] = pretrained_token_matrix[0, :]
110
  token_matrix[0, :] = pretrained_token_matrix[1, :]
111
  token_matrix[3, :] = pretrained_token_matrix[2, :]
 
119
  return token_matrix
120
 
121
  def _init_submodules_new(self, pretrained_model_name):
122
+ # print("Load pretrained model exists weights:")
123
  from esm import pretrained
124
  from collections import OrderedDict
125
  pretrained, _ = pretrained.load_model_and_alphabet(pretrained_model_name)
 
145
  elif name in our_model_state_dict and our_model_state_dict[name].shape == weight.shape:
146
  del our_model_state_dict[name]
147
  new_state_dict[name] = weight
148
+ '''
149
  print("Exists layer names:")
150
  print(new_state_dict.keys())
151
  print("Not exists Layer names:")
152
  print(our_model_state_dict.keys())
153
+ '''
154
  new_state_dict.update(our_model_state_dict)
155
  self.load_state_dict(new_state_dict)
156
 
157
  def __calc_loss__(self, task_level_type, output_mode, logits, label, label_size, loss_fct, loss_reduction):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  if output_mode in ["regression"]:
159
  if task_level_type not in ["seq_level"] and loss_reduction == "meanmean":
160
  # structure-level regression
 
292
  representation_matrix = hidden_representations[self.layer_size]
293
  # mask 任务
294
  # B * Seq_len * vocab_size
295
+ if not self.embedding_inference:
296
+ lm_mask_logits = self.lm_head(x)
297
  # lm head的输出向量作为表征向量
298
  # (B, E)
299
  representation_vector = representation_matrix[:, 0, :]
 
315
  attentions = attentions * attention_mask[:, None, None, :, :]
316
  representations["attentions"] = attentions
317
  # 预测contact矩阵
318
+ if return_contacts and hasattr(self, "contact_head") \
319
+ and not self.embedding_inference:
320
  contacts = self.contact_head(input_ids, attentions)
321
  representations["contacts"] = contacts
322
  '''
323
  print("output_keys:")
324
  print(output_keys)
325
  '''
326
+ if not self.embedding_inference and output_keys:
327
  for item in output_keys.items():
328
  cur_task_level_type = item[0]
329
  if cur_task_level_type not in logits:
 
453
  use_last_layer_norm=use_last_layer_norm
454
  )
455
  has_pair_b = True
 
 
 
456
 
457
+ if not self.embedding_inference:
458
+ if has_pair and has_pair_b and pair_output_keys and len(pair_output_keys) > 0:
459
+ cur_representation_vector = encoding["representation_vector"]
460
+ cur_representation_vector_b = encoding_b["representation_vector"]
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
+ pair_logits = {}
463
+ pair_outputs = {}
464
  for item1 in pair_output_keys.items():
465
  cur_task_level_type = item1[0]
466
+ if cur_task_level_type not in pair_outputs:
467
+ pair_outputs[cur_task_level_type] = {}
468
+ pair_logits[cur_task_level_type] = {}
 
469
  for cur_task_level_name in item1[1]:
470
+ cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](
471
+ torch.cat((cur_representation_vector, cur_representation_vector_b), dim=-1)
472
+ )
473
+ cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
474
+ if cur_hidden_layer is not None:
475
+ cur_logits = cur_hidden_layer(cur_logits)
476
+ cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
477
+ pair_logits[cur_task_level_type][cur_task_level_name] = cur_logits
478
+ pair_outputs[cur_task_level_type][cur_task_level_name] = self.output[cur_task_level_type][cur_task_level_name](cur_logits)
479
+
480
+ if pair_label is not None:
481
+ pair_loss = {}
482
+ for item1 in pair_output_keys.items():
483
+ cur_task_level_type = item1[0]
484
+ if cur_task_level_type not in pair_label:
485
  continue
486
+ if cur_task_level_type in pair_label:
487
+ pair_loss[cur_task_level_type] = {}
488
+ for cur_task_level_name in item1[1]:
489
+ if cur_task_level_name not in pair_label[cur_task_level_type]:
490
+ continue
491
+ cur_label = pair_label[cur_task_level_type][cur_task_level_name]
492
+ cur_label_size = self.label_size[cur_task_level_type][cur_task_level_name]
493
+ cur_output_mode = self.output_mode[cur_task_level_type][cur_task_level_name]
494
+ cur_loss_fct = self.loss_fct[cur_task_level_type][cur_task_level_name]
495
+ cur_logits = pair_logits[cur_task_level_type][cur_task_level_name]
496
+ cur_loss = self.__calc_loss__(
497
+ task_level_type=cur_task_level_type,
498
+ output_mode=cur_output_mode, logits=cur_logits,
499
+ label=cur_label, label_size=cur_label_size, loss_fct=cur_loss_fct,
500
+ loss_reduction="meanmean")
501
+ pair_loss[cur_task_level_type][cur_task_level_name] = cur_loss
502
 
503
+ if not return_dict:
504
+ return [[losses, losses_b, pair_loss], [outputs, outputs_b, pair_outputs]] + [[encoding, encoding_b]]
505
+ return AllOutput(
506
+ losses=losses,
507
+ outputs=outputs,
508
+ hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
509
+ attentions=encoding["attentions"] if "attentions" in encoding else None,
510
+ global_attentions=None,
511
+ contacts=encoding["contacts"] if "contacts" in encoding else None,
512
+ losses_b=losses_b,
513
+ outputs_b=outputs_b,
514
+ hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
515
+ attentions_b=encoding_b["attentions"] if "hidden_states" in encoding_b else None,
516
+ global_attentions_b=None,
517
+ contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None,
518
+ pair_outputs=pair_outputs,
519
+ pair_losses=pair_loss)
520
+ else:
521
+ if not return_dict:
522
+ return [[losses, losses_b], [outputs, outputs_b]] + [[encoding, encoding_b]]
523
+ return AllOutput(
524
+ losses=losses,
525
+ outputs=outputs,
526
+ hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
527
+ attentions=encoding["attentions"] if "attentions" in encoding else None,
528
+ global_attentions=None,
529
+ contacts=encoding["contacts"] if "contacts" in encoding else None,
530
+ losses_b=losses_b,
531
+ outputs_b=outputs_b,
532
+ hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
533
+ attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
534
+ global_attentions_b=None,
535
+ contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
536
+ )
537
+ elif has_pair:
538
  if not return_dict:
539
+ return [[losses], [outputs], [encoding]]
540
  return AllOutput(
541
  losses=losses,
542
  outputs=outputs,
543
  hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
544
  attentions=encoding["attentions"] if "attentions" in encoding else None,
545
  global_attentions=None,
546
+ contacts=encoding["contacts"] if "contacts" in encoding else None
547
+ )
548
+ else:
549
+ if not return_dict:
550
+ return [[losses_b], [outputs_b], [encoding_b]]
551
+ return AllOutput(
552
  losses_b=losses_b,
553
  outputs_b=outputs_b,
554
  hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
555
+ attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
556
  global_attentions_b=None,
557
+ contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
558
+ )
559
+ else:
560
+ if has_pair and has_pair_b:
561
  if not return_dict:
562
+ return [[None, None], [None, None]] + [[encoding, encoding_b]]
563
  return AllOutput(
564
+ losses=None,
565
+ outputs=None,
566
  hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
567
  attentions=encoding["attentions"] if "attentions" in encoding else None,
568
  global_attentions=None,
569
  contacts=encoding["contacts"] if "contacts" in encoding else None,
570
+ losses_b=None,
571
+ outputs_b=None,
572
+ hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
573
+ attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
574
+ global_attentions_b=None,
575
+ contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
576
+ )
577
+ elif has_pair:
578
+ if not return_dict:
579
+ return [[None], [None], [encoding]]
580
+ return AllOutput(
581
+ losses=None,
582
+ outputs=None,
583
+ hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
584
+ attentions=encoding["attentions"] if "attentions" in encoding else None,
585
+ global_attentions=None,
586
+ contacts=encoding["contacts"] if "contacts" in encoding else None
587
+ )
588
+ else:
589
+ if not return_dict:
590
+ return [[None], [None], [encoding_b]]
591
+ return AllOutput(
592
+ losses_b=None,
593
+ outputs_b=None,
594
  hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
595
  attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
596
  global_attentions_b=None,
597
  contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
598
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
 
600
  def predict_contacts(self, input_ids, position_ids=None, token_type_ids=None):
601
+ return self(
602
+ input_ids=input_ids,
603
+ position_ids=position_ids,
604
+ token_type_ids=token_type_ids,
605
+ return_contacts=True)["contacts"]
modeling_bert.py CHANGED
@@ -6,7 +6,7 @@
6
  @email: sanyuan.**@**.com
7
  @tel: 137****6540
8
  @datetime: 2022/12/2 09:38
9
- @project: LucaOneTasks
10
  @file: modeling_bert
11
  @desc: transformer layers
12
  '''
@@ -179,22 +179,20 @@ class BertEmbeddings(nn.Module):
179
 
180
  def __init__(self, config):
181
  super().__init__()
182
- if hasattr(config, "no_token_embeddings"):
183
- self.no_token_embeddings = config.no_token_embeddings
184
- else:
185
- self.no_token_embeddings = False
186
- if not self.no_token_embeddings:
187
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
188
  if hasattr(config, "no_position_embeddings"):
189
  self.no_position_embeddings = config.no_position_embeddings
190
  else:
191
  self.no_position_embeddings = False
 
192
  if hasattr(config, "no_token_type_embeddings"):
193
  self.no_token_type_embeddings = config.no_token_type_embeddings
194
  else:
195
  self.no_token_type_embeddings = False
 
196
  if not self.no_position_embeddings:
197
  self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
 
198
  if not self.no_token_type_embeddings:
199
  self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
200
 
@@ -206,7 +204,10 @@ class BertEmbeddings(nn.Module):
206
  if not self.no_position_embeddings:
207
  self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
208
  self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
209
- if not self.no_token_type_embeddings and not self.no_position_embeddings:
 
 
 
210
  if version.parse(torch.__version__) > version.parse("1.6.0"):
211
  self.register_buffer(
212
  "token_type_ids",
@@ -229,21 +230,20 @@ class BertEmbeddings(nn.Module):
229
 
230
  seq_length = input_shape[1]
231
 
232
- if not self.no_position_embeddings and position_ids is None :
233
  position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
234
 
235
  # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
236
  # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
237
  # issue #5664
238
- if not self.no_token_type_embeddings and token_type_ids is None:
239
- if hasattr(self, "token_type_ids"):
240
- buffered_token_type_ids = self.token_type_ids[:, :seq_length]
241
- buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
242
- token_type_ids = buffered_token_type_ids_expanded
243
- else:
244
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device if input_ids is not None else inputs_embeds.device)
245
- if self.no_token_embeddings and inputs_embeds is None:
246
- raise Exception("The model has not token_embeddings layer, the inputs_embeds cannot None")
247
 
248
  if inputs_embeds is None:
249
  inputs_embeds = self.word_embeddings(input_ids)
@@ -898,14 +898,11 @@ class BertModel(BertPreTrainedModel):
898
  `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
899
  """
900
 
901
- def __init__(self, config, use_pretrained_embedding=False, add_pooling_layer=True):
902
  super().__init__(config)
903
  self.config = config
904
- self.use_pretrained_embedding = use_pretrained_embedding
905
- self.add_pooling_layer = add_pooling_layer
906
-
907
- self.embeddings = nn.Linear(config.embedding_input_size, config.hidden_size) if use_pretrained_embedding else BertEmbeddings(config)
908
 
 
909
  self.encoder = BertEncoder(config)
910
 
911
  self.pooler = BertPooler(config) if add_pooling_layer else None
@@ -1029,16 +1026,13 @@ class BertModel(BertPreTrainedModel):
1029
  # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1030
  head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1031
 
1032
- if self.use_pretrained_embedding:
1033
- embedding_output = self.embeddings(inputs_embeds)
1034
- else:
1035
- embedding_output = self.embeddings(
1036
- input_ids=input_ids,
1037
- position_ids=position_ids,
1038
- token_type_ids=token_type_ids,
1039
- inputs_embeds=inputs_embeds,
1040
- past_key_values_length=past_key_values_length,
1041
- )
1042
  encoder_outputs = self.encoder(
1043
  embedding_output,
1044
  attention_mask=extended_attention_mask,
 
6
  @email: sanyuan.**@**.com
7
  @tel: 137****6540
8
  @datetime: 2022/12/2 09:38
9
+ @project: LucaOne
10
  @file: modeling_bert
11
  @desc: transformer layers
12
  '''
 
179
 
180
  def __init__(self, config):
181
  super().__init__()
182
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
 
 
 
 
 
183
  if hasattr(config, "no_position_embeddings"):
184
  self.no_position_embeddings = config.no_position_embeddings
185
  else:
186
  self.no_position_embeddings = False
187
+
188
  if hasattr(config, "no_token_type_embeddings"):
189
  self.no_token_type_embeddings = config.no_token_type_embeddings
190
  else:
191
  self.no_token_type_embeddings = False
192
+
193
  if not self.no_position_embeddings:
194
  self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
195
+
196
  if not self.no_token_type_embeddings:
197
  self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
198
 
 
204
  if not self.no_position_embeddings:
205
  self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
206
  self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
207
+
208
+ if not self.no_token_type_embeddings:
209
+ if not hasattr(self, "position_ids"):
210
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
211
  if version.parse(torch.__version__) > version.parse("1.6.0"):
212
  self.register_buffer(
213
  "token_type_ids",
 
230
 
231
  seq_length = input_shape[1]
232
 
233
+ if (not self.no_position_embeddings or not self.no_token_type_embeddings) and position_ids is None:
234
  position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
235
 
236
  # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
237
  # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
238
  # issue #5664
239
+ if not self.no_token_type_embeddings:
240
+ if token_type_ids is None:
241
+ if hasattr(self, "token_type_ids"):
242
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
243
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
244
+ token_type_ids = buffered_token_type_ids_expanded
245
+ else:
246
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
 
247
 
248
  if inputs_embeds is None:
249
  inputs_embeds = self.word_embeddings(input_ids)
 
898
  `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
899
  """
900
 
901
+ def __init__(self, config, add_pooling_layer=True):
902
  super().__init__(config)
903
  self.config = config
 
 
 
 
904
 
905
+ self.embeddings = BertEmbeddings(config)
906
  self.encoder = BertEncoder(config)
907
 
908
  self.pooler = BertPooler(config) if add_pooling_layer else None
 
1026
  # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1027
  head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1028
 
1029
+ embedding_output = self.embeddings(
1030
+ input_ids=input_ids,
1031
+ position_ids=position_ids,
1032
+ token_type_ids=token_type_ids,
1033
+ inputs_embeds=inputs_embeds,
1034
+ past_key_values_length=past_key_values_length,
1035
+ )
 
 
 
1036
  encoder_outputs = self.encoder(
1037
  embedding_output,
1038
  attention_mask=extended_attention_mask,
modeling_gplm.py CHANGED
@@ -1,6 +1,15 @@
1
  #!/usr/bin/env python
2
  # encoding: utf-8
3
-
 
 
 
 
 
 
 
 
 
4
  import math
5
  from typing import Dict, Optional, Sequence, Tuple, List, Union
6
  import uuid
@@ -11,19 +20,14 @@ from torch.nn import Parameter
11
 
12
 
13
  def gelu(x):
14
- """Implementation of the gelu activation function.
15
- OpenAI GPT's gelu: 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
16
- """
17
  return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
18
 
19
 
20
  def symmetrize(x):
21
- "Make layer symmetric in final two dimensions, used for contact prediction."
22
  return x + x.transpose(-1, -2)
23
 
24
 
25
  def apc(x):
26
- "Perform average product correct, used for contact prediction."
27
  a1 = x.sum(-1, keepdims=True)
28
  a2 = x.sum(-2, keepdims=True)
29
  a12 = x.sum((-1, -2), keepdims=True)
@@ -57,7 +61,22 @@ class LucaGPLM1LayerNorm(nn.Module):
57
  x = (self.weight * x) + self.bias
58
  return x
59
 
60
- from torch.nn import LayerNorm as LucaGPLM1bLayerNorm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  class LucaGPLMTransformerLayer(nn.Module):
@@ -141,7 +160,6 @@ class LucaGPLMTransformerLayer(nn.Module):
141
 
142
 
143
  class AxialTransformerLayer(nn.Module):
144
- """Implements an Axial MSA Transformer block."""
145
  def __init__(
146
  self,
147
  embedding_dim: int = 768,
@@ -197,10 +215,6 @@ class AxialTransformerLayer(nn.Module):
197
  self_attn_padding_mask: Optional[torch.Tensor] = None,
198
  need_head_weights: bool = False,
199
  ):
200
- """
201
- LayerNorm is applied either before or after the self-attention/ffn
202
- modules similar to the original Transformer implementation.
203
- """
204
  x, row_attn = self.row_self_attention(
205
  x,
206
  self_attn_mask=self_attn_mask,
@@ -219,13 +233,6 @@ class AxialTransformerLayer(nn.Module):
219
 
220
 
221
  class LearnedPositionalEmbedding(nn.Embedding):
222
- """
223
- This module learns positional embeddings up to a fixed maximum size.
224
- Padding ids are ignored by either offsetting based on padding_idx
225
- or by setting padding_idx to None and ensuring that the appropriate
226
- position ids are passed to the forward function.
227
- """
228
-
229
  def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
230
  if padding_idx is not None:
231
  num_embeddings_ = num_embeddings + padding_idx + 1
@@ -293,8 +300,6 @@ class SinusoidalPositionalEmbedding(nn.Module):
293
 
294
 
295
  class RobertaLMHead(nn.Module):
296
- """Head for masked language modeling."""
297
-
298
  def __init__(self, embed_dim, output_dim, weight):
299
  super().__init__()
300
  self.dense = nn.Linear(embed_dim, embed_dim)
@@ -312,8 +317,6 @@ class RobertaLMHead(nn.Module):
312
 
313
 
314
  class ContactPredictionHead(nn.Module):
315
- """Performs symmetrization, apc, and computes a logistic regression on the output features"""
316
-
317
  def __init__(
318
  self,
319
  in_features: int,
@@ -697,11 +700,6 @@ def with_incremental_state(cls):
697
 
698
  @with_incremental_state
699
  class LucaGPLMMultiheadAttention(nn.Module):
700
- """Multi-headed attention.
701
-
702
- See "Attention Is All You Need" for more details.
703
- """
704
-
705
  def __init__(
706
  self,
707
  embed_dim,
@@ -768,18 +766,6 @@ class LucaGPLMMultiheadAttention(nn.Module):
768
  self.onnx_trace = True
769
 
770
  def reset_parameters(self):
771
- '''
772
- if self.qkv_same_dim:
773
- # Empirically observed the convergence to be much better with
774
- # the scaled initialization
775
- nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
776
- nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
777
- nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
778
- else:
779
- nn.init.xavier_uniform_(self.k_proj.weight)
780
- nn.init.xavier_uniform_(self.v_proj.weight)
781
- nn.init.xavier_uniform_(self.q_proj.weight)
782
- '''
783
  nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu"))
784
  nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu"))
785
  nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu"))
@@ -806,23 +792,6 @@ class LucaGPLMMultiheadAttention(nn.Module):
806
  before_softmax: bool = False,
807
  need_head_weights: bool = False,
808
  ) -> Tuple[Tensor, Optional[Tensor]]:
809
- """Input shape: Time x Batch x Channel
810
-
811
- Args:
812
- key_padding_mask (ByteTensor, optional): mask to exclude
813
- keys that are pads, of shape `(batch, src_len)`, where
814
- padding elements are indicated by 1s.
815
- need_weights (bool, optional): return the attention weights,
816
- averaged over heads (default: False).
817
- attn_mask (ByteTensor, optional): typically used to
818
- implement causal attention, where the mask prevents the
819
- attention from looking forward in time (default: None).
820
- before_softmax (bool, optional): return the raw attention
821
- weights and values before the attention softmax.
822
- need_head_weights (bool, optional): return the attention
823
- weights for each head. Implies *need_weights*. Default:
824
- return the average attention weights over all heads.
825
- """
826
  if need_head_weights:
827
  need_weights = True
828
 
@@ -1081,7 +1050,6 @@ class LucaGPLMMultiheadAttention(nn.Module):
1081
  def reorder_incremental_state(
1082
  self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
1083
  ):
1084
- """Reorder buffered internal state (for incremental generation)."""
1085
  input_buffer = self._get_input_buffer(incremental_state)
1086
  if input_buffer is not None:
1087
  for k in input_buffer.keys():
@@ -1121,7 +1089,6 @@ class LucaGPLMMultiheadAttention(nn.Module):
1121
  keys_to_remove = []
1122
  for k in state_dict.keys():
1123
  if k.endswith(prefix + "in_proj_weight"):
1124
- # in_proj_weight used to be q + k + v with same dimensions
1125
  dim = int(state_dict[k].shape[0] / 3)
1126
  items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
1127
  items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
@@ -1158,22 +1125,8 @@ def apply_rotary_pos_emb(x, cos, sin):
1158
 
1159
 
1160
  class RotaryEmbedding(torch.nn.Module):
1161
- """
1162
- The rotary position embeddings from RoFormer_ (Su et. al).
1163
- A crucial insight from the method is that the query and keys are
1164
- transformed by rotation matrices which depend on the relative positions.
1165
- Other implementations are available in the Rotary Transformer repo_ and in
1166
- GPT-NeoX_, GPT-NeoX was an inspiration
1167
- .. _RoFormer: https://arxiv.org/abs/2104.09864
1168
- .. _repo: https://github.com/ZhuiyiTechnology/roformer
1169
- .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
1170
- .. warning: Please note that this embedding is not registered on purpose, as it is transformative
1171
- (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
1172
- """
1173
-
1174
  def __init__(self, dim: int, *_, **__):
1175
  super().__init__()
1176
- # Generate and save the inverse frequency buffer (non trainable)
1177
  inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
1178
  self.register_buffer("inv_freq", inv_freq)
1179
 
@@ -1184,8 +1137,6 @@ class RotaryEmbedding(torch.nn.Module):
1184
  def _update_cos_sin_tables(self, x, seq_dimension=1):
1185
  seq_len = x.shape[seq_dimension]
1186
 
1187
- # Reset the tables if the sequence length has changed,
1188
- # or if we're on a new device (possibly due to tracing for instance)
1189
  if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
1190
  self._seq_len_cached = seq_len
1191
  t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
 
1
  #!/usr/bin/env python
2
  # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: [email protected]
7
+ @tel: 137****6540
8
+ @datetime: 2023/7/24 10:01
9
+ @project: LucaOne
10
+ @file: modeling_gplm
11
+ @desc: LucaOne Model Detail
12
+ '''
13
  import math
14
  from typing import Dict, Optional, Sequence, Tuple, List, Union
15
  import uuid
 
20
 
21
 
22
  def gelu(x):
 
 
 
23
  return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
24
 
25
 
26
  def symmetrize(x):
 
27
  return x + x.transpose(-1, -2)
28
 
29
 
30
  def apc(x):
 
31
  a1 = x.sum(-1, keepdims=True)
32
  a2 = x.sum(-2, keepdims=True)
33
  a12 = x.sum((-1, -2), keepdims=True)
 
61
  x = (self.weight * x) + self.bias
62
  return x
63
 
64
+
65
+ try:
66
+ # Optimized LayerNorm
67
+ from apex.normalization import FusedLayerNorm as _FusedLayerNorm
68
+ class LucaGPLM1bLayerNorm(_FusedLayerNorm):
69
+ @torch.jit.unused
70
+ def forward(self, x):
71
+ if not x.is_cuda:
72
+ return super().forward(x)
73
+ else:
74
+ with torch.cuda.device(x.device):
75
+ return super().forward(x)
76
+
77
+ except ImportError as e:
78
+ print("import apex err:", e)
79
+ from torch.nn import LayerNorm as LucaGPLM1bLayerNorm
80
 
81
 
82
  class LucaGPLMTransformerLayer(nn.Module):
 
160
 
161
 
162
  class AxialTransformerLayer(nn.Module):
 
163
  def __init__(
164
  self,
165
  embedding_dim: int = 768,
 
215
  self_attn_padding_mask: Optional[torch.Tensor] = None,
216
  need_head_weights: bool = False,
217
  ):
 
 
 
 
218
  x, row_attn = self.row_self_attention(
219
  x,
220
  self_attn_mask=self_attn_mask,
 
233
 
234
 
235
  class LearnedPositionalEmbedding(nn.Embedding):
 
 
 
 
 
 
 
236
  def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
237
  if padding_idx is not None:
238
  num_embeddings_ = num_embeddings + padding_idx + 1
 
300
 
301
 
302
  class RobertaLMHead(nn.Module):
 
 
303
  def __init__(self, embed_dim, output_dim, weight):
304
  super().__init__()
305
  self.dense = nn.Linear(embed_dim, embed_dim)
 
317
 
318
 
319
  class ContactPredictionHead(nn.Module):
 
 
320
  def __init__(
321
  self,
322
  in_features: int,
 
700
 
701
  @with_incremental_state
702
  class LucaGPLMMultiheadAttention(nn.Module):
 
 
 
 
 
703
  def __init__(
704
  self,
705
  embed_dim,
 
766
  self.onnx_trace = True
767
 
768
  def reset_parameters(self):
 
 
 
 
 
 
 
 
 
 
 
 
769
  nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu"))
770
  nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu"))
771
  nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu"))
 
792
  before_softmax: bool = False,
793
  need_head_weights: bool = False,
794
  ) -> Tuple[Tensor, Optional[Tensor]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
795
  if need_head_weights:
796
  need_weights = True
797
 
 
1050
  def reorder_incremental_state(
1051
  self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
1052
  ):
 
1053
  input_buffer = self._get_input_buffer(incremental_state)
1054
  if input_buffer is not None:
1055
  for k in input_buffer.keys():
 
1089
  keys_to_remove = []
1090
  for k in state_dict.keys():
1091
  if k.endswith(prefix + "in_proj_weight"):
 
1092
  dim = int(state_dict[k].shape[0] / 3)
1093
  items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
1094
  items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
 
1125
 
1126
 
1127
  class RotaryEmbedding(torch.nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
1128
  def __init__(self, dim: int, *_, **__):
1129
  super().__init__()
 
1130
  inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
1131
  self.register_buffer("inv_freq", inv_freq)
1132
 
 
1137
  def _update_cos_sin_tables(self, x, seq_dimension=1):
1138
  seq_len = x.shape[seq_dimension]
1139
 
 
 
1140
  if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
1141
  self._seq_len_cached = seq_len
1142
  t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)