Christina Theodoris
commited on
Commit
•
fb130e6
1
Parent(s):
86fe0dd
update kwargs for pretrainer
Browse files- geneformer/pretrainer.py +10 -9
geneformer/pretrainer.py
CHANGED
@@ -106,9 +106,8 @@ class TensorType(ExplicitEnum):
|
|
106 |
|
107 |
class GeneformerPreCollator(SpecialTokensMixin):
|
108 |
def __init__(self, *args, **kwargs) -> None:
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
self.token_dictionary = kwargs.get("token_dictionary")
|
113 |
# self.mask_token = "<mask>"
|
114 |
# self.mask_token_id = self.token_dictionary.get("<mask>")
|
@@ -120,8 +119,8 @@ class GeneformerPreCollator(SpecialTokensMixin):
|
|
120 |
# self.token_dictionary.get("<pad>"),
|
121 |
# ]
|
122 |
self.model_input_names = ["input_ids"]
|
123 |
-
|
124 |
-
def convert_ids_to_tokens(self,value):
|
125 |
return self.token_dictionary.get(value)
|
126 |
|
127 |
def _get_padding_truncation_strategies(
|
@@ -391,7 +390,6 @@ class GeneformerPreCollator(SpecialTokensMixin):
|
|
391 |
|
392 |
for key, value in encoded_inputs.items():
|
393 |
encoded_inputs[key] = to_py_obj(value)
|
394 |
-
|
395 |
|
396 |
# Convert padding_strategy in PaddingStrategy
|
397 |
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
|
@@ -596,15 +594,17 @@ class GeneformerPreCollator(SpecialTokensMixin):
|
|
596 |
|
597 |
class GeneformerPretrainer(Trainer):
|
598 |
def __init__(self, *args, **kwargs):
|
599 |
-
data_collator = kwargs.get("data_collator",None)
|
600 |
token_dictionary = kwargs.pop("token_dictionary")
|
|
|
|
|
601 |
|
602 |
if data_collator is None:
|
603 |
precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
|
604 |
|
605 |
# # Data Collator Functions
|
606 |
data_collator = DataCollatorForLanguageModeling(
|
607 |
-
tokenizer=precollator, mlm=
|
608 |
)
|
609 |
kwargs["data_collator"] = data_collator
|
610 |
|
@@ -694,6 +694,7 @@ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
|
|
694 |
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
|
695 |
length while keeping a bit of randomness.
|
696 |
"""
|
|
|
697 |
# Copied and adapted from PyTorch DistributedSampler.
|
698 |
def __init__(
|
699 |
self,
|
@@ -757,7 +758,7 @@ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
|
|
757 |
# Deterministically shuffle based on epoch and seed
|
758 |
g = torch.Generator()
|
759 |
g.manual_seed(self.seed + self.epoch)
|
760 |
-
|
761 |
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
762 |
|
763 |
if not self.drop_last:
|
|
|
106 |
|
107 |
class GeneformerPreCollator(SpecialTokensMixin):
|
108 |
def __init__(self, *args, **kwargs) -> None:
|
109 |
+
super().__init__(mask_token="<mask>", pad_token="<pad>")
|
110 |
+
|
|
|
111 |
self.token_dictionary = kwargs.get("token_dictionary")
|
112 |
# self.mask_token = "<mask>"
|
113 |
# self.mask_token_id = self.token_dictionary.get("<mask>")
|
|
|
119 |
# self.token_dictionary.get("<pad>"),
|
120 |
# ]
|
121 |
self.model_input_names = ["input_ids"]
|
122 |
+
|
123 |
+
def convert_ids_to_tokens(self, value):
|
124 |
return self.token_dictionary.get(value)
|
125 |
|
126 |
def _get_padding_truncation_strategies(
|
|
|
390 |
|
391 |
for key, value in encoded_inputs.items():
|
392 |
encoded_inputs[key] = to_py_obj(value)
|
|
|
393 |
|
394 |
# Convert padding_strategy in PaddingStrategy
|
395 |
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
|
|
|
594 |
|
595 |
class GeneformerPretrainer(Trainer):
|
596 |
def __init__(self, *args, **kwargs):
|
597 |
+
data_collator = kwargs.get("data_collator", None)
|
598 |
token_dictionary = kwargs.pop("token_dictionary")
|
599 |
+
mlm = kwargs.pop("mlm", True)
|
600 |
+
mlm_probability = kwargs.pop("mlm_probability", 0.15)
|
601 |
|
602 |
if data_collator is None:
|
603 |
precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
|
604 |
|
605 |
# # Data Collator Functions
|
606 |
data_collator = DataCollatorForLanguageModeling(
|
607 |
+
tokenizer=precollator, mlm=mlm, mlm_probability=mlm_probability
|
608 |
)
|
609 |
kwargs["data_collator"] = data_collator
|
610 |
|
|
|
694 |
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
|
695 |
length while keeping a bit of randomness.
|
696 |
"""
|
697 |
+
|
698 |
# Copied and adapted from PyTorch DistributedSampler.
|
699 |
def __init__(
|
700 |
self,
|
|
|
758 |
# Deterministically shuffle based on epoch and seed
|
759 |
g = torch.Generator()
|
760 |
g.manual_seed(self.seed + self.epoch)
|
761 |
+
|
762 |
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
763 |
|
764 |
if not self.drop_last:
|