samarthsrivastava's picture
Upload folder using huggingface_hub
787a546 verified
raw
history blame
28.6 kB
import sys
import collections
import os
import regex as re
import re
#from mosestokenizer import *
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import unicodedata
import numpy as np
import argparse
from torch.utils.data import TensorDataset, DataLoader
from transformers import AutoModel, AutoTokenizer, BertTokenizer
default_config = argparse.Namespace(
seed=871253,
lang='en',
#flavor='flaubert/flaubert_base_uncased',
flavor=None,
max_length=256,
batch_size=16,
updates=24000,
period=1000,
lr=1e-5,
dab_rate=0.1,
device='cuda',
debug=False
)
default_flavors = {
'fr': 'flaubert/flaubert_base_uncased',
'en': 'bert-base-uncased',
'zh': 'ckiplab/bert-base-chinese',
'tr': 'dbmdz/bert-base-turkish-uncased',
'de': 'dbmdz/bert-base-german-uncased',
'pt': 'neuralmind/bert-base-portuguese-cased'
}
class Config(argparse.Namespace):
def __init__(self, **kwargs):
for key, value in default_config.__dict__.items():
setattr(self, key, value)
for key, value in kwargs.items():
setattr(self, key, value)
assert self.lang in ['fr', 'en', 'zh', 'tr', 'pt', 'de']
if 'lang' in kwargs and ('flavor' not in kwargs or kwargs['flavor'] is None):
self.flavor = default_flavors[self.lang]
#print(self.lang, self.flavor)
def init_random(seed):
# make sure everything is deterministic
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
#torch.use_deterministic_algorithms(True)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
# NOTE: it is assumed in the implementation that y[:,0] is the punctuation label, and y[:,1] is the case label!
punctuation = {
'O': 0,
'COMMA': 1,
'PERIOD': 2,
'QUESTION': 3,
'EXCLAMATION': 4,
}
punctuation_syms = ['', ',', '.', ' ?', ' !']
case = {
'LOWER': 0,
'UPPER': 1,
'CAPITALIZE': 2,
'OTHER': 3,
}
class Model(nn.Module):
def __init__(self, flavor, device):
super().__init__()
self.bert = AutoModel.from_pretrained(flavor)
# need a proper way of determining representation size
size = self.bert.dim if hasattr(self.bert, 'dim') else self.bert.config.pooler_fc_size if hasattr(self.bert.config, 'pooler_fc_size') else self.bert.config.emb_dim if hasattr(self.bert.config, 'emb_dim') else self.bert.config.hidden_size
self.punc = nn.Linear(size, 5)
self.case = nn.Linear(size, 4)
self.dropout = nn.Dropout(0.3)
self.to(device)
def forward(self, x):
output = self.bert(x)
representations = self.dropout(F.gelu(output['last_hidden_state']))
punc = self.punc(representations)
case = self.case(representations)
return punc, case
# randomly create sequences that align to punctuation boundaries
def drop_at_boundaries(rate, x, y, cls_token_id, sep_token_id, pad_token_id):
for i, dropped in enumerate(torch.rand((len(x),)) < rate):
if dropped:
# select all indices that are sentence endings
indices = (y[i,:,0] > 1).nonzero(as_tuple=True)[0]
if len(indices) < 2:
continue
start = indices[0] + 1
end = indices[random.randint(1, len(indices) - 1)] + 1
length = end - start
if length + 2 > len(x[i]):
continue
x[i, 0] = cls_token_id
x[i, 1: length + 1] = x[i, start: end].clone()
x[i, length + 1] = sep_token_id
x[i, length + 2:] = pad_token_id
y[i, 0] = 0
y[i, 1: length + 1] = y[i, start: end].clone()
y[i, length + 1:] = 0
def compute_performance(config, model, loader):
device = config.device
criterion = nn.CrossEntropyLoss()
model.eval()
total_loss = all_correct1 = all_correct2 = num_loss = num_perf = 0
num_ref = collections.defaultdict(float)
num_hyp = collections.defaultdict(float)
num_correct = collections.defaultdict(float)
for x, y in loader:
x = x.long().to(device)
y = y.long().to(device)
y1 = y[:,:,0]
y2 = y[:,:,1]
with torch.no_grad():
y_scores1, y_scores2 = model(x.to(device))
loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1)))
loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1)))
loss = loss1 + loss2
y_pred1 = torch.max(y_scores1, 2)[1]
y_pred2 = torch.max(y_scores2, 2)[1]
for label in range(1, 5):
ref = (y1 == label)
hyp = (y_pred1 == label)
correct = (ref * hyp == 1)
num_ref[label] += ref.sum()
num_hyp[label] += hyp.sum()
num_correct[label] += correct.sum()
num_ref[0] += ref.sum()
num_hyp[0] += hyp.sum()
num_correct[0] += correct.sum()
all_correct1 += (y_pred1 == y1).sum()
all_correct2 += (y_pred2 == y2).sum()
total_loss += loss.item()
num_loss += len(y)
num_perf += len(y) * config.max_length
recall = {}
precision = {}
fscore = {}
for label in range(0, 5):
recall[label] = num_correct[label] / num_ref[label] if num_ref[label] > 0 else 0
precision[label] = num_correct[label] / num_hyp[label] if num_hyp[label] > 0 else 0
fscore[label] = (2 * recall[label] * precision[label] / (recall[label] + precision[label])).item() if recall[label] + precision[label] > 0 else 0
return total_loss / num_loss, all_correct2.item() / num_perf, all_correct1.item() / num_perf, fscore
def fit(config, model, checkpoint_path, train_loader, valid_loader, iterations, valid_period=200, lr=1e-5):
device = config.device
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=lr)
iteration = 0
while True:
model.train()
total_loss = num = 0
for x, y in tqdm(train_loader):
x = x.long().to(device)
y = y.long().to(device)
drop_at_boundaries(config.dab_rate, x, y, config.cls_token_id, config.sep_token_id, config.pad_token_id)
y1 = y[:,:,0]
y2 = y[:,:,1]
optimizer.zero_grad()
y_scores1, y_scores2 = model(x)
loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1)))
loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1)))
loss = loss1 + loss2
loss.backward()
optimizer.step()
total_loss += loss.item()
num += len(y)
if iteration % valid_period == valid_period - 1:
train_loss = total_loss / num
valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore = compute_performance(config, model, valid_loader)
torch.save({
'iteration': iteration + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'valid_loss': valid_loss,
'valid_accuracy_case': valid_accuracy_case,
'valid_accuracy_punc': valid_accuracy_punc,
'valid_fscore': valid_fscore,
'config': config.__dict__,
}, '%s.%d' % (checkpoint_path, iteration + 1))
print(iteration + 1, train_loss, valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore)
total_loss = num = 0
iteration += 1
if iteration > iterations:
return
sys.stderr.flush()
sys.stdout.flush()
def batchify(max_length, x, y):
print (x.shape)
print (y.shape)
x = x[:(len(x) // max_length) * max_length].reshape(-1, max_length)
y = y[:(len(y) // max_length) * max_length, :].reshape(-1, max_length, 2)
return x, y
def train(config, train_x_fn, train_y_fn, valid_x_fn, valid_y_fn, checkpoint_path):
X_train, Y_train = batchify(config.max_length, torch.load(train_x_fn), torch.load(train_y_fn))
X_valid, Y_valid = batchify(config.max_length, torch.load(valid_x_fn), torch.load(valid_y_fn))
train_set = TensorDataset(X_train, Y_train)
valid_set = TensorDataset(X_valid, Y_valid)
train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=config.batch_size)
model = Model(config.flavor, config.device)
fit(config, model, checkpoint_path, train_loader, valid_loader, config.updates, config.period, config.lr)
def run_eval(config, test_x_fn, test_y_fn, checkpoint_path):
X_test, Y_test = batchify(config.max_length, torch.load(test_x_fn), torch.load(test_y_fn))
test_set = TensorDataset(X_test, Y_test)
test_loader = DataLoader(test_set, batch_size=config.batch_size)
loaded = torch.load(checkpoint_path, map_location=config.device)
if 'config' in loaded:
config = Config(**loaded['config'])
init(config)
model = Model(config.flavor, config.device)
model.load_state_dict(loaded['model_state_dict'], strict=False)
print(*compute_performance(config, model, test_loader))
def recase(token, label):
if label == case['LOWER']:
return token.lower()
elif label == case['CAPITALIZE']:
return token.lower().capitalize()
elif label == case['UPPER']:
return token.upper()
else:
return token
class CasePuncPredictor:
def __init__(self, checkpoint_path, lang=default_config.lang, flavor=default_config.flavor, device=default_config.device):
loaded = torch.load(checkpoint_path, map_location=device if torch.cuda.is_available() else 'cpu')
if 'config' in loaded:
self.config = Config(**loaded['config'])
else:
self.config = Config(lang=lang, flavor=flavor, device=device)
init(self.config)
self.model = Model(self.config.flavor, self.config.device)
self.model.load_state_dict(loaded['model_state_dict'])
self.model.eval()
self.model.to(self.config.device)
self.rev_case = {b: a for a, b in case.items()}
self.rev_punc = {b: a for a, b in punctuation.items()}
def tokenize(self, text):
return [self.config.cls_token] + self.config.tokenizer.tokenize(text) + [self.config.sep_token]
def predict(self, tokens, getter=lambda x: x):
max_length = self.config.max_length
device = self.config.device
if type(tokens) == str:
tokens = self.tokenize(tokens)
previous_label = punctuation['PERIOD']
for start in range(0, len(tokens), max_length):
instance = tokens[start: start + max_length]
if type(getter(instance[0])) == str:
ids = self.config.tokenizer.convert_tokens_to_ids(getter(token) for token in instance)
else:
ids = [getter(token) for token in instance]
if len(ids) < max_length:
ids += [0] * (max_length - len(ids))
x = torch.tensor([ids]).long().to(device)
y_scores1, y_scores2 = self.model(x)
y_pred1 = torch.max(y_scores1, 2)[1]
y_pred2 = torch.max(y_scores2, 2)[1]
for i, id, token, punc_label, case_label in zip(range(len(instance)), ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
if id == self.config.cls_token_id or id == self.config.sep_token_id:
continue
if previous_label != None and previous_label > 1:
if case_label in [case['LOWER'], case['OTHER']]: # LOWER, OTHER
case_label = case['CAPITALIZE']
if i + start == len(tokens) - 2 and punc_label == punctuation['O']:
punc_label = punctuation['PERIOD']
yield (token, self.rev_case[case_label], self.rev_punc[punc_label])
previous_label = punc_label
def map_case_label(self, token, case_label):
if token.endswith('</w>'):
token = token[:-4]
if token.startswith('##'):
token = token[2:]
return recase(token, case[case_label])
def map_punc_label(self, token, punc_label):
if token.endswith('</w>'):
token = token[:-4]
if token.startswith('##'):
token = token[2:]
return token + punctuation_syms[punctuation[punc_label]]
def generate_predictions(config, checkpoint_path):
loaded = torch.load(checkpoint_path, map_location=config.device if torch.cuda.is_available() else 'cpu')
if 'config' in loaded:
config = Config(**loaded['config'])
init(config)
model = Model(config.flavor, config.device)
model.load_state_dict(loaded['model_state_dict'], strict=False)
rev_case = {b: a for a, b in case.items()}
rev_punc = {b: a for a, b in punctuation.items()}
for line in sys.stdin:
# also drop punctuation that we may generate
line = ''.join([c for c in line if c not in mapped_punctuation])
if config.debug:
print(line)
tokens = [config.cls_token] + config.tokenizer.tokenize(line) + [config.sep_token]
if config.debug:
print(tokens)
previous_label = punctuation['PERIOD']
first_time = True
was_word = False
for start in range(0, len(tokens), config.max_length):
instance = tokens[start: start + config.max_length]
ids = config.tokenizer.convert_tokens_to_ids(instance)
#print(len(ids), file=sys.stderr)
if len(ids) < config.max_length:
ids += [config.pad_token_id] * (config.max_length - len(ids))
x = torch.tensor([ids]).long().to(config.device)
y_scores1, y_scores2 = model(x)
y_pred1 = torch.max(y_scores1, 2)[1]
y_pred2 = torch.max(y_scores2, 2)[1]
for id, token, punc_label, case_label in zip(ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
if config.debug:
print(id, token, punc_label, case_label, file=sys.stderr)
if id == config.cls_token_id or id == config.sep_token_id:
continue
if previous_label != None and previous_label > 1:
if case_label in [case['LOWER'], case['OTHER']]:
case_label = case['CAPITALIZE']
previous_label = punc_label
# different strategy due to sub-lexical token encoding in Flaubert
if config.lang == 'fr':
if token.endswith('</w>'):
cased_token = recase(token[:-4], case_label)
if was_word:
print(' ', end='')
print(cased_token + punctuation_syms[punc_label], end='')
was_word = True
else:
cased_token = recase(token, case_label)
if was_word:
print(' ', end='')
print(cased_token, end='')
was_word = False
else:
if token.startswith('##'):
cased_token = recase(token[2:], case_label)
print(cased_token, end='')
else:
cased_token = recase(token, case_label)
if not first_time:
print(' ', end='')
first_time = False
print(cased_token + punctuation_syms[punc_label], end='')
if previous_label == 0:
print('.', end='')
print()
def label_for_case(token):
token = re.sub(r'[^\p{Han}\p{Ll}\p{Lu}]', '', token)
if token == token.lower():
return 'LOWER'
elif token == token.lower().capitalize():
return 'CAPITALIZE'
elif token == token.upper():
return 'UPPER'
else:
return 'OTHER'
def make_tensors(config, input_fn, output_x_fn, output_y_fn):
# count file lines without loading them
size = 0
with open(input_fn) as fp:
for line in fp:
size += 1
with open(input_fn) as fp:
X = torch.IntTensor(size)
Y = torch.ByteTensor(size, 2)
offset = 0
for n, line in enumerate(fp):
word, case_label, punc_label = line.strip().split('\t')
id = config.tokenizer.convert_tokens_to_ids(word)
if config.debug:
assert word.lower() == tokenizer.convert_ids_to_tokens(id)
X[offset] = id
Y[offset, 0] = punctuation[punc_label]
Y[offset, 1] = case[case_label]
offset += 1
torch.save(X, output_x_fn)
torch.save(Y, output_y_fn)
mapped_punctuation = {
'.': 'PERIOD',
'...': 'PERIOD',
',': 'COMMA',
';': 'COMMA',
':': 'COMMA',
'(': 'COMMA',
')': 'COMMA',
'?': 'QUESTION',
'!': 'EXCLAMATION',
',': 'COMMA',
'!': 'EXCLAMATION',
'?': 'QUESTION',
';': 'COMMA',
':': 'COMMA',
'(': 'COMMA',
'(': 'COMMA',
')': 'COMMA',
'[': 'COMMA',
']': 'COMMA',
'【': 'COMMA',
'】': 'COMMA',
'└': 'COMMA',
'└ ': 'COMMA',
'_': 'O',
'。': 'PERIOD',
'、': 'COMMA', # enumeration comma
'、': 'COMMA',
'…': 'PERIOD',
'—': 'COMMA',
'「': 'COMMA',
'」': 'COMMA',
'.': 'PERIOD',
'《': 'O',
'》': 'O',
',': 'COMMA',
'“': 'O',
'”': 'O',
'"': 'O',
'-': 'O',
'-': 'O',
'〉': 'COMMA',
'〈': 'COMMA',
'↑': 'O',
'〔': 'COMMA',
'〕': 'COMMA',
}
def preprocess_text(config, max_token_count=-1):
global num_tokens_output
max_token_count = int(max_token_count)
num_tokens_output = 0
def process_segment(text, punctuation):
global num_tokens_output
text = text.replace('\t', ' ')
tokens = config.tokenizer.tokenize(text)
for i, token in enumerate(tokens):
case_label = label_for_case(token)
if i == len(tokens) - 1:
print(token.lower(), case_label, punctuation, sep='\t')
else:
print(token.lower(), case_label, 'O', sep='\t')
num_tokens_output += 1
# a bit too ugly, but alternative is to throw an exception
if max_token_count > 0 and num_tokens_output >= max_token_count:
sys.exit(0)
for line in sys.stdin:
line = line.strip()
if line != '':
line = unicodedata.normalize("NFC", line)
if config.debug:
print(line)
start = 0
for i, char in enumerate(line):
if char in mapped_punctuation:
if i > start and line[start: i].strip() != '':
process_segment(line[start: i], mapped_punctuation[char])
start = i + 1
if start < len(line):
process_segment(line[start:], 'PERIOD')
def preprocess_text_old_fr(config):
assert config.lang == 'fr'
splitsents = MosesSentenceSplitter(lang)
tokenize = MosesTokenizer(lang, extra=['-no-escape'])
normalize = MosesPunctuationNormalizer(lang)
for line in sys.stdin:
if line.strip() != '':
for sentence in splitsents([normalize(line)]):
tokens = tokenize(sentence)
previous_token = None
for token in tokens:
if token in mapped_punctuation:
if previous_token != None:
print(previous_token, mapped_punctuation[token], sep='\t')
previous_token = None
elif not re.search(r'[\p{Han}\p{Ll}\p{Lu}\d]', token): # remove non-alphanumeric tokens
continue
else:
if previous_token != None:
print(previous_token, 'O', sep='\t')
previous_token = token
if previous_token != None:
print(previous_token, 'PERIOD', sep='\t')
# modification of the wordpiece tokenizer to keep case information even if vocab is lower cased
# forked from https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token, max_input_chars_per_word=100, keep_case=True):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
self.keep_case = keep_case
def tokenize(self, text):
"""
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
tokenization using the given vocabulary.
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer`.
Returns:
A list of wordpiece tokens.
"""
output_tokens = []
for token in text.strip().split():
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
# optionaly lowercase substring before checking for inclusion in vocab
if (self.keep_case and substr.lower() in self.vocab) or (substr in self.vocab):
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
# modification of XLM bpe tokenizer for keeping case information when vocab is lowercase
# forked from https://github.com/huggingface/transformers/blob/cd56f3fe7eae4a53a9880e3f5e8f91877a78271c/src/transformers/models/xlm/tokenization_xlm.py
def bpe(self, token):
def to_lower(pair):
#print(' ',pair)
return (pair[0].lower(), pair[1].lower())
from transformers.models.xlm.tokenization_xlm import get_pairs
word = tuple(token[:-1]) + (token[-1] + "</w>",)
if token in self.cache:
return self.cache[token]
pairs = get_pairs(word)
if not pairs:
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(to_lower(pair), float("inf")))
#print(bigram)
if to_lower(bigram) not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
if word == "\n </w>":
word = "\n</w>"
self.cache[token] = word
return word
def init(config):
init_random(config.seed)
if config.lang == 'fr':
config.tokenizer = tokenizer = AutoTokenizer.from_pretrained(config.flavor, do_lower_case=False)
from transformers.models.xlm.tokenization_xlm import XLMTokenizer
assert isinstance(tokenizer, XLMTokenizer)
# monkey patch XLM tokenizer
import types
tokenizer.bpe = types.MethodType(bpe, tokenizer)
else:
# warning: needs to be BertTokenizer for monkey patching to work
config.tokenizer = tokenizer = BertTokenizer.from_pretrained(config.flavor, do_lower_case=False)
# warning: monkey patch tokenizer to keep case information
#from recasing_tokenizer import WordpieceTokenizer
config.tokenizer.wordpiece_tokenizer = WordpieceTokenizer(vocab=tokenizer.vocab, unk_token=tokenizer.unk_token)
if config.lang == 'fr':
config.pad_token_id = tokenizer.pad_token_id
config.cls_token_id = tokenizer.bos_token_id
config.cls_token = tokenizer.bos_token
config.sep_token_id = tokenizer.sep_token_id
config.sep_token = tokenizer.sep_token
else:
config.pad_token_id = tokenizer.pad_token_id
config.cls_token_id = tokenizer.cls_token_id
config.cls_token = tokenizer.cls_token
config.sep_token_id = tokenizer.sep_token_id
config.sep_token = tokenizer.sep_token
if not torch.cuda.is_available() and config.device == 'cuda':
print('WARNING: reverting to cpu as cuda is not available', file=sys.stderr)
config.device = torch.device(config.device if torch.cuda.is_available() else 'cpu')
def main(config, action, args):
init(config)
if action == 'train':
train(config, *args)
elif action == 'eval':
run_eval(config, *args)
elif action == 'predict':
generate_predictions(config, *args)
elif action == 'tensorize':
make_tensors(config, *args)
elif action == 'preprocess':
preprocess_text(config, *args)
else:
print('invalid action "%s"' % action)
sys.exit(1)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("action", help="train|eval|predict|tensorize|preprocess", type=str)
parser.add_argument("action_args", help="arguments for selected action", type=str, nargs='*')
parser.add_argument("--seed", help="random seed", default=default_config.seed, type=int)
parser.add_argument("--lang", help="language (fr, en, zh)", default=default_config.lang, type=str)
parser.add_argument("--flavor", help="bert flavor in transformers model zoo", default=default_config.flavor, type=str)
parser.add_argument("--max-length", help="maximum input length", default=default_config.max_length, type=int)
parser.add_argument("--batch-size", help="size of batches", default=default_config.batch_size, type=int)
parser.add_argument("--device", help="computation device (cuda, cpu)", default=default_config.device, type=str)
parser.add_argument("--debug", help="whether to output more debug info", default=default_config.debug, type=bool)
parser.add_argument("--updates", help="number of training updates to perform", default=default_config.updates, type=bool)
parser.add_argument("--period", help="validation period in updates", default=default_config.period, type=bool)
parser.add_argument("--lr", help="learning rate", default=default_config.lr, type=bool)
parser.add_argument("--dab-rate", help="drop at boundaries rate", default=default_config.dab_rate, type=bool)
config = Config(**parser.parse_args().__dict__)
main(config, config.action, config.action_args)