GO-language / trainer.py
willdampier's picture
adding trainer, readme, and tokenizer
769b5d9
raw
history blame
3.47 kB
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import datasets
import os
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import WhitespaceSplit
from tokenizers.processors import TemplateProcessing
from tokenizers.trainers import WordLevelTrainer
from tokenizers.decoders import WordPiece
from transformers import PreTrainedTokenizerFast
from transformers import BertConfig, BertForMaskedLM, BertModel, BertForPreTraining
from transformers import (
AutoModelForMaskedLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
EarlyStoppingCallback,
Trainer,
TrainingArguments,
)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["WANDB_DISABLED"] = "true"
NUM_TRAIN_EPOCHS = 100
go_uni = datasets.load_dataset("damlab/uniprot")["train"].filter(
lambda x: x["go"] is not None
)
tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"),)
tokenizer.pre_tokenizer = WhitespaceSplit()
trainer = WordLevelTrainer(
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "[BOS]", "[EOS]"]
)
tokenizer.train_from_iterator(go_uni["go"], trainer=trainer)
cls_token_id = tokenizer.token_to_id("[CLS]")
sep_token_id = tokenizer.token_to_id("[SEP]")
print(cls_token_id, sep_token_id)
tokenizer.post_processor = TemplateProcessing(
single=f"[CLS]:0 $A:0 [SEP]:0",
pair=f"[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[("[CLS]", cls_token_id), ("[SEP]", sep_token_id)],
)
tokenizer.decoder = WordPiece(prefix="##")
wrapped_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
# tokenizer_file="tokenizer.json", # You can load from the tokenizer file, alternatively
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]",
)
wrapped_tokenizer.save_pretrained("./")
def tkn_func(examples):
return wrapped_tokenizer(examples["go"], max_length=256, truncation=True)
tokenized_dataset = go_uni.map(
tkn_func, batched=True, remove_columns=go_uni.column_names
)
split_dataset = tokenized_dataset.train_test_split(seed=1234)
data_collator = DataCollatorForLanguageModeling(
tokenizer=wrapped_tokenizer, mlm_probability=0.15, pad_to_multiple_of=8,
)
training_args = TrainingArguments(
"trainer",
evaluation_strategy="steps",
load_best_model_at_end=False,
save_strategy="no",
logging_first_step=True,
logging_steps=10,
eval_steps=10,
num_train_epochs=NUM_TRAIN_EPOCHS,
warmup_steps=10,
weight_decay=0.01,
per_device_train_batch_size=24,
per_device_eval_batch_size=24,
gradient_accumulation_steps=96,
lr_scheduler_type="cosine_with_restarts",
)
encoder_bert = BertConfig(
vocab_size=tokenizer.get_vocab_size(),
hidden_size=1024,
num_hidden_layers=12,
num_attention_heads=32,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=256,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
position_embedding_type="absolute",
)
def model_init():
return BertForMaskedLM(encoder_bert)
trainer = Trainer(
model_init=model_init,
args=training_args,
train_dataset=split_dataset["train"],
eval_dataset=split_dataset["test"],
data_collator=data_collator,
)
results = trainer.train()
trainer.save_model("./")