AWolters's picture
update
c44944d
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
from transformers import AdamWeightDecay
import tensorflow as tf
import random
from transformers import logging as hf_logging
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
import numpy as np
import textwrap
import argparse
import re
import warnings
import os
warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
hf_logging.set_verbosity_error()
np.random.seed(1234)
tf.random.set_seed(1234)
random.seed(1234)
def create_arg_parser():
'''Creating command line arguments'''
parser = argparse.ArgumentParser()
parser.add_argument("-tf", "--transformer", default="google/byt5-small",
type=str, help="this argument takes the pretrained "
"language model URL from HuggingFace "
"default is ByT5-small, please visit "
"HuggingFace for full URL")
parser.add_argument("-c_model", "--custom_model",
type=str, help="this argument takes a custom "
"pretrained checkpoint")
parser.add_argument("-train", "--train_data", default='training_data10k.txt',
type=str, help="this argument takes the train "
"data file as input")
parser.add_argument("-dev", "--dev_data", default='validation_data.txt',
type=str, help="this argument takes the dev data file "
"as input")
parser.add_argument("-lr", "--learn_rate", default=5e-5, type=float,
help="Set a custom learn rate for "
"the model, default is 5e-5")
parser.add_argument("-bs", "--batch_size", default=8, type=int,
help="Set a custom batch size for "
"the pretrained language model, default is 8")
parser.add_argument("-sl_train", "--sequence_length_train", default=155,
type=int, help="Set a custom maximum sequence length"
"for the pretrained language model,"
"default is 155")
parser.add_argument("-sl_dev", "--sequence_length_dev", default=155,
type=int, help="Set a custom maximum sequence length"
"for the pretrained language model,"
"default is 155")
parser.add_argument("-ep", "--epochs", default=1, type=int,
help="This argument selects the amount of epochs "
"to run the model with, default is 1 epoch")
parser.add_argument("-es", "--early_stop", default="val_loss", type=str,
help="Set the value to monitor for earlystopping")
parser.add_argument("-es_p", "--early_stop_patience", default=2,
type=int, help="Set the patience value for "
"earlystopping, default is 2")
args = parser.parse_args()
return args
def read_data(data_file):
'''Reading in data files'''
with open(data_file) as file:
data = file.readlines()
text = []
for d in data:
text.append(d)
return text
def create_data(data):
'''Splitting Alpino format training data into separate
source and target sentences'''
source_text = []
target_text = []
for x in data:
source = []
target = []
spel = re.findall(r'\[.*?\]', x)
if spel:
for s in spel:
s = s.split()
if s[1] == '@alt':
target.append(''.join(s[2:3]))
source.append(''.join(s[3:-1]))
elif s[1] == '@mwu_alt':
target.append(''.join(s[2:3]))
source.append(''.join(s[3:-1]).replace('-', ''))
elif s[1] == '@mwu':
target.append(''.join(s[2:-1]))
source.append(' '.join(s[2:-1]))
elif s[1] == '@postag':
target.append(''.join(s[-2]))
source.append(''.join(s[-2]))
elif s[1] == '@phantom':
target.append(''.join(s[2]))
source.append('')
target2 = []
for t in target:
if t[0] == '~':
t = t.split('~')
target2.append(t[1])
else:
target2.append(t)
sent = re.sub(r'\[.*?\]', 'EMPTY', x)
word_c = 0
src = []
trg = []
for word in sent.split():
if word == 'EMPTY':
src.append(source[word_c])
trg.append(target2[word_c])
word_c += 1
else:
src.append(word)
trg.append(word)
source_text.append(' '.join(src))
target_text.append(' '.join(trg))
return source_text, target_text
def split_sent(data, max_length):
'''Splitting sentences if longer than given max_length value'''
short_sent = []
long_sent = []
for n in data:
n = n.split('|')
if len(n[1]) <= max_length:
short_sent.append(n[1])
elif len(n[1]) > max_length:
n[1] = re.sub(r'(\s)+(?=[^[]*?\])', '$$', n[1])
n[1] = n[1].replace("] [", "]##[")
lines = textwrap.wrap(n[1], max_length, break_long_words=False)
long_sent.append(lines)
new_data = []
for s in long_sent:
for s1 in s:
s1 = s1.replace(']##[', '] [')
s1 = s1.replace('$$', ' ')
s2 = s1.split()
if len(s2) > 2:
new_data.append(s1)
for x in short_sent:
new_data.append(x)
return new_data
def preprocess_function(tk, s, t):
'''tokenizing text and labels'''
model_inputs = tk(s)
with tk.as_target_tokenizer():
labels = tk(t)
model_inputs["labels"] = labels["input_ids"]
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
def convert_tok(tok, sl):
'''Convert tokenized object to Tensors and add padding'''
input_ids = []
attention_mask = []
labels = []
decoder_attention_mask = []
for a, b, c, d in zip(tok['input_ids'], tok['attention_mask'], tok['labels'],
tok['decoder_attention_mask']):
input_ids.append(a)
attention_mask.append(b)
labels.append(c)
decoder_attention_mask.append(d)
input_ids_pad = pad_sequences(input_ids, padding='post', maxlen=sl)
attention_mask_pad = pad_sequences(attention_mask, padding='post',
maxlen=sl)
labels_pad = pad_sequences(labels, padding='post', maxlen=sl)
dec_attention_mask_pad = pad_sequences(decoder_attention_mask,
padding='post', maxlen=sl)
return {'input_ids': tf.constant(input_ids_pad), 'attention_mask':
tf.constant(attention_mask_pad), 'labels': tf.constant(labels_pad),
'decoder_attention_mask': tf.constant(dec_attention_mask_pad)}
def train_model(model_name, lr, bs, sl_train, sl_dev, ep, es, es_p, train, dev):
'''Finetune and save a given T5 version with given parameters'''
print('Training model: {}\nWith parameters:\nLearn rate: {}, '
'Batch size: {}\nSequence length train: {}, sequence length dev: {}\n'
'Epochs: {}'.format(model_name, lr, bs, sl_train, sl_dev, ep))
tk = AutoTokenizer.from_pretrained(model_name)
args = create_arg_parser()
source_train, target_train = create_data(train)
source_test, target_test = create_data(dev)
if args.custom_model:
model = TFAutoModelForSeq2SeqLM.from_pretrained(args.custom_model,
from_pt=True)
else:
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
train_tok = preprocess_function(tk, source_train, target_train)
dev_tok = preprocess_function(tk, source_test, target_test)
tf_train = convert_tok(train_tok, sl_train)
tf_dev = convert_tok(dev_tok, sl_dev)
optim = AdamWeightDecay(learning_rate=lr)
model.compile(optimizer=optim, loss=custom_loss,
metrics=[accuracy])
ear_stop = tf.keras.callbacks.EarlyStopping(monitor=es, patience=es_p,
restore_best_weights=True,
mode="auto")
model.fit(tf_train, validation_data=tf_dev, epochs=ep,
batch_size=bs, callbacks=[ear_stop])
model.save_weights('{}_weights.h5'.format(model_name[7:]))
return model
def custom_loss(y_true, y_pred):
'''Custom loss function'''
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction='none')
loss = loss_fn(y_true, y_pred)
mask = tf.cast(y_true != 0, loss.dtype)
loss *= mask
return tf.reduce_sum(loss)/tf.reduce_sum(mask)
def accuracy(y_true, y_pred):
'''Custom accuracy function '''
y_pred = tf.argmax(y_pred, axis=-1)
y_pred = tf.cast(y_pred, y_true.dtype)
match = tf.cast(y_true == y_pred, tf.float32)
mask = tf.cast(y_true != 0, tf.float32)
return tf.reduce_sum(match)/tf.reduce_sum(mask)
def main():
args = create_arg_parser()
lr = args.learn_rate
bs = args.batch_size
sl_train = args.sequence_length_train
sl_dev = args.sequence_length_dev
split_length_train = (sl_train - 5)
split_length_dev = (sl_dev - 5)
ep = args.epochs
if args.transformer == 'google/flan-t5-small':
model_name = 'google/flan-t5-small'
elif args.transformer == 'google/byt5-small':
model_name = 'google/byt5-small'
elif args.transformer == 'google/mt5-small':
model_name = 'google/mt5-small'
else:
model_name = 'Unknown'
early_stop = args.early_stop
patience = args.early_stop_patience
train_d = read_data(args.train_data)
dev_d = read_data(args.dev_data)
train_data = split_sent(train_d, split_length_train)
dev_data = split_sent(dev_d, split_length_dev)
print('Train size: {}\nDev size: {}\n'.format(len(train_data),
len(dev_data)))
print(train_model(model_name, lr, bs, sl_train, sl_dev,
ep, early_stop, patience, train_data, dev_data))
if __name__ == '__main__':
main()