File size: 3,472 Bytes
769b5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import datasets
import os

from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import WhitespaceSplit
from tokenizers.processors import TemplateProcessing
from tokenizers.trainers import WordLevelTrainer
from tokenizers.decoders import WordPiece

from transformers import PreTrainedTokenizerFast
from transformers import BertConfig, BertForMaskedLM, BertModel, BertForPreTraining
from transformers import (
    AutoModelForMaskedLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
    Trainer,
    TrainingArguments,
)

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["WANDB_DISABLED"] = "true"

NUM_TRAIN_EPOCHS = 100

go_uni = datasets.load_dataset("damlab/uniprot")["train"].filter(
    lambda x: x["go"] is not None
)


tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"),)
tokenizer.pre_tokenizer = WhitespaceSplit()

trainer = WordLevelTrainer(
    special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "[BOS]", "[EOS]"]
)
tokenizer.train_from_iterator(go_uni["go"], trainer=trainer)

cls_token_id = tokenizer.token_to_id("[CLS]")
sep_token_id = tokenizer.token_to_id("[SEP]")
print(cls_token_id, sep_token_id)

tokenizer.post_processor = TemplateProcessing(
    single=f"[CLS]:0 $A:0 [SEP]:0",
    pair=f"[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
    special_tokens=[("[CLS]", cls_token_id), ("[SEP]", sep_token_id)],
)

tokenizer.decoder = WordPiece(prefix="##")

wrapped_tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer,
    # tokenizer_file="tokenizer.json", # You can load from the tokenizer file, alternatively
    unk_token="[UNK]",
    pad_token="[PAD]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    mask_token="[MASK]",
)

wrapped_tokenizer.save_pretrained("./")


def tkn_func(examples):
    return wrapped_tokenizer(examples["go"], max_length=256, truncation=True)


tokenized_dataset = go_uni.map(
    tkn_func, batched=True, remove_columns=go_uni.column_names
)
split_dataset = tokenized_dataset.train_test_split(seed=1234)


data_collator = DataCollatorForLanguageModeling(
    tokenizer=wrapped_tokenizer, mlm_probability=0.15, pad_to_multiple_of=8,
)

training_args = TrainingArguments(
    "trainer",
    evaluation_strategy="steps",
    load_best_model_at_end=False,
    save_strategy="no",
    logging_first_step=True,
    logging_steps=10,
    eval_steps=10,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    warmup_steps=10,
    weight_decay=0.01,
    per_device_train_batch_size=24,
    per_device_eval_batch_size=24,
    gradient_accumulation_steps=96,
    lr_scheduler_type="cosine_with_restarts",
)


encoder_bert = BertConfig(
    vocab_size=tokenizer.get_vocab_size(),
    hidden_size=1024,
    num_hidden_layers=12,
    num_attention_heads=32,
    intermediate_size=3072,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=256,
    type_vocab_size=2,
    initializer_range=0.02,
    layer_norm_eps=1e-12,
    pad_token_id=0,
    position_embedding_type="absolute",
)


def model_init():
    return BertForMaskedLM(encoder_bert)


trainer = Trainer(
    model_init=model_init,
    args=training_args,
    train_dataset=split_dataset["train"],
    eval_dataset=split_dataset["test"],
    data_collator=data_collator,
)

results = trainer.train()
trainer.save_model("./")