yangheng commited on
Commit
51f9ad7
1 Parent(s): 3eb2fd3
README.md CHANGED
@@ -1,3 +1,35 @@
1
  ---
2
  license: mit
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ language:
4
+ - rna
5
+ - dna
6
+
7
+ tags:
8
+ - Genomic-Language-Modeling
9
+ - OmniGenome Foundation Model
10
  ---
11
+
12
+ # Multi-species Foundation Model for Universal RNA and DNA Downstream Tasks
13
+
14
+ # Notes
15
+ We are keep updating the checkpoints, the current checkpoint is trained for 0.85 epoch.
16
+
17
+ ## Training Examples
18
+ Refer to GitHub [https://github.com/yangheng95/OmniGenome](https://github.com/yangheng95/OmniGenome)
19
+
20
+ ## Usage
21
+ This model is available for replacing genomic foundation models such as CDSBERT, Nucleotide Transformers, DNABERT2, etc.
22
+ ```
23
+ from transformers import AutoModel
24
+ model = AutoModel.from_pretrained("yangheng/OmniGenome-52M", trust_remote_code=True)
25
+ ```
26
+
27
+ ## Subtasks
28
+ - Secondary structure prediction
29
+ - Genome Sequence Classification
30
+ - Genome Sequence Regression
31
+ - Single Nucleotide Repair
32
+ - Genome Masked Language Modeling
33
+ - etc.
34
+
35
+ Part of the codes are adapted from ESM2.
config.json CHANGED
@@ -1,17 +1,16 @@
1
  {
2
- "MPRNAfold_config": null,
3
- "_name_or_path": "../output/checkpoint-500-legacy",
4
  "architectures": [
5
- "MPRNAForMaskedLM"
6
  ],
7
  "attention_probs_dropout_prob": 0.0,
8
  "auto_map": {
9
- "AutoConfig": "configuration_mprna.MPRNAConfig",
10
- "AutoModel": "modeling_mprna.MPRNAModel",
11
- "AutoModelForMaskedLM": "modeling_mprna.MPRNAForMaskedLM",
12
- "AutoModelForSequenceClassification": "modeling_mprna.RNA2StructForSequenceClassification",
13
- "AutoModelForTokenClassification": "modeling_mprna.RNA2StructForTokenClassification",
14
- "AutoTokenizer": "tokenization_mprna.MPRNATokenizer"
15
  },
16
  "classifier_dropout": null,
17
  "emb_layer_norm_before": false,
@@ -24,7 +23,7 @@
24
  "layer_norm_eps": 1e-05,
25
  "mask_token_id": 23,
26
  "max_position_embeddings": 1026,
27
- "model_type": "mprna",
28
  "num_attention_heads": 30,
29
  "num_hidden_layers": 32,
30
  "pad_token_id": 1,
 
1
  {
2
+ "OmniGenomefold_config": null,
3
+ "_name_or_path": "./",
4
  "architectures": [
5
+ "OmniGenomeForTokenClassification"
6
  ],
7
  "attention_probs_dropout_prob": 0.0,
8
  "auto_map": {
9
+ "AutoConfig": "configuration_omnigenome.OmniGenomeConfig",
10
+ "AutoModel": "modeling_omnigenome.OmniGenomeModel",
11
+ "AutoModelForMaskedLM": "modeling_omnigenome.OmniGenomeForMaskedLM",
12
+ "AutoModelForSeq2SeqLM": "modeling_omnigenome.OmniGenomeForSeq2SeqLM",
13
+ "AutoModelForTokenClassification": "modeling_omnigenome.OmniGenomeForTokenClassification"
 
14
  },
15
  "classifier_dropout": null,
16
  "emb_layer_norm_before": false,
 
23
  "layer_norm_eps": 1e-05,
24
  "mask_token_id": 23,
25
  "max_position_embeddings": 1026,
26
+ "model_type": "omnigenome",
27
  "num_attention_heads": 30,
28
  "num_hidden_layers": 32,
29
  "pad_token_id": 1,
configuration_mprna.py → configuration_omnigenome.py RENAMED
@@ -12,7 +12,7 @@
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
- """ MPRNA model configuration"""
16
 
17
  from dataclasses import asdict, dataclass
18
  from typing import Optional
@@ -24,18 +24,19 @@ from transformers.utils import logging
24
  logger = logging.get_logger(__name__)
25
 
26
  # TODO Update this
27
- MPRNA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
28
- "yangheng/MPRNA-small": "https://huggingface.co/yangheng/MPRNA-small/resolve/main/config.json",
29
- # See all MPRNA models at https://huggingface.co/models?filter=MPRNA
 
30
  }
31
 
32
 
33
- class MPRNAConfig(PretrainedConfig):
34
  r"""
35
- This is the configuration class to store the configuration of a [`MPRNAModel`]. It is used to instantiate a MPRNA model
36
  according to the specified arguments, defining the model architecture. Instantiating a configuration with the
37
- defaults will yield a similar configuration to that of the MPRNA
38
- [yangheng/MPRNA-small](https://huggingface.co/yangheng/MPRNA-small) architecture.
39
 
40
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
41
  documentation from [`PretrainedConfig`] for more information.
@@ -43,14 +44,14 @@ class MPRNAConfig(PretrainedConfig):
43
 
44
  Args:
45
  vocab_size (`int`, *optional*):
46
- Vocabulary size of the MPRNA model. Defines the number of different tokens that can be represented by the
47
- `inputs_ids` passed when calling [`MPRNAModel`].
48
  mask_token_id (`int`, *optional*):
49
  The index of the mask token in the vocabulary. This must be included in the config because of the
50
  "mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens.
51
  pad_token_id (`int`, *optional*):
52
  The index of the padding token in the vocabulary. This must be included in the config because certain parts
53
- of the MPRNA code use this instead of the attention mask.
54
  hidden_size (`int`, *optional*, defaults to 768):
55
  Dimensionality of the encoder layers and the pooler layer.
56
  num_hidden_layers (`int`, *optional*, defaults to 12):
@@ -89,11 +90,11 @@ class MPRNAConfig(PretrainedConfig):
89
  Examples:
90
 
91
  ```python
92
- # >>> from transformers import MPRNAModel, MPRNAConfig
93
  #
94
- # >>> # Initializing a MPRNA yangheng/MPRNA-small style configuration >>> configuration = MPRNAConfig()
95
  #
96
- # >>> # Initializing a model from the configuration >>> model = MPRNAModel(configuration)
97
  #
98
  # >>> # Accessing the model configuration >>> configuration = model.config
99
  ```"""
@@ -119,7 +120,7 @@ class MPRNAConfig(PretrainedConfig):
119
  emb_layer_norm_before=None,
120
  token_dropout=False,
121
  is_folding_model=False,
122
- MPRNAfold_config=None,
123
  vocab_list=None,
124
  **kwargs,
125
  ):
@@ -142,30 +143,13 @@ class MPRNAConfig(PretrainedConfig):
142
  self.emb_layer_norm_before = emb_layer_norm_before
143
  self.token_dropout = token_dropout
144
  self.is_folding_model = is_folding_model
145
- if is_folding_model:
146
- if MPRNAfold_config is None:
147
- logger.info(
148
- "No MPRNAfold_config supplied for folding model, using default values."
149
- )
150
- MPRNAfold_config = MPRNAFoldConfig()
151
- elif isinstance(MPRNAfold_config, dict):
152
- MPRNAfold_config = MPRNAFoldConfig(**MPRNAfold_config)
153
- self.MPRNAfold_config = MPRNAfold_config
154
- if vocab_list is None:
155
- logger.warning(
156
- "No vocab_list supplied for folding model, assuming the MPRNA-2 vocabulary!"
157
- )
158
- self.vocab_list = get_default_vocab_list()
159
- else:
160
- self.vocab_list = vocab_list
161
- else:
162
- self.MPRNAfold_config = None
163
- self.vocab_list = None
164
- if self.MPRNAfold_config is not None and getattr(
165
- self.MPRNAfold_config, "use_MPRNA_attn_map", False
166
  ):
167
  raise ValueError(
168
- "The HuggingFace port of MPRNAFold does not support use_MPRNA_attn_map at this time!"
169
  )
170
 
171
  def to_dict(self):
@@ -176,41 +160,6 @@ class MPRNAConfig(PretrainedConfig):
176
  `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
177
  """
178
  output = super().to_dict()
179
- if isinstance(self.MPRNAfold_config, MPRNAFoldConfig):
180
- output["MPRNAfold_config"] = self.MPRNAfold_config.to_dict()
181
- return output
182
-
183
-
184
- @dataclass
185
- class MPRNAFoldConfig:
186
- MPRNA_type: str = None
187
- fp16_MPRNA: bool = True
188
- use_MPRNA_attn_map: bool = False
189
- MPRNA_ablate_pairwise: bool = False
190
- MPRNA_ablate_sequence: bool = False
191
- MPRNA_input_dropout: float = 0
192
-
193
- embed_aa: bool = True
194
- bypass_lm: bool = False
195
-
196
- lddt_head_hid_dim: int = 128
197
- trunk: "TrunkConfig" = None
198
-
199
- def __post_init__(self):
200
- if self.trunk is None:
201
- self.trunk = TrunkConfig()
202
- elif isinstance(self.trunk, dict):
203
- self.trunk = TrunkConfig(**self.trunk)
204
-
205
- def to_dict(self):
206
- """
207
- Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
208
-
209
- Returns:
210
- `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
211
- """
212
- output = asdict(self)
213
- output["trunk"] = self.trunk.to_dict()
214
  return output
215
 
216
 
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
+ """ OmniGenome model configuration"""
16
 
17
  from dataclasses import asdict, dataclass
18
  from typing import Optional
 
24
  logger = logging.get_logger(__name__)
25
 
26
  # TODO Update this
27
+ OmniGenome_PRETRAINED_CONFIG_ARCHIVE_MAP = {
28
+ "yangheng/OmniGenome-52M": "https://huggingface.co/yangheng/OmniGenome-52M/resolve/main/config.json",
29
+ "yangheng/OmniGenome-186M": "https://huggingface.co/yangheng/OmniGenome-186M/resolve/main/config.json",
30
+ # See all OmniGenome models at https://huggingface.co/models?filter=OmniGenome
31
  }
32
 
33
 
34
+ class OmniGenomeConfig(PretrainedConfig):
35
  r"""
36
+ This is the configuration class to store the configuration of a [`OmniGenomeModel`]. It is used to instantiate a OmniGenome model
37
  according to the specified arguments, defining the model architecture. Instantiating a configuration with the
38
+ defaults will yield a similar configuration to that of the OmniGenome
39
+ [yangheng/OmniGenome-52M](https://huggingface.co/yangheng/OmniGenome-52M) architecture.
40
 
41
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
42
  documentation from [`PretrainedConfig`] for more information.
 
44
 
45
  Args:
46
  vocab_size (`int`, *optional*):
47
+ Vocabulary size of the OmniGenome model. Defines the number of different tokens that can be represented by the
48
+ `inputs_ids` passed when calling [`OmniGenomeModel`].
49
  mask_token_id (`int`, *optional*):
50
  The index of the mask token in the vocabulary. This must be included in the config because of the
51
  "mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens.
52
  pad_token_id (`int`, *optional*):
53
  The index of the padding token in the vocabulary. This must be included in the config because certain parts
54
+ of the OmniGenome code use this instead of the attention mask.
55
  hidden_size (`int`, *optional*, defaults to 768):
56
  Dimensionality of the encoder layers and the pooler layer.
57
  num_hidden_layers (`int`, *optional*, defaults to 12):
 
90
  Examples:
91
 
92
  ```python
93
+ # >>> from transformers import OmniGenomeModel, OmniGenomeConfig
94
  #
95
+ # >>> # Initializing a OmniGenome yangheng/OmniGenome-52M style configuration >>> configuration = OmniGenomeConfig()
96
  #
97
+ # >>> # Initializing a model from the configuration >>> model = OmniGenomeModel(configuration)
98
  #
99
  # >>> # Accessing the model configuration >>> configuration = model.config
100
  ```"""
 
120
  emb_layer_norm_before=None,
121
  token_dropout=False,
122
  is_folding_model=False,
123
+ OmniGenomefold_config=None,
124
  vocab_list=None,
125
  **kwargs,
126
  ):
 
143
  self.emb_layer_norm_before = emb_layer_norm_before
144
  self.token_dropout = token_dropout
145
  self.is_folding_model = is_folding_model
146
+ self.OmniGenomefold_config = None
147
+ self.vocab_list = None
148
+ if self.OmniGenomefold_config is not None and getattr(
149
+ self.OmniGenomefold_config, "use_OmniGenome_attn_map", False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  ):
151
  raise ValueError(
152
+ "The HuggingFace port of OmniGenomeFold does not support use_OmniGenome_attn_map at this time!"
153
  )
154
 
155
  def to_dict(self):
 
160
  `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
161
  """
162
  output = super().to_dict()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  return output
164
 
165
 
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:856b9640ad214afeed65eea9fe2e165093b8a002815b8cf84431211b6efbb4db
3
- size 743619996
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e9fec6baa4327e0554e927998fe8079d4223517276478567502fb5a6cb59790
3
+ size 745777424
modeling_mprna.py → modeling_omnigenome.py RENAMED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
- """ PyTorch MPRNA model."""
16
 
17
  import math
18
  from typing import List, Optional, Tuple, Union
@@ -23,25 +23,36 @@ from torch import nn
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
  from transformers import add_start_docstrings, PreTrainedModel
25
 
26
- from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, \
27
- BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
28
-
29
- from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
 
 
 
30
 
31
- from transformers.utils import logging, add_code_sample_docstrings, add_start_docstrings_to_model_forward
 
 
 
32
 
33
- from .configuration_mprna import MPRNAConfig
 
 
 
 
34
 
 
35
 
36
  logger = logging.get_logger(__name__)
37
 
38
- _CHECKPOINT_FOR_DOC = "yangheng/MPRNA-small"
39
- _CONFIG_FOR_DOC = "MPRNAConfig"
40
 
41
- MPRNA_PRETRAINED_MODEL_ARCHIVE_LIST = [
42
- "yangheng/MPRNA-small",
43
- # This is not a complete list of all MPRNA models!
44
- # See all MPRNA models at https://huggingface.co/models?filter=MPRNA
45
  ]
46
 
47
 
@@ -59,7 +70,7 @@ def apply_rotary_pos_emb(x, cos, sin):
59
 
60
  def gelu(x):
61
  """
62
- This is the gelu implementation from the original MPRNA repo. Using F.gelu yields subtly wrong results.
63
  """
64
  return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
65
 
@@ -81,6 +92,7 @@ def average_product_correct(x):
81
  return normalized
82
 
83
 
 
84
  class RotaryEmbedding(torch.nn.Module):
85
  """
86
  Rotary position embeddings based on those in
@@ -118,7 +130,7 @@ class RotaryEmbedding(torch.nn.Module):
118
  return self._cos_cached, self._sin_cached
119
 
120
  def forward(
121
- self, q: torch.Tensor, k: torch.Tensor
122
  ) -> Tuple[torch.Tensor, torch.Tensor]:
123
  self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
124
  k, seq_dimension=-2
@@ -130,14 +142,15 @@ class RotaryEmbedding(torch.nn.Module):
130
  )
131
 
132
 
133
- class MPRNAContactPredictionHead(nn.Module):
 
134
  """Performs symmetrization, apc, and computes a logistic regression on the output features"""
135
 
136
  def __init__(
137
- self,
138
- in_features: int,
139
- bias=True,
140
- eos_idx: int = 2,
141
  ):
142
  super().__init__()
143
  self.in_features = in_features
@@ -165,7 +178,8 @@ class MPRNAContactPredictionHead(nn.Module):
165
  return self.activation(self.regression(attentions).squeeze(3))
166
 
167
 
168
- class MPRNAEmbeddings(nn.Module):
 
169
  """
170
  Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
171
  """
@@ -203,12 +217,12 @@ class MPRNAEmbeddings(nn.Module):
203
  self.mask_token_id = config.mask_token_id
204
 
205
  def forward(
206
- self,
207
- input_ids=None,
208
- attention_mask=None,
209
- position_ids=None,
210
- inputs_embeds=None,
211
- past_key_values_length=0,
212
  ):
213
  if position_ids is None:
214
  if input_ids is not None:
@@ -224,11 +238,11 @@ class MPRNAEmbeddings(nn.Module):
224
  if inputs_embeds is None:
225
  inputs_embeds = self.word_embeddings(input_ids)
226
 
227
- # Note that if we want to support MPRNA-1 (not 1b!) in future then we need to support an
228
  # embedding_scale factor here.
229
  embeddings = inputs_embeds
230
 
231
- # Matt: MPRNA has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
232
  # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
233
  # masked tokens are treated as if they were selected for input dropout and zeroed out.
234
  # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
@@ -240,16 +254,16 @@ class MPRNAEmbeddings(nn.Module):
240
  (input_ids == self.mask_token_id).unsqueeze(-1), 0.0
241
  )
242
  mask_ratio_train = (
243
- 0.15 * 0.8
244
- ) # Hardcoded as the ratio used in all MPRNA model training runs
245
  src_lengths = attention_mask.sum(-1)
246
  mask_ratio_observed = (input_ids == self.mask_token_id).sum(
247
  -1
248
  ).float() / src_lengths
249
  embeddings = (
250
- embeddings
251
- * (1 - mask_ratio_train)
252
- / (1 - mask_ratio_observed)[:, None, None]
253
  ).to(embeddings.dtype)
254
 
255
  if self.position_embedding_type == "absolute":
@@ -287,11 +301,12 @@ class MPRNAEmbeddings(nn.Module):
287
  return position_ids.unsqueeze(0).expand(input_shape)
288
 
289
 
290
- class MPRNASelfAttention(nn.Module):
 
291
  def __init__(self, config, position_embedding_type=None):
292
  super().__init__()
293
  if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
294
- config, "embedding_size"
295
  ):
296
  raise ValueError(
297
  f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
@@ -312,8 +327,8 @@ class MPRNASelfAttention(nn.Module):
312
  )
313
  self.rotary_embeddings = None
314
  if (
315
- self.position_embedding_type == "relative_key"
316
- or self.position_embedding_type == "relative_key_query"
317
  ):
318
  self.max_position_embeddings = config.max_position_embeddings
319
  self.distance_embedding = nn.Embedding(
@@ -333,14 +348,14 @@ class MPRNASelfAttention(nn.Module):
333
  return x.permute(0, 2, 1, 3)
334
 
335
  def forward(
336
- self,
337
- hidden_states: torch.Tensor,
338
- attention_mask: Optional[torch.FloatTensor] = None,
339
- head_mask: Optional[torch.FloatTensor] = None,
340
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
341
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
342
- past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
343
- output_attentions: Optional[bool] = False,
344
  ) -> Tuple[torch.Tensor]:
345
  mixed_query_layer = self.query(hidden_states)
346
 
@@ -370,10 +385,10 @@ class MPRNASelfAttention(nn.Module):
370
  query_layer = self.transpose_for_scores(mixed_query_layer)
371
 
372
  # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
373
- # MPRNA scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
374
  # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
375
- # MPRNA code and fix rotary embeddings.
376
- query_layer = query_layer * self.attention_head_size**-0.5
377
 
378
  if self.is_decoder:
379
  # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -392,8 +407,8 @@ class MPRNASelfAttention(nn.Module):
392
  attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
393
 
394
  if (
395
- self.position_embedding_type == "relative_key"
396
- or self.position_embedding_type == "relative_key_query"
397
  ):
398
  seq_length = hidden_states.size()[1]
399
  position_ids_l = torch.arange(
@@ -423,13 +438,13 @@ class MPRNASelfAttention(nn.Module):
423
  "bhrd,lrd->bhlr", key_layer, positional_embedding
424
  )
425
  attention_scores = (
426
- attention_scores
427
- + relative_position_scores_query
428
- + relative_position_scores_key
429
  )
430
 
431
  if attention_mask is not None:
432
- # Apply the attention mask is (precomputed for all layers in MPRNAModel forward() function)
433
  attention_scores = attention_scores + attention_mask
434
 
435
  # Normalize the attention scores to probabilities.
@@ -458,7 +473,8 @@ class MPRNASelfAttention(nn.Module):
458
  return outputs
459
 
460
 
461
- class MPRNASelfOutput(nn.Module):
 
462
  def __init__(self, config):
463
  super().__init__()
464
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
@@ -471,11 +487,12 @@ class MPRNASelfOutput(nn.Module):
471
  return hidden_states
472
 
473
 
474
- class MPRNAAttention(nn.Module):
 
475
  def __init__(self, config):
476
  super().__init__()
477
- self.self = MPRNASelfAttention(config)
478
- self.output = MPRNASelfOutput(config)
479
  self.pruned_heads = set()
480
  self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
481
 
@@ -498,19 +515,19 @@ class MPRNAAttention(nn.Module):
498
  # Update hyper params and store pruned heads
499
  self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
500
  self.self.all_head_size = (
501
- self.self.attention_head_size * self.self.num_attention_heads
502
  )
503
  self.pruned_heads = self.pruned_heads.union(heads)
504
 
505
  def forward(
506
- self,
507
- hidden_states,
508
- attention_mask=None,
509
- head_mask=None,
510
- encoder_hidden_states=None,
511
- encoder_attention_mask=None,
512
- past_key_value=None,
513
- output_attentions=False,
514
  ):
515
  hidden_states_ln = self.LayerNorm(hidden_states)
516
  self_outputs = self.self(
@@ -524,12 +541,13 @@ class MPRNAAttention(nn.Module):
524
  )
525
  attention_output = self.output(self_outputs[0], hidden_states)
526
  outputs = (attention_output,) + self_outputs[
527
- 1:
528
- ] # add attentions if we output them
529
  return outputs
530
 
531
 
532
- class MPRNAIntermediate(nn.Module):
 
533
  def __init__(self, config):
534
  super().__init__()
535
  self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
@@ -540,7 +558,8 @@ class MPRNAIntermediate(nn.Module):
540
  return hidden_states
541
 
542
 
543
- class MPRNAOutput(nn.Module):
 
544
  def __init__(self, config):
545
  super().__init__()
546
  self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
@@ -553,12 +572,13 @@ class MPRNAOutput(nn.Module):
553
  return hidden_states
554
 
555
 
556
- class MPRNALayer(nn.Module):
 
557
  def __init__(self, config):
558
  super().__init__()
559
  self.chunk_size_feed_forward = config.chunk_size_feed_forward
560
  self.seq_len_dim = 1
561
- self.attention = MPRNAAttention(config)
562
  self.is_decoder = config.is_decoder
563
  self.add_cross_attention = config.add_cross_attention
564
  if self.add_cross_attention:
@@ -566,20 +586,20 @@ class MPRNALayer(nn.Module):
566
  raise RuntimeError(
567
  f"{self} should be used as a decoder model if cross attention is added"
568
  )
569
- self.crossattention = MPRNAAttention(config)
570
- self.intermediate = MPRNAIntermediate(config)
571
- self.output = MPRNAOutput(config)
572
  self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
573
 
574
  def forward(
575
- self,
576
- hidden_states,
577
- attention_mask=None,
578
- head_mask=None,
579
- encoder_hidden_states=None,
580
- encoder_attention_mask=None,
581
- past_key_value=None,
582
- output_attentions=False,
583
  ):
584
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
585
  self_attn_past_key_value = (
@@ -600,8 +620,8 @@ class MPRNALayer(nn.Module):
600
  present_key_value = self_attention_outputs[-1]
601
  else:
602
  outputs = self_attention_outputs[
603
- 1:
604
- ] # add self attentions if we output attention weights
605
 
606
  cross_attn_present_key_value = None
607
  if self.is_decoder and encoder_hidden_states is not None:
@@ -626,7 +646,7 @@ class MPRNALayer(nn.Module):
626
  )
627
  attention_output = cross_attention_outputs[0]
628
  outputs = (
629
- outputs + cross_attention_outputs[1:-1]
630
  ) # add cross attentions if we output attention weights
631
 
632
  # add cross-attn cache to positions 3,4 of present_key_value tuple
@@ -649,12 +669,13 @@ class MPRNALayer(nn.Module):
649
  return layer_output
650
 
651
 
652
- class MPRNAEncoder(nn.Module):
 
653
  def __init__(self, config):
654
  super().__init__()
655
  self.config = config
656
  self.layer = nn.ModuleList(
657
- [MPRNALayer(config) for _ in range(config.num_hidden_layers)]
658
  )
659
  self.emb_layer_norm_after = nn.LayerNorm(
660
  config.hidden_size, eps=config.layer_norm_eps
@@ -662,17 +683,17 @@ class MPRNAEncoder(nn.Module):
662
  self.gradient_checkpointing = False
663
 
664
  def forward(
665
- self,
666
- hidden_states,
667
- attention_mask=None,
668
- head_mask=None,
669
- encoder_hidden_states=None,
670
- encoder_attention_mask=None,
671
- past_key_values=None,
672
- use_cache=None,
673
- output_attentions=False,
674
- output_hidden_states=False,
675
- return_dict=True,
676
  ):
677
  if self.gradient_checkpointing and self.training:
678
  if use_cache:
@@ -752,8 +773,8 @@ class MPRNAEncoder(nn.Module):
752
  )
753
 
754
 
755
- # Copied from transformers.models.bert.modeling_bert.BertPooler
756
- class MPRNAPooler(nn.Module):
757
  def __init__(self, config):
758
  super().__init__()
759
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
@@ -768,19 +789,20 @@ class MPRNAPooler(nn.Module):
768
  return pooled_output
769
 
770
 
771
- class MPRNAPreTrainedModel(PreTrainedModel):
 
772
  """
773
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
774
  models.
775
  """
776
 
777
- config_class = MPRNAConfig
778
- base_model_prefix = "MPRNA"
779
  supports_gradient_checkpointing = True
780
  _no_split_modules = [
781
- "MPRNALayer",
782
- "MPRNAFoldTriangularSelfAttentionBlock",
783
- "MPRNAEmbeddings",
784
  ]
785
 
786
  # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
@@ -801,7 +823,7 @@ class MPRNAPreTrainedModel(PreTrainedModel):
801
  module.weight.data.fill_(1.0)
802
 
803
 
804
- MPRNA_START_DOCSTRING = r"""
805
 
806
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
807
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -812,12 +834,12 @@ MPRNA_START_DOCSTRING = r"""
812
  and behavior.
813
 
814
  Parameters:
815
- config ([`MPRNAConfig`]): Model configuration class with all the parameters of the
816
  model. Initializing with a config file does not load the weights associated with the model, only the
817
  configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
818
  """
819
 
820
- MPRNA_INPUTS_DOCSTRING = r"""
821
  Args:
822
  input_ids (`torch.LongTensor` of shape `({0})`):
823
  Indices of input sequence tokens in the vocabulary.
@@ -860,10 +882,11 @@ MPRNA_INPUTS_DOCSTRING = r"""
860
 
861
 
862
  @add_start_docstrings(
863
- "The bare MPRNA Model transformer outputting raw hidden-states without any specific head on top.",
864
- MPRNA_START_DOCSTRING,
865
  )
866
- class MPRNAModel(MPRNAPreTrainedModel):
 
867
  """
868
 
869
  The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
@@ -880,12 +903,12 @@ class MPRNAModel(MPRNAPreTrainedModel):
880
  super().__init__(config)
881
  self.config = config
882
 
883
- self.embeddings = MPRNAEmbeddings(config)
884
- self.encoder = MPRNAEncoder(config)
885
 
886
- self.pooler = MPRNAPooler(config) if add_pooling_layer else None
887
 
888
- self.contact_head = MPRNAContactPredictionHead(
889
  in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
890
  )
891
 
@@ -907,7 +930,7 @@ class MPRNAModel(MPRNAPreTrainedModel):
907
  self.encoder.layer[layer].attention.prune_heads(heads)
908
 
909
  @add_start_docstrings_to_model_forward(
910
- MPRNA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")
911
  )
912
  @add_code_sample_docstrings(
913
  checkpoint=_CHECKPOINT_FOR_DOC,
@@ -915,19 +938,19 @@ class MPRNAModel(MPRNAPreTrainedModel):
915
  config_class=_CONFIG_FOR_DOC,
916
  )
917
  def forward(
918
- self,
919
- input_ids: Optional[torch.Tensor] = None,
920
- attention_mask: Optional[torch.Tensor] = None,
921
- position_ids: Optional[torch.Tensor] = None,
922
- head_mask: Optional[torch.Tensor] = None,
923
- inputs_embeds: Optional[torch.Tensor] = None,
924
- encoder_hidden_states: Optional[torch.Tensor] = None,
925
- encoder_attention_mask: Optional[torch.Tensor] = None,
926
- past_key_values: Optional[List[torch.FloatTensor]] = None,
927
- use_cache: Optional[bool] = None,
928
- output_attentions: Optional[bool] = None,
929
- output_hidden_states: Optional[bool] = None,
930
- return_dict: Optional[bool] = None,
931
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
932
  r"""
933
  encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -1077,9 +1100,10 @@ class MPRNAModel(MPRNAPreTrainedModel):
1077
 
1078
 
1079
  @add_start_docstrings(
1080
- """MPRNA Model with a `language modeling` head on top.""", MPRNA_START_DOCSTRING
1081
  )
1082
- class MPRNAForMaskedLM(MPRNAPreTrainedModel):
 
1083
  _tied_weights_keys = ["lm_head.decoder.weight"]
1084
 
1085
  def __init__(self, config):
@@ -1087,14 +1111,13 @@ class MPRNAForMaskedLM(MPRNAPreTrainedModel):
1087
 
1088
  if config.is_decoder:
1089
  logger.warning(
1090
- "If you want to use `MPRNAForMaskedLM` make sure `config.is_decoder=False` for "
1091
  "bi-directional self-attention."
1092
  )
1093
 
1094
- self.MPRNA = MPRNAModel(config, add_pooling_layer=False)
1095
- self.lm_head = MPRNALMHead(config)
1096
-
1097
- self.init_weights()
1098
 
1099
  def get_output_embeddings(self):
1100
  return self.lm_head.decoder
@@ -1103,7 +1126,7 @@ class MPRNAForMaskedLM(MPRNAPreTrainedModel):
1103
  self.lm_head.decoder = new_embeddings
1104
 
1105
  @add_start_docstrings_to_model_forward(
1106
- MPRNA_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1107
  )
1108
  @add_code_sample_docstrings(
1109
  checkpoint=_CHECKPOINT_FOR_DOC,
@@ -1112,18 +1135,18 @@ class MPRNAForMaskedLM(MPRNAPreTrainedModel):
1112
  mask="<mask>",
1113
  )
1114
  def forward(
1115
- self,
1116
- input_ids: Optional[torch.LongTensor] = None,
1117
- attention_mask: Optional[torch.Tensor] = None,
1118
- position_ids: Optional[torch.LongTensor] = None,
1119
- head_mask: Optional[torch.Tensor] = None,
1120
- inputs_embeds: Optional[torch.FloatTensor] = None,
1121
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1122
- encoder_attention_mask: Optional[torch.Tensor] = None,
1123
- labels: Optional[torch.LongTensor] = None,
1124
- output_attentions: Optional[bool] = None,
1125
- output_hidden_states: Optional[bool] = None,
1126
- return_dict: Optional[bool] = None,
1127
  ) -> Union[Tuple, MaskedLMOutput]:
1128
  r"""
1129
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1137,7 +1160,7 @@ class MPRNAForMaskedLM(MPRNAPreTrainedModel):
1137
  return_dict if return_dict is not None else self.config.use_return_dict
1138
  )
1139
 
1140
- outputs = self.MPRNA(
1141
  input_ids,
1142
  attention_mask=attention_mask,
1143
  position_ids=position_ids,
@@ -1175,11 +1198,12 @@ class MPRNAForMaskedLM(MPRNAPreTrainedModel):
1175
  )
1176
 
1177
  def predict_contacts(self, tokens, attention_mask):
1178
- return self.MPRNA.predict_contacts(tokens, attention_mask=attention_mask)
1179
 
1180
 
1181
- class MPRNALMHead(nn.Module):
1182
- """MPRNA Head for masked language modeling."""
 
1183
 
1184
  def __init__(self, config):
1185
  super().__init__()
@@ -1201,24 +1225,22 @@ class MPRNALMHead(nn.Module):
1201
 
1202
  @add_start_docstrings(
1203
  """
1204
- MPRNA Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1205
  output) e.g. for GLUE tasks.
1206
  """,
1207
- MPRNA_START_DOCSTRING,
1208
  )
1209
- class MPRNAForSequenceClassification(MPRNAPreTrainedModel):
1210
  def __init__(self, config):
1211
  super().__init__(config)
1212
  self.num_labels = config.num_labels
1213
  self.config = config
1214
-
1215
- self.MPRNA = MPRNAModel(config, add_pooling_layer=False)
1216
- self.classifier = MPRNAClassificationHead(config)
1217
-
1218
- self.init_weights()
1219
 
1220
  @add_start_docstrings_to_model_forward(
1221
- MPRNA_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1222
  )
1223
  @add_code_sample_docstrings(
1224
  checkpoint=_CHECKPOINT_FOR_DOC,
@@ -1226,16 +1248,16 @@ class MPRNAForSequenceClassification(MPRNAPreTrainedModel):
1226
  config_class=_CONFIG_FOR_DOC,
1227
  )
1228
  def forward(
1229
- self,
1230
- input_ids: Optional[torch.LongTensor] = None,
1231
- attention_mask: Optional[torch.Tensor] = None,
1232
- position_ids: Optional[torch.LongTensor] = None,
1233
- head_mask: Optional[torch.Tensor] = None,
1234
- inputs_embeds: Optional[torch.FloatTensor] = None,
1235
- labels: Optional[torch.LongTensor] = None,
1236
- output_attentions: Optional[bool] = None,
1237
- output_hidden_states: Optional[bool] = None,
1238
- return_dict: Optional[bool] = None,
1239
  ) -> Union[Tuple, SequenceClassifierOutput]:
1240
  r"""
1241
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1247,7 +1269,7 @@ class MPRNAForSequenceClassification(MPRNAPreTrainedModel):
1247
  return_dict if return_dict is not None else self.config.use_return_dict
1248
  )
1249
 
1250
- outputs = self.MPRNA(
1251
  input_ids,
1252
  attention_mask=attention_mask,
1253
  position_ids=position_ids,
@@ -1268,7 +1290,7 @@ class MPRNAForSequenceClassification(MPRNAPreTrainedModel):
1268
  if self.num_labels == 1:
1269
  self.config.problem_type = "regression"
1270
  elif self.num_labels > 1 and (
1271
- labels.dtype == torch.long or labels.dtype == torch.int
1272
  ):
1273
  self.config.problem_type = "single_label_classification"
1274
  else:
@@ -1301,24 +1323,156 @@ class MPRNAForSequenceClassification(MPRNAPreTrainedModel):
1301
 
1302
  @add_start_docstrings(
1303
  """
1304
- MPRNA Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1305
- Named-Entity-Recognition (NER) tasks.
 
 
1306
  """,
1307
- MPRNA_START_DOCSTRING,
1308
  )
1309
- class MPRNAForTokenClassification(MPRNAPreTrainedModel):
 
1310
  def __init__(self, config):
1311
  super().__init__(config)
1312
  self.num_labels = config.num_labels
 
 
 
 
 
 
 
1313
 
1314
- self.MPRNA = MPRNAModel(config, add_pooling_layer=False)
1315
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
1316
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1317
 
1318
- self.init_weights()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1319
 
1320
  @add_start_docstrings_to_model_forward(
1321
- MPRNA_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1322
  )
1323
  @add_code_sample_docstrings(
1324
  checkpoint=_CHECKPOINT_FOR_DOC,
@@ -1326,16 +1480,16 @@ class MPRNAForTokenClassification(MPRNAPreTrainedModel):
1326
  config_class=_CONFIG_FOR_DOC,
1327
  )
1328
  def forward(
1329
- self,
1330
- input_ids: Optional[torch.LongTensor] = None,
1331
- attention_mask: Optional[torch.Tensor] = None,
1332
- position_ids: Optional[torch.LongTensor] = None,
1333
- head_mask: Optional[torch.Tensor] = None,
1334
- inputs_embeds: Optional[torch.FloatTensor] = None,
1335
- labels: Optional[torch.LongTensor] = None,
1336
- output_attentions: Optional[bool] = None,
1337
- output_hidden_states: Optional[bool] = None,
1338
- return_dict: Optional[bool] = None,
1339
  ) -> Union[Tuple, TokenClassifierOutput]:
1340
  r"""
1341
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1345,7 +1499,7 @@ class MPRNAForTokenClassification(MPRNAPreTrainedModel):
1345
  return_dict if return_dict is not None else self.config.use_return_dict
1346
  )
1347
 
1348
- outputs = self.MPRNA(
1349
  input_ids,
1350
  attention_mask=attention_mask,
1351
  position_ids=position_ids,
@@ -1380,7 +1534,8 @@ class MPRNAForTokenClassification(MPRNAPreTrainedModel):
1380
  )
1381
 
1382
 
1383
- class MPRNAClassificationHead(nn.Module):
 
1384
  """Head for sentence-level classification tasks."""
1385
 
1386
  def __init__(self, config):
@@ -1400,7 +1555,7 @@ class MPRNAClassificationHead(nn.Module):
1400
 
1401
 
1402
  def create_position_ids_from_input_ids(
1403
- input_ids, padding_idx, past_key_values_length=0
1404
  ):
1405
  """
1406
  Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
@@ -1414,6 +1569,6 @@ def create_position_ids_from_input_ids(
1414
  # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
1415
  mask = input_ids.ne(padding_idx).int()
1416
  incremental_indices = (
1417
- torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
1418
- ) * mask
1419
  return incremental_indices.long() + padding_idx
 
1
  # coding=utf-8
2
+ # Copyright 2022 ColaLab-UoE (https://colalab.ai/), Meta and The HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
+ """ PyTorch OmniGenome model."""
16
 
17
  import math
18
  from typing import List, Optional, Tuple, Union
 
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
  from transformers import add_start_docstrings, PreTrainedModel
25
 
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutputWithPastAndCrossAttentions,
28
+ BaseModelOutputWithPoolingAndCrossAttentions,
29
+ MaskedLMOutput,
30
+ SequenceClassifierOutput,
31
+ TokenClassifierOutput,
32
+ )
33
 
34
+ from transformers.pytorch_utils import (
35
+ find_pruneable_heads_and_indices,
36
+ prune_linear_layer,
37
+ )
38
 
39
+ from transformers.utils import (
40
+ logging,
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings_to_model_forward,
43
+ )
44
 
45
+ from .configuration_omnigenome import OmniGenomeConfig
46
 
47
  logger = logging.get_logger(__name__)
48
 
49
+ _CHECKPOINT_FOR_DOC = "yangheng/OmniGenome-52M"
50
+ _CONFIG_FOR_DOC = "OmniGenomeConfig"
51
 
52
+ OmniGenome_PRETRAINED_MODEL_ARCHIVE_LIST = [
53
+ "yangheng/OmniGenome-52M",
54
+ # This is not a complete list of all OmniGenome models!
55
+ # See all OmniGenome models at https://huggingface.co/models?filter=OmniGenome
56
  ]
57
 
58
 
 
70
 
71
  def gelu(x):
72
  """
73
+ This is the gelu implementation from the original OmniGenome repo. Using F.gelu yields subtly wrong results.
74
  """
75
  return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
76
 
 
92
  return normalized
93
 
94
 
95
+ # Copied from transformers.models.esm.modeling_esm.RotaryEmbedding
96
  class RotaryEmbedding(torch.nn.Module):
97
  """
98
  Rotary position embeddings based on those in
 
130
  return self._cos_cached, self._sin_cached
131
 
132
  def forward(
133
+ self, q: torch.Tensor, k: torch.Tensor
134
  ) -> Tuple[torch.Tensor, torch.Tensor]:
135
  self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
136
  k, seq_dimension=-2
 
142
  )
143
 
144
 
145
+ # Copied from transformers.models.esm.modeling_esm.EsmContactPredictionHead with Esm->OmniGenome
146
+ class OmniGenomeContactPredictionHead(nn.Module):
147
  """Performs symmetrization, apc, and computes a logistic regression on the output features"""
148
 
149
  def __init__(
150
+ self,
151
+ in_features: int,
152
+ bias=True,
153
+ eos_idx: int = 2,
154
  ):
155
  super().__init__()
156
  self.in_features = in_features
 
178
  return self.activation(self.regression(attentions).squeeze(3))
179
 
180
 
181
+ # Copied from transformers.models.esm.modeling_esm.EsmEmbeddings with Esm->OmniGenome
182
+ class OmniGenomeEmbeddings(nn.Module):
183
  """
184
  Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
185
  """
 
217
  self.mask_token_id = config.mask_token_id
218
 
219
  def forward(
220
+ self,
221
+ input_ids=None,
222
+ attention_mask=None,
223
+ position_ids=None,
224
+ inputs_embeds=None,
225
+ past_key_values_length=0,
226
  ):
227
  if position_ids is None:
228
  if input_ids is not None:
 
238
  if inputs_embeds is None:
239
  inputs_embeds = self.word_embeddings(input_ids)
240
 
241
+ # Note that if we want to support OmniGenome-1 (not 1b!) in future then we need to support an
242
  # embedding_scale factor here.
243
  embeddings = inputs_embeds
244
 
245
+ # Matt: OmniGenome has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
246
  # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
247
  # masked tokens are treated as if they were selected for input dropout and zeroed out.
248
  # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
 
254
  (input_ids == self.mask_token_id).unsqueeze(-1), 0.0
255
  )
256
  mask_ratio_train = (
257
+ 0.15 * 0.8
258
+ ) # Hardcoded as the ratio used in all OmniGenome model training runs
259
  src_lengths = attention_mask.sum(-1)
260
  mask_ratio_observed = (input_ids == self.mask_token_id).sum(
261
  -1
262
  ).float() / src_lengths
263
  embeddings = (
264
+ embeddings
265
+ * (1 - mask_ratio_train)
266
+ / (1 - mask_ratio_observed)[:, None, None]
267
  ).to(embeddings.dtype)
268
 
269
  if self.position_embedding_type == "absolute":
 
301
  return position_ids.unsqueeze(0).expand(input_shape)
302
 
303
 
304
+ # Copied from transformers.models.esm.modeling_esm.EsmSelfAttention with Esm->OmniGenome
305
+ class OmniGenomeSelfAttention(nn.Module):
306
  def __init__(self, config, position_embedding_type=None):
307
  super().__init__()
308
  if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
309
+ config, "embedding_size"
310
  ):
311
  raise ValueError(
312
  f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
 
327
  )
328
  self.rotary_embeddings = None
329
  if (
330
+ self.position_embedding_type == "relative_key"
331
+ or self.position_embedding_type == "relative_key_query"
332
  ):
333
  self.max_position_embeddings = config.max_position_embeddings
334
  self.distance_embedding = nn.Embedding(
 
348
  return x.permute(0, 2, 1, 3)
349
 
350
  def forward(
351
+ self,
352
+ hidden_states: torch.Tensor,
353
+ attention_mask: Optional[torch.FloatTensor] = None,
354
+ head_mask: Optional[torch.FloatTensor] = None,
355
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
356
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
357
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
358
+ output_attentions: Optional[bool] = False,
359
  ) -> Tuple[torch.Tensor]:
360
  mixed_query_layer = self.query(hidden_states)
361
 
 
385
  query_layer = self.transpose_for_scores(mixed_query_layer)
386
 
387
  # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
388
+ # OmniGenome scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
389
  # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
390
+ # OmniGenome code and fix rotary embeddings.
391
+ query_layer = query_layer * self.attention_head_size ** -0.5
392
 
393
  if self.is_decoder:
394
  # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
 
407
  attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
408
 
409
  if (
410
+ self.position_embedding_type == "relative_key"
411
+ or self.position_embedding_type == "relative_key_query"
412
  ):
413
  seq_length = hidden_states.size()[1]
414
  position_ids_l = torch.arange(
 
438
  "bhrd,lrd->bhlr", key_layer, positional_embedding
439
  )
440
  attention_scores = (
441
+ attention_scores
442
+ + relative_position_scores_query
443
+ + relative_position_scores_key
444
  )
445
 
446
  if attention_mask is not None:
447
+ # Apply the attention mask is (precomputed for all layers in OmniGenomeModel forward() function)
448
  attention_scores = attention_scores + attention_mask
449
 
450
  # Normalize the attention scores to probabilities.
 
473
  return outputs
474
 
475
 
476
+ # Copied from transformers.models.esm.modeling_esm.EsmSelfOutput with Esm->OmniGenome
477
+ class OmniGenomeSelfOutput(nn.Module):
478
  def __init__(self, config):
479
  super().__init__()
480
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
 
487
  return hidden_states
488
 
489
 
490
+ # Copied from transformers.models.esm.modeling_esm.EsmAttention with Esm->OmniGenome
491
+ class OmniGenomeAttention(nn.Module):
492
  def __init__(self, config):
493
  super().__init__()
494
+ self.self = OmniGenomeSelfAttention(config)
495
+ self.output = OmniGenomeSelfOutput(config)
496
  self.pruned_heads = set()
497
  self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
498
 
 
515
  # Update hyper params and store pruned heads
516
  self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
517
  self.self.all_head_size = (
518
+ self.self.attention_head_size * self.self.num_attention_heads
519
  )
520
  self.pruned_heads = self.pruned_heads.union(heads)
521
 
522
  def forward(
523
+ self,
524
+ hidden_states,
525
+ attention_mask=None,
526
+ head_mask=None,
527
+ encoder_hidden_states=None,
528
+ encoder_attention_mask=None,
529
+ past_key_value=None,
530
+ output_attentions=False,
531
  ):
532
  hidden_states_ln = self.LayerNorm(hidden_states)
533
  self_outputs = self.self(
 
541
  )
542
  attention_output = self.output(self_outputs[0], hidden_states)
543
  outputs = (attention_output,) + self_outputs[
544
+ 1:
545
+ ] # add attentions if we output them
546
  return outputs
547
 
548
 
549
+ # Copied from transformers.models.esm.modeling_esm.EsmIntermediate with Esm->OmniGenome
550
+ class OmniGenomeIntermediate(nn.Module):
551
  def __init__(self, config):
552
  super().__init__()
553
  self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
 
558
  return hidden_states
559
 
560
 
561
+ # Copied from transformers.models.esm.modeling_esm.EsmOutput with Esm->OmniGenome
562
+ class OmniGenomeOutput(nn.Module):
563
  def __init__(self, config):
564
  super().__init__()
565
  self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
 
572
  return hidden_states
573
 
574
 
575
+ # Copied from transformers.models.esm.modeling_esm.EsmLayer with Esm->OmniGenome
576
+ class OmniGenomeLayer(nn.Module):
577
  def __init__(self, config):
578
  super().__init__()
579
  self.chunk_size_feed_forward = config.chunk_size_feed_forward
580
  self.seq_len_dim = 1
581
+ self.attention = OmniGenomeAttention(config)
582
  self.is_decoder = config.is_decoder
583
  self.add_cross_attention = config.add_cross_attention
584
  if self.add_cross_attention:
 
586
  raise RuntimeError(
587
  f"{self} should be used as a decoder model if cross attention is added"
588
  )
589
+ self.crossattention = OmniGenomeAttention(config)
590
+ self.intermediate = OmniGenomeIntermediate(config)
591
+ self.output = OmniGenomeOutput(config)
592
  self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
593
 
594
  def forward(
595
+ self,
596
+ hidden_states,
597
+ attention_mask=None,
598
+ head_mask=None,
599
+ encoder_hidden_states=None,
600
+ encoder_attention_mask=None,
601
+ past_key_value=None,
602
+ output_attentions=False,
603
  ):
604
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
605
  self_attn_past_key_value = (
 
620
  present_key_value = self_attention_outputs[-1]
621
  else:
622
  outputs = self_attention_outputs[
623
+ 1:
624
+ ] # add self attentions if we output attention weights
625
 
626
  cross_attn_present_key_value = None
627
  if self.is_decoder and encoder_hidden_states is not None:
 
646
  )
647
  attention_output = cross_attention_outputs[0]
648
  outputs = (
649
+ outputs + cross_attention_outputs[1:-1]
650
  ) # add cross attentions if we output attention weights
651
 
652
  # add cross-attn cache to positions 3,4 of present_key_value tuple
 
669
  return layer_output
670
 
671
 
672
+ # Copied from transformers.models.esm.modeling_esm.EsmEncoder with Esm->OmniGenome
673
+ class OmniGenomeEncoder(nn.Module):
674
  def __init__(self, config):
675
  super().__init__()
676
  self.config = config
677
  self.layer = nn.ModuleList(
678
+ [OmniGenomeLayer(config) for _ in range(config.num_hidden_layers)]
679
  )
680
  self.emb_layer_norm_after = nn.LayerNorm(
681
  config.hidden_size, eps=config.layer_norm_eps
 
683
  self.gradient_checkpointing = False
684
 
685
  def forward(
686
+ self,
687
+ hidden_states,
688
+ attention_mask=None,
689
+ head_mask=None,
690
+ encoder_hidden_states=None,
691
+ encoder_attention_mask=None,
692
+ past_key_values=None,
693
+ use_cache=None,
694
+ output_attentions=False,
695
+ output_hidden_states=False,
696
+ return_dict=True,
697
  ):
698
  if self.gradient_checkpointing and self.training:
699
  if use_cache:
 
773
  )
774
 
775
 
776
+ # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->OmniGenome
777
+ class OmniGenomePooler(nn.Module):
778
  def __init__(self, config):
779
  super().__init__()
780
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
 
789
  return pooled_output
790
 
791
 
792
+ # Copied from transformers.models.esm.modeling_esm.EsmPreTrainedModel with Esm->OmniGenome
793
+ class OmniGenomePreTrainedModel(PreTrainedModel):
794
  """
795
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
796
  models.
797
  """
798
 
799
+ config_class = OmniGenomeConfig
800
+ base_model_prefix = "OmniGenome"
801
  supports_gradient_checkpointing = True
802
  _no_split_modules = [
803
+ "OmniGenomeLayer",
804
+ "OmniGenomeFoldTriangularSelfAttentionBlock",
805
+ "OmniGenomeEmbeddings",
806
  ]
807
 
808
  # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
 
823
  module.weight.data.fill_(1.0)
824
 
825
 
826
+ OmniGenome_START_DOCSTRING = r"""
827
 
828
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
829
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
 
834
  and behavior.
835
 
836
  Parameters:
837
+ config ([`OmniGenomeConfig`]): Model configuration class with all the parameters of the
838
  model. Initializing with a config file does not load the weights associated with the model, only the
839
  configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
840
  """
841
 
842
+ OmniGenome_INPUTS_DOCSTRING = r"""
843
  Args:
844
  input_ids (`torch.LongTensor` of shape `({0})`):
845
  Indices of input sequence tokens in the vocabulary.
 
882
 
883
 
884
  @add_start_docstrings(
885
+ "The bare OmniGenome Model transformer outputting raw hidden-states without any specific head on top.",
886
+ OmniGenome_START_DOCSTRING,
887
  )
888
+ # Copied from transformers.models.esm.modeling_esm.EsmModel with Esm->OmniGenome
889
+ class OmniGenomeModel(OmniGenomePreTrainedModel):
890
  """
891
 
892
  The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
 
903
  super().__init__(config)
904
  self.config = config
905
 
906
+ self.embeddings = OmniGenomeEmbeddings(config)
907
+ self.encoder = OmniGenomeEncoder(config)
908
 
909
+ self.pooler = OmniGenomePooler(config) if add_pooling_layer else None
910
 
911
+ self.contact_head = OmniGenomeContactPredictionHead(
912
  in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
913
  )
914
 
 
930
  self.encoder.layer[layer].attention.prune_heads(heads)
931
 
932
  @add_start_docstrings_to_model_forward(
933
+ OmniGenome_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")
934
  )
935
  @add_code_sample_docstrings(
936
  checkpoint=_CHECKPOINT_FOR_DOC,
 
938
  config_class=_CONFIG_FOR_DOC,
939
  )
940
  def forward(
941
+ self,
942
+ input_ids: Optional[torch.Tensor] = None,
943
+ attention_mask: Optional[torch.Tensor] = None,
944
+ position_ids: Optional[torch.Tensor] = None,
945
+ head_mask: Optional[torch.Tensor] = None,
946
+ inputs_embeds: Optional[torch.Tensor] = None,
947
+ encoder_hidden_states: Optional[torch.Tensor] = None,
948
+ encoder_attention_mask: Optional[torch.Tensor] = None,
949
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
950
+ use_cache: Optional[bool] = None,
951
+ output_attentions: Optional[bool] = None,
952
+ output_hidden_states: Optional[bool] = None,
953
+ return_dict: Optional[bool] = None,
954
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
955
  r"""
956
  encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
 
1100
 
1101
 
1102
  @add_start_docstrings(
1103
+ """OmniGenome Model with a `language modeling` head on top.""", OmniGenome_START_DOCSTRING
1104
  )
1105
+ # Copied from transformers.models.esm.modeling_esm.EsmForMaskedLM with Esm->OmniGenome
1106
+ class OmniGenomeForMaskedLM(OmniGenomePreTrainedModel):
1107
  _tied_weights_keys = ["lm_head.decoder.weight"]
1108
 
1109
  def __init__(self, config):
 
1111
 
1112
  if config.is_decoder:
1113
  logger.warning(
1114
+ "If you want to use `OmniGenomeForMaskedLM` make sure `config.is_decoder=False` for "
1115
  "bi-directional self-attention."
1116
  )
1117
 
1118
+ self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
1119
+ self.lm_head = OmniGenomeLMHead(config)
1120
+ # self.init_weights()
 
1121
 
1122
  def get_output_embeddings(self):
1123
  return self.lm_head.decoder
 
1126
  self.lm_head.decoder = new_embeddings
1127
 
1128
  @add_start_docstrings_to_model_forward(
1129
+ OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1130
  )
1131
  @add_code_sample_docstrings(
1132
  checkpoint=_CHECKPOINT_FOR_DOC,
 
1135
  mask="<mask>",
1136
  )
1137
  def forward(
1138
+ self,
1139
+ input_ids: Optional[torch.LongTensor] = None,
1140
+ attention_mask: Optional[torch.Tensor] = None,
1141
+ position_ids: Optional[torch.LongTensor] = None,
1142
+ head_mask: Optional[torch.Tensor] = None,
1143
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1144
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1145
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1146
+ labels: Optional[torch.LongTensor] = None,
1147
+ output_attentions: Optional[bool] = None,
1148
+ output_hidden_states: Optional[bool] = None,
1149
+ return_dict: Optional[bool] = None,
1150
  ) -> Union[Tuple, MaskedLMOutput]:
1151
  r"""
1152
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
1160
  return_dict if return_dict is not None else self.config.use_return_dict
1161
  )
1162
 
1163
+ outputs = self.OmniGenome(
1164
  input_ids,
1165
  attention_mask=attention_mask,
1166
  position_ids=position_ids,
 
1198
  )
1199
 
1200
  def predict_contacts(self, tokens, attention_mask):
1201
+ return self.OmniGenome.predict_contacts(tokens, attention_mask=attention_mask)
1202
 
1203
 
1204
+ # Copied from transformers.models.esm.modeling_esm.EsmLMHead with Esm->OmniGenome
1205
+ class OmniGenomeLMHead(nn.Module):
1206
+ """OmniGenome Head for masked language modeling."""
1207
 
1208
  def __init__(self, config):
1209
  super().__init__()
 
1225
 
1226
  @add_start_docstrings(
1227
  """
1228
+ OmniGenome Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1229
  output) e.g. for GLUE tasks.
1230
  """,
1231
+ OmniGenome_START_DOCSTRING,
1232
  )
1233
+ class OmniGenomeForSequenceClassification(OmniGenomePreTrainedModel):
1234
  def __init__(self, config):
1235
  super().__init__(config)
1236
  self.num_labels = config.num_labels
1237
  self.config = config
1238
+ self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
1239
+ self.classifier = OmniGenomeClassificationHead(config)
1240
+ # self.init_weights()
 
 
1241
 
1242
  @add_start_docstrings_to_model_forward(
1243
+ OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1244
  )
1245
  @add_code_sample_docstrings(
1246
  checkpoint=_CHECKPOINT_FOR_DOC,
 
1248
  config_class=_CONFIG_FOR_DOC,
1249
  )
1250
  def forward(
1251
+ self,
1252
+ input_ids: Optional[torch.LongTensor] = None,
1253
+ attention_mask: Optional[torch.Tensor] = None,
1254
+ position_ids: Optional[torch.LongTensor] = None,
1255
+ head_mask: Optional[torch.Tensor] = None,
1256
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1257
+ labels: Optional[torch.LongTensor] = None,
1258
+ output_attentions: Optional[bool] = None,
1259
+ output_hidden_states: Optional[bool] = None,
1260
+ return_dict: Optional[bool] = None,
1261
  ) -> Union[Tuple, SequenceClassifierOutput]:
1262
  r"""
1263
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
1269
  return_dict if return_dict is not None else self.config.use_return_dict
1270
  )
1271
 
1272
+ outputs = self.OmniGenome(
1273
  input_ids,
1274
  attention_mask=attention_mask,
1275
  position_ids=position_ids,
 
1290
  if self.num_labels == 1:
1291
  self.config.problem_type = "regression"
1292
  elif self.num_labels > 1 and (
1293
+ labels.dtype == torch.long or labels.dtype == torch.int
1294
  ):
1295
  self.config.problem_type = "single_label_classification"
1296
  else:
 
1323
 
1324
  @add_start_docstrings(
1325
  """
1326
+ OmniGenome Model with a token classification head on top (a linear layer on top of the hidden-states output)
1327
+ Note that this model is pre-trained for RNA secondary structure prediction and can be used for zero-shot RNA
1328
+ secondary structure prediction. Please find more advanced usages at https://github.com/yangheng95/OmniGenome
1329
+ This model can be fine-tuned for other token classification tasks.
1330
  """,
1331
+ OmniGenome_START_DOCSTRING,
1332
  )
1333
+ # Copied from transformers.models.esm.modeling_esm.EsmForTokenClassification with Esm->OmniGenome
1334
+ class OmniGenomeForTokenClassification(OmniGenomePreTrainedModel):
1335
  def __init__(self, config):
1336
  super().__init__(config)
1337
  self.num_labels = config.num_labels
1338
+ self.OmniGenome = OmniGenomeModel(config, add_pooling_layer=False)
1339
+ self.lm_head = OmniGenomeLMHead(config)
1340
+ self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
1341
+ self.classifier = torch.nn.Linear(self.config.hidden_size, self.num_labels)
1342
+ self.activation = torch.nn.Tanh()
1343
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
1344
+ # self.init_weights()
1345
 
1346
+ @add_start_docstrings_to_model_forward(
1347
+ OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1348
+ )
1349
+ @add_code_sample_docstrings(
1350
+ checkpoint=_CHECKPOINT_FOR_DOC,
1351
+ output_type=TokenClassifierOutput,
1352
+ config_class=_CONFIG_FOR_DOC,
1353
+ )
1354
+ def forward(
1355
+ self,
1356
+ input_ids: Optional[torch.LongTensor] = None,
1357
+ attention_mask: Optional[torch.Tensor] = None,
1358
+ position_ids: Optional[torch.LongTensor] = None,
1359
+ head_mask: Optional[torch.Tensor] = None,
1360
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1361
+ labels: Optional[torch.LongTensor] = None,
1362
+ output_attentions: Optional[bool] = None,
1363
+ output_hidden_states: Optional[bool] = None,
1364
+ return_dict: Optional[bool] = None,
1365
+ ) -> Union[Tuple, TokenClassifierOutput]:
1366
+ r"""
1367
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1368
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1369
+ """
1370
+
1371
+ return_dict = (
1372
+ return_dict if return_dict is not None else self.config.use_return_dict
1373
+ )
1374
+
1375
+ mlm_outputs = self.OmniGenome(
1376
+ input_ids,
1377
+ attention_mask=attention_mask,
1378
+ position_ids=position_ids,
1379
+ head_mask=head_mask,
1380
+ inputs_embeds=inputs_embeds,
1381
+ output_attentions=output_attentions,
1382
+ output_hidden_states=output_hidden_states,
1383
+ return_dict=return_dict,
1384
+ )
1385
+ try:
1386
+ last_hidden_state = mlm_outputs[0]
1387
+ last_hidden_state = self.dense(last_hidden_state)
1388
+ except:
1389
+ last_hidden_state = mlm_outputs.hidden_states[-1]
1390
+ last_hidden_state = self.dense(last_hidden_state)
1391
+
1392
+ logits = self.classifier(last_hidden_state)
1393
+ logits = torch.softmax(logits, dim=-1)
1394
+ logits = self.activation(logits)
1395
+ logits = self.dropout(logits)
1396
 
1397
+ loss = None
1398
+ if labels is not None:
1399
+ loss_fct = CrossEntropyLoss()
1400
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1401
+
1402
+ if not return_dict:
1403
+ output = (logits,) + mlm_outputs[2:]
1404
+ return ((loss,) + output) if loss is not None else output
1405
+
1406
+ return TokenClassifierOutput(
1407
+ loss=loss,
1408
+ logits=logits,
1409
+ hidden_states=mlm_outputs.hidden_states,
1410
+ attentions=mlm_outputs.attentions,
1411
+ )
1412
+
1413
+ @staticmethod
1414
+ def verify_secondary_structure(structure):
1415
+ structure = list(structure)
1416
+ left_brackets = []
1417
+ right_brackets = []
1418
+ for i, char in enumerate(structure):
1419
+ if char == "(":
1420
+ left_brackets.append(i)
1421
+ elif char == ")":
1422
+ if left_brackets:
1423
+ left_brackets.pop()
1424
+ else:
1425
+ right_brackets.append(i)
1426
+
1427
+ for i in left_brackets:
1428
+ structure[i] = "."
1429
+ for i in right_brackets:
1430
+ structure[i] = "."
1431
+
1432
+ structure = "".join(structure)
1433
+
1434
+ return structure
1435
+
1436
+ def predict_structure(
1437
+ self,
1438
+ input_ids: Optional[torch.LongTensor] = None,
1439
+ attention_mask: Optional[torch.Tensor] = None,
1440
+ **kwargs
1441
+ ) -> List[str]:
1442
+ """
1443
+ Predicts the secondary structure of a sequence given the logits and attention mask.
1444
+ """
1445
+ outputs = self.forward(input_ids, attention_mask, **kwargs)
1446
+
1447
+ logits = torch.argmax(outputs.logits, dim=-1)
1448
+ lengths = torch.sum(torch.ne(torch.tensor(0), attention_mask), dim=-1)
1449
+ structures = []
1450
+ for i, length in enumerate(lengths):
1451
+ structure = logits[i, :length].cpu().numpy()
1452
+ structure = "".join(self.config.id2label[label] for label in structure)
1453
+ if self.config.verify_ss:
1454
+ structure = self.verify_secondary_structure(structure)
1455
+ structures.append(structure)
1456
+ return structures
1457
+
1458
+
1459
+ @add_start_docstrings(
1460
+ """
1461
+ OmniGenome Model with a simple genetic algorithm based RNA design head on top.
1462
+ """,
1463
+ OmniGenome_START_DOCSTRING,
1464
+ )
1465
+ class OmniGenomeMaskedLMForRNADesign(OmniGenomePreTrainedModel):
1466
+ def __init__(self, config):
1467
+ super().__init__(config)
1468
+ self.num_labels = config.num_labels
1469
+ self.OmniGenome = OmniGenomeForMaskedLM(config)
1470
+ self.num_generation = config.num_generation
1471
+ self.num_population = config.num_population
1472
+ # self.init_weights()
1473
 
1474
  @add_start_docstrings_to_model_forward(
1475
+ OmniGenome_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1476
  )
1477
  @add_code_sample_docstrings(
1478
  checkpoint=_CHECKPOINT_FOR_DOC,
 
1480
  config_class=_CONFIG_FOR_DOC,
1481
  )
1482
  def forward(
1483
+ self,
1484
+ input_ids: Optional[torch.LongTensor] = None,
1485
+ attention_mask: Optional[torch.Tensor] = None,
1486
+ position_ids: Optional[torch.LongTensor] = None,
1487
+ head_mask: Optional[torch.Tensor] = None,
1488
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1489
+ labels: Optional[torch.LongTensor] = None,
1490
+ output_attentions: Optional[bool] = None,
1491
+ output_hidden_states: Optional[bool] = True,
1492
+ return_dict: Optional[bool] = None,
1493
  ) -> Union[Tuple, TokenClassifierOutput]:
1494
  r"""
1495
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
1499
  return_dict if return_dict is not None else self.config.use_return_dict
1500
  )
1501
 
1502
+ outputs = self.OmniGenome(
1503
  input_ids,
1504
  attention_mask=attention_mask,
1505
  position_ids=position_ids,
 
1534
  )
1535
 
1536
 
1537
+ # Copied from transformers.models.esm.modeling_esm.EsmClassificationHead with Esm->OmniGenome
1538
+ class OmniGenomeClassificationHead(nn.Module):
1539
  """Head for sentence-level classification tasks."""
1540
 
1541
  def __init__(self, config):
 
1555
 
1556
 
1557
  def create_position_ids_from_input_ids(
1558
+ input_ids, padding_idx, past_key_values_length=0
1559
  ):
1560
  """
1561
  Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
 
1569
  # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
1570
  mask = input_ids.ne(padding_idx).int()
1571
  incremental_indices = (
1572
+ torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
1573
+ ) * mask
1574
  return incremental_indices.long() + padding_idx