|
import pandas as pd |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import datasets |
|
import os |
|
|
|
from tokenizers import Tokenizer |
|
from tokenizers.models import WordLevel |
|
from tokenizers.pre_tokenizers import WhitespaceSplit |
|
from tokenizers.processors import TemplateProcessing |
|
from tokenizers.trainers import WordLevelTrainer |
|
from tokenizers.decoders import WordPiece |
|
|
|
from transformers import PreTrainedTokenizerFast |
|
from transformers import BertConfig, BertForMaskedLM, BertModel, BertForPreTraining |
|
from transformers import ( |
|
AutoModelForMaskedLM, |
|
AutoTokenizer, |
|
DataCollatorForLanguageModeling, |
|
EarlyStoppingCallback, |
|
Trainer, |
|
TrainingArguments, |
|
) |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
os.environ["WANDB_DISABLED"] = "true" |
|
|
|
NUM_TRAIN_EPOCHS = 100 |
|
|
|
go_uni = datasets.load_dataset("damlab/uniprot")["train"].filter( |
|
lambda x: x["go"] is not None |
|
) |
|
|
|
|
|
tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"),) |
|
tokenizer.pre_tokenizer = WhitespaceSplit() |
|
|
|
trainer = WordLevelTrainer( |
|
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "[BOS]", "[EOS]"] |
|
) |
|
tokenizer.train_from_iterator(go_uni["go"], trainer=trainer) |
|
|
|
cls_token_id = tokenizer.token_to_id("[CLS]") |
|
sep_token_id = tokenizer.token_to_id("[SEP]") |
|
print(cls_token_id, sep_token_id) |
|
|
|
tokenizer.post_processor = TemplateProcessing( |
|
single=f"[CLS]:0 $A:0 [SEP]:0", |
|
pair=f"[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", |
|
special_tokens=[("[CLS]", cls_token_id), ("[SEP]", sep_token_id)], |
|
) |
|
|
|
tokenizer.decoder = WordPiece(prefix="##") |
|
|
|
wrapped_tokenizer = PreTrainedTokenizerFast( |
|
tokenizer_object=tokenizer, |
|
|
|
unk_token="[UNK]", |
|
pad_token="[PAD]", |
|
cls_token="[CLS]", |
|
sep_token="[SEP]", |
|
mask_token="[MASK]", |
|
) |
|
|
|
wrapped_tokenizer.save_pretrained("./") |
|
|
|
|
|
def tkn_func(examples): |
|
return wrapped_tokenizer(examples["go"], max_length=256, truncation=True) |
|
|
|
|
|
tokenized_dataset = go_uni.map( |
|
tkn_func, batched=True, remove_columns=go_uni.column_names |
|
) |
|
split_dataset = tokenized_dataset.train_test_split(seed=1234) |
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
tokenizer=wrapped_tokenizer, mlm_probability=0.15, pad_to_multiple_of=8, |
|
) |
|
|
|
training_args = TrainingArguments( |
|
"trainer", |
|
evaluation_strategy="steps", |
|
load_best_model_at_end=False, |
|
save_strategy="no", |
|
logging_first_step=True, |
|
logging_steps=10, |
|
eval_steps=10, |
|
num_train_epochs=NUM_TRAIN_EPOCHS, |
|
warmup_steps=10, |
|
weight_decay=0.01, |
|
per_device_train_batch_size=24, |
|
per_device_eval_batch_size=24, |
|
gradient_accumulation_steps=96, |
|
lr_scheduler_type="cosine_with_restarts", |
|
) |
|
|
|
|
|
encoder_bert = BertConfig( |
|
vocab_size=tokenizer.get_vocab_size(), |
|
hidden_size=1024, |
|
num_hidden_layers=12, |
|
num_attention_heads=32, |
|
intermediate_size=3072, |
|
hidden_act="gelu", |
|
hidden_dropout_prob=0.1, |
|
attention_probs_dropout_prob=0.1, |
|
max_position_embeddings=256, |
|
type_vocab_size=2, |
|
initializer_range=0.02, |
|
layer_norm_eps=1e-12, |
|
pad_token_id=0, |
|
position_embedding_type="absolute", |
|
) |
|
|
|
|
|
def model_init(): |
|
return BertForMaskedLM(encoder_bert) |
|
|
|
|
|
trainer = Trainer( |
|
model_init=model_init, |
|
args=training_args, |
|
train_dataset=split_dataset["train"], |
|
eval_dataset=split_dataset["test"], |
|
data_collator=data_collator, |
|
) |
|
|
|
results = trainer.train() |
|
trainer.save_model("./") |
|
|