import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns import datasets import os from tokenizers import Tokenizer from tokenizers.models import WordLevel from tokenizers.pre_tokenizers import WhitespaceSplit from tokenizers.processors import TemplateProcessing from tokenizers.trainers import WordLevelTrainer from tokenizers.decoders import WordPiece from transformers import PreTrainedTokenizerFast from transformers import BertConfig, BertForMaskedLM, BertModel, BertForPreTraining from transformers import ( AutoModelForMaskedLM, AutoTokenizer, DataCollatorForLanguageModeling, EarlyStoppingCallback, Trainer, TrainingArguments, ) os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["WANDB_DISABLED"] = "true" NUM_TRAIN_EPOCHS = 100 go_uni = datasets.load_dataset("damlab/uniprot")["train"].filter( lambda x: x["go"] is not None ) tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"),) tokenizer.pre_tokenizer = WhitespaceSplit() trainer = WordLevelTrainer( special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "[BOS]", "[EOS]"] ) tokenizer.train_from_iterator(go_uni["go"], trainer=trainer) cls_token_id = tokenizer.token_to_id("[CLS]") sep_token_id = tokenizer.token_to_id("[SEP]") print(cls_token_id, sep_token_id) tokenizer.post_processor = TemplateProcessing( single=f"[CLS]:0 $A:0 [SEP]:0", pair=f"[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", special_tokens=[("[CLS]", cls_token_id), ("[SEP]", sep_token_id)], ) tokenizer.decoder = WordPiece(prefix="##") wrapped_tokenizer = PreTrainedTokenizerFast( tokenizer_object=tokenizer, # tokenizer_file="tokenizer.json", # You can load from the tokenizer file, alternatively unk_token="[UNK]", pad_token="[PAD]", cls_token="[CLS]", sep_token="[SEP]", mask_token="[MASK]", ) wrapped_tokenizer.save_pretrained("./") def tkn_func(examples): return wrapped_tokenizer(examples["go"], max_length=256, truncation=True) tokenized_dataset = go_uni.map( tkn_func, batched=True, remove_columns=go_uni.column_names ) split_dataset = tokenized_dataset.train_test_split(seed=1234) data_collator = DataCollatorForLanguageModeling( tokenizer=wrapped_tokenizer, mlm_probability=0.15, pad_to_multiple_of=8, ) training_args = TrainingArguments( "trainer", evaluation_strategy="steps", load_best_model_at_end=False, save_strategy="no", logging_first_step=True, logging_steps=10, eval_steps=10, num_train_epochs=NUM_TRAIN_EPOCHS, warmup_steps=10, weight_decay=0.01, per_device_train_batch_size=24, per_device_eval_batch_size=24, gradient_accumulation_steps=96, lr_scheduler_type="cosine_with_restarts", ) encoder_bert = BertConfig( vocab_size=tokenizer.get_vocab_size(), hidden_size=1024, num_hidden_layers=12, num_attention_heads=32, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=256, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", ) def model_init(): return BertForMaskedLM(encoder_bert) trainer = Trainer( model_init=model_init, args=training_args, train_dataset=split_dataset["train"], eval_dataset=split_dataset["test"], data_collator=data_collator, ) results = trainer.train() trainer.save_model("./")