|
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM |
|
from transformers import AdamWeightDecay |
|
import tensorflow as tf |
|
import random |
|
from transformers import logging as hf_logging |
|
from tensorflow.keras.preprocessing.sequence import pad_sequences |
|
from sklearn.model_selection import train_test_split |
|
import numpy as np |
|
import textwrap |
|
import argparse |
|
import re |
|
import warnings |
|
import os |
|
warnings.filterwarnings("ignore") |
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
|
hf_logging.set_verbosity_error() |
|
|
|
np.random.seed(1234) |
|
tf.random.set_seed(1234) |
|
random.seed(1234) |
|
|
|
|
|
def create_arg_parser(): |
|
'''Creating command line arguments''' |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("-tf", "--transformer", default="google/byt5-small", |
|
type=str, help="this argument takes the pretrained " |
|
"language model URL from HuggingFace " |
|
"default is ByT5-small, please visit " |
|
"HuggingFace for full URL") |
|
parser.add_argument("-c_model", "--custom_model", |
|
type=str, help="this argument takes a custom " |
|
"pretrained checkpoint") |
|
parser.add_argument("-train", "--train_data", default='training_data10k.txt', |
|
type=str, help="this argument takes the train " |
|
"data file as input") |
|
parser.add_argument("-dev", "--dev_data", default='validation_data.txt', |
|
type=str, help="this argument takes the dev data file " |
|
"as input") |
|
parser.add_argument("-lr", "--learn_rate", default=5e-5, type=float, |
|
help="Set a custom learn rate for " |
|
"the model, default is 5e-5") |
|
parser.add_argument("-bs", "--batch_size", default=8, type=int, |
|
help="Set a custom batch size for " |
|
"the pretrained language model, default is 8") |
|
parser.add_argument("-sl_train", "--sequence_length_train", default=155, |
|
type=int, help="Set a custom maximum sequence length" |
|
"for the pretrained language model," |
|
"default is 155") |
|
parser.add_argument("-sl_dev", "--sequence_length_dev", default=155, |
|
type=int, help="Set a custom maximum sequence length" |
|
"for the pretrained language model," |
|
"default is 155") |
|
parser.add_argument("-ep", "--epochs", default=1, type=int, |
|
help="This argument selects the amount of epochs " |
|
"to run the model with, default is 1 epoch") |
|
parser.add_argument("-es", "--early_stop", default="val_loss", type=str, |
|
help="Set the value to monitor for earlystopping") |
|
parser.add_argument("-es_p", "--early_stop_patience", default=2, |
|
type=int, help="Set the patience value for " |
|
"earlystopping, default is 2") |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def read_data(data_file): |
|
'''Reading in data files''' |
|
with open(data_file) as file: |
|
data = file.readlines() |
|
|
|
text = [] |
|
for d in data: |
|
text.append(d) |
|
return text |
|
|
|
|
|
def create_data(data): |
|
'''Splitting Alpino format training data into separate |
|
source and target sentences''' |
|
source_text = [] |
|
target_text = [] |
|
for x in data: |
|
source = [] |
|
target = [] |
|
spel = re.findall(r'\[.*?\]', x) |
|
if spel: |
|
for s in spel: |
|
s = s.split() |
|
if s[1] == '@alt': |
|
target.append(''.join(s[2:3])) |
|
source.append(''.join(s[3:-1])) |
|
elif s[1] == '@mwu_alt': |
|
target.append(''.join(s[2:3])) |
|
source.append(''.join(s[3:-1]).replace('-', '')) |
|
elif s[1] == '@mwu': |
|
target.append(''.join(s[2:-1])) |
|
source.append(' '.join(s[2:-1])) |
|
elif s[1] == '@postag': |
|
target.append(''.join(s[-2])) |
|
source.append(''.join(s[-2])) |
|
elif s[1] == '@phantom': |
|
target.append(''.join(s[2])) |
|
source.append('') |
|
|
|
target2 = [] |
|
for t in target: |
|
if t[0] == '~': |
|
t = t.split('~') |
|
target2.append(t[1]) |
|
else: |
|
target2.append(t) |
|
|
|
sent = re.sub(r'\[.*?\]', 'EMPTY', x) |
|
word_c = 0 |
|
src = [] |
|
trg = [] |
|
for word in sent.split(): |
|
if word == 'EMPTY': |
|
src.append(source[word_c]) |
|
trg.append(target2[word_c]) |
|
word_c += 1 |
|
else: |
|
src.append(word) |
|
trg.append(word) |
|
source_text.append(' '.join(src)) |
|
target_text.append(' '.join(trg)) |
|
return source_text, target_text |
|
|
|
|
|
def split_sent(data, max_length): |
|
'''Splitting sentences if longer than given max_length value''' |
|
short_sent = [] |
|
long_sent = [] |
|
for n in data: |
|
n = n.split('|') |
|
if len(n[1]) <= max_length: |
|
short_sent.append(n[1]) |
|
elif len(n[1]) > max_length: |
|
n[1] = re.sub(r'(\s)+(?=[^[]*?\])', '$$', n[1]) |
|
n[1] = n[1].replace("] [", "]##[") |
|
lines = textwrap.wrap(n[1], max_length, break_long_words=False) |
|
long_sent.append(lines) |
|
|
|
new_data = [] |
|
for s in long_sent: |
|
for s1 in s: |
|
s1 = s1.replace(']##[', '] [') |
|
s1 = s1.replace('$$', ' ') |
|
s2 = s1.split() |
|
if len(s2) > 2: |
|
new_data.append(s1) |
|
|
|
for x in short_sent: |
|
new_data.append(x) |
|
return new_data |
|
|
|
|
|
def preprocess_function(tk, s, t): |
|
'''tokenizing text and labels''' |
|
model_inputs = tk(s) |
|
|
|
with tk.as_target_tokenizer(): |
|
labels = tk(t) |
|
|
|
model_inputs["labels"] = labels["input_ids"] |
|
model_inputs["decoder_attention_mask"] = labels["attention_mask"] |
|
return model_inputs |
|
|
|
|
|
def convert_tok(tok, sl): |
|
'''Convert tokenized object to Tensors and add padding''' |
|
input_ids = [] |
|
attention_mask = [] |
|
labels = [] |
|
decoder_attention_mask = [] |
|
for a, b, c, d in zip(tok['input_ids'], tok['attention_mask'], tok['labels'], |
|
tok['decoder_attention_mask']): |
|
input_ids.append(a) |
|
attention_mask.append(b) |
|
labels.append(c) |
|
decoder_attention_mask.append(d) |
|
|
|
input_ids_pad = pad_sequences(input_ids, padding='post', maxlen=sl) |
|
attention_mask_pad = pad_sequences(attention_mask, padding='post', |
|
maxlen=sl) |
|
labels_pad = pad_sequences(labels, padding='post', maxlen=sl) |
|
dec_attention_mask_pad = pad_sequences(decoder_attention_mask, |
|
padding='post', maxlen=sl) |
|
return {'input_ids': tf.constant(input_ids_pad), 'attention_mask': |
|
tf.constant(attention_mask_pad), 'labels': tf.constant(labels_pad), |
|
'decoder_attention_mask': tf.constant(dec_attention_mask_pad)} |
|
|
|
|
|
def train_model(model_name, lr, bs, sl_train, sl_dev, ep, es, es_p, train, dev): |
|
'''Finetune and save a given T5 version with given parameters''' |
|
print('Training model: {}\nWith parameters:\nLearn rate: {}, ' |
|
'Batch size: {}\nSequence length train: {}, sequence length dev: {}\n' |
|
'Epochs: {}'.format(model_name, lr, bs, sl_train, sl_dev, ep)) |
|
|
|
tk = AutoTokenizer.from_pretrained(model_name) |
|
|
|
args = create_arg_parser() |
|
source_train, target_train = create_data(train) |
|
source_test, target_test = create_data(dev) |
|
|
|
if args.custom_model: |
|
model = TFAutoModelForSeq2SeqLM.from_pretrained(args.custom_model, |
|
from_pt=True) |
|
else: |
|
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
train_tok = preprocess_function(tk, source_train, target_train) |
|
dev_tok = preprocess_function(tk, source_test, target_test) |
|
|
|
tf_train = convert_tok(train_tok, sl_train) |
|
tf_dev = convert_tok(dev_tok, sl_dev) |
|
|
|
optim = AdamWeightDecay(learning_rate=lr) |
|
model.compile(optimizer=optim, loss=custom_loss, |
|
metrics=[accuracy]) |
|
ear_stop = tf.keras.callbacks.EarlyStopping(monitor=es, patience=es_p, |
|
restore_best_weights=True, |
|
mode="auto") |
|
model.fit(tf_train, validation_data=tf_dev, epochs=ep, |
|
batch_size=bs, callbacks=[ear_stop]) |
|
model.save_weights('{}_weights.h5'.format(model_name[7:])) |
|
return model |
|
|
|
|
|
def custom_loss(y_true, y_pred): |
|
'''Custom loss function''' |
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( |
|
from_logits=True, reduction='none') |
|
loss = loss_fn(y_true, y_pred) |
|
|
|
mask = tf.cast(y_true != 0, loss.dtype) |
|
loss *= mask |
|
return tf.reduce_sum(loss)/tf.reduce_sum(mask) |
|
|
|
|
|
def accuracy(y_true, y_pred): |
|
'''Custom accuracy function ''' |
|
y_pred = tf.argmax(y_pred, axis=-1) |
|
y_pred = tf.cast(y_pred, y_true.dtype) |
|
|
|
match = tf.cast(y_true == y_pred, tf.float32) |
|
mask = tf.cast(y_true != 0, tf.float32) |
|
return tf.reduce_sum(match)/tf.reduce_sum(mask) |
|
|
|
|
|
def main(): |
|
args = create_arg_parser() |
|
|
|
lr = args.learn_rate |
|
bs = args.batch_size |
|
sl_train = args.sequence_length_train |
|
sl_dev = args.sequence_length_dev |
|
split_length_train = (sl_train - 5) |
|
split_length_dev = (sl_dev - 5) |
|
ep = args.epochs |
|
|
|
if args.transformer == 'google/flan-t5-small': |
|
model_name = 'google/flan-t5-small' |
|
elif args.transformer == 'google/byt5-small': |
|
model_name = 'google/byt5-small' |
|
elif args.transformer == 'google/mt5-small': |
|
model_name = 'google/mt5-small' |
|
else: |
|
model_name = 'Unknown' |
|
|
|
early_stop = args.early_stop |
|
patience = args.early_stop_patience |
|
|
|
train_d = read_data(args.train_data) |
|
dev_d = read_data(args.dev_data) |
|
train_data = split_sent(train_d, split_length_train) |
|
dev_data = split_sent(dev_d, split_length_dev) |
|
|
|
print('Train size: {}\nDev size: {}\n'.format(len(train_data), |
|
len(dev_data))) |
|
print(train_model(model_name, lr, bs, sl_train, sl_dev, |
|
ep, early_stop, patience, train_data, dev_data)) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|