import sys import collections import os import regex as re import re #from mosestokenizer import * from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import random import unicodedata import numpy as np import argparse from torch.utils.data import TensorDataset, DataLoader from transformers import AutoModel, AutoTokenizer, BertTokenizer default_config = argparse.Namespace( seed=871253, lang='en', #flavor='flaubert/flaubert_base_uncased', flavor=None, max_length=256, batch_size=16, updates=24000, period=1000, lr=1e-5, dab_rate=0.1, device='cuda', debug=False ) default_flavors = { 'fr': 'flaubert/flaubert_base_uncased', 'en': 'bert-base-uncased', 'zh': 'ckiplab/bert-base-chinese', 'tr': 'dbmdz/bert-base-turkish-uncased', 'de': 'dbmdz/bert-base-german-uncased', 'pt': 'neuralmind/bert-base-portuguese-cased' } class Config(argparse.Namespace): def __init__(self, **kwargs): for key, value in default_config.__dict__.items(): setattr(self, key, value) for key, value in kwargs.items(): setattr(self, key, value) assert self.lang in ['fr', 'en', 'zh', 'tr', 'pt', 'de'] if 'lang' in kwargs and ('flavor' not in kwargs or kwargs['flavor'] is None): self.flavor = default_flavors[self.lang] #print(self.lang, self.flavor) def init_random(seed): # make sure everything is deterministic os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' #torch.use_deterministic_algorithms(True) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) # NOTE: it is assumed in the implementation that y[:,0] is the punctuation label, and y[:,1] is the case label! punctuation = { 'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3, 'EXCLAMATION': 4, } punctuation_syms = ['', ',', '.', ' ?', ' !'] case = { 'LOWER': 0, 'UPPER': 1, 'CAPITALIZE': 2, 'OTHER': 3, } class Model(nn.Module): def __init__(self, flavor, device): super().__init__() self.bert = AutoModel.from_pretrained(flavor) # need a proper way of determining representation size size = self.bert.dim if hasattr(self.bert, 'dim') else self.bert.config.pooler_fc_size if hasattr(self.bert.config, 'pooler_fc_size') else self.bert.config.emb_dim if hasattr(self.bert.config, 'emb_dim') else self.bert.config.hidden_size self.punc = nn.Linear(size, 5) self.case = nn.Linear(size, 4) self.dropout = nn.Dropout(0.3) self.to(device) def forward(self, x): output = self.bert(x) representations = self.dropout(F.gelu(output['last_hidden_state'])) punc = self.punc(representations) case = self.case(representations) return punc, case # randomly create sequences that align to punctuation boundaries def drop_at_boundaries(rate, x, y, cls_token_id, sep_token_id, pad_token_id): for i, dropped in enumerate(torch.rand((len(x),)) < rate): if dropped: # select all indices that are sentence endings indices = (y[i,:,0] > 1).nonzero(as_tuple=True)[0] if len(indices) < 2: continue start = indices[0] + 1 end = indices[random.randint(1, len(indices) - 1)] + 1 length = end - start if length + 2 > len(x[i]): continue x[i, 0] = cls_token_id x[i, 1: length + 1] = x[i, start: end].clone() x[i, length + 1] = sep_token_id x[i, length + 2:] = pad_token_id y[i, 0] = 0 y[i, 1: length + 1] = y[i, start: end].clone() y[i, length + 1:] = 0 def compute_performance(config, model, loader): device = config.device criterion = nn.CrossEntropyLoss() model.eval() total_loss = all_correct1 = all_correct2 = num_loss = num_perf = 0 num_ref = collections.defaultdict(float) num_hyp = collections.defaultdict(float) num_correct = collections.defaultdict(float) for x, y in loader: x = x.long().to(device) y = y.long().to(device) y1 = y[:,:,0] y2 = y[:,:,1] with torch.no_grad(): y_scores1, y_scores2 = model(x.to(device)) loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1))) loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1))) loss = loss1 + loss2 y_pred1 = torch.max(y_scores1, 2)[1] y_pred2 = torch.max(y_scores2, 2)[1] for label in range(1, 5): ref = (y1 == label) hyp = (y_pred1 == label) correct = (ref * hyp == 1) num_ref[label] += ref.sum() num_hyp[label] += hyp.sum() num_correct[label] += correct.sum() num_ref[0] += ref.sum() num_hyp[0] += hyp.sum() num_correct[0] += correct.sum() all_correct1 += (y_pred1 == y1).sum() all_correct2 += (y_pred2 == y2).sum() total_loss += loss.item() num_loss += len(y) num_perf += len(y) * config.max_length recall = {} precision = {} fscore = {} for label in range(0, 5): recall[label] = num_correct[label] / num_ref[label] if num_ref[label] > 0 else 0 precision[label] = num_correct[label] / num_hyp[label] if num_hyp[label] > 0 else 0 fscore[label] = (2 * recall[label] * precision[label] / (recall[label] + precision[label])).item() if recall[label] + precision[label] > 0 else 0 return total_loss / num_loss, all_correct2.item() / num_perf, all_correct1.item() / num_perf, fscore def fit(config, model, checkpoint_path, train_loader, valid_loader, iterations, valid_period=200, lr=1e-5): device = config.device criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=lr) iteration = 0 while True: model.train() total_loss = num = 0 for x, y in tqdm(train_loader): x = x.long().to(device) y = y.long().to(device) drop_at_boundaries(config.dab_rate, x, y, config.cls_token_id, config.sep_token_id, config.pad_token_id) y1 = y[:,:,0] y2 = y[:,:,1] optimizer.zero_grad() y_scores1, y_scores2 = model(x) loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1))) loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1))) loss = loss1 + loss2 loss.backward() optimizer.step() total_loss += loss.item() num += len(y) if iteration % valid_period == valid_period - 1: train_loss = total_loss / num valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore = compute_performance(config, model, valid_loader) torch.save({ 'iteration': iteration + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_loss': train_loss, 'valid_loss': valid_loss, 'valid_accuracy_case': valid_accuracy_case, 'valid_accuracy_punc': valid_accuracy_punc, 'valid_fscore': valid_fscore, 'config': config.__dict__, }, '%s.%d' % (checkpoint_path, iteration + 1)) print(iteration + 1, train_loss, valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore) total_loss = num = 0 iteration += 1 if iteration > iterations: return sys.stderr.flush() sys.stdout.flush() def batchify(max_length, x, y): print (x.shape) print (y.shape) x = x[:(len(x) // max_length) * max_length].reshape(-1, max_length) y = y[:(len(y) // max_length) * max_length, :].reshape(-1, max_length, 2) return x, y def train(config, train_x_fn, train_y_fn, valid_x_fn, valid_y_fn, checkpoint_path): X_train, Y_train = batchify(config.max_length, torch.load(train_x_fn), torch.load(train_y_fn)) X_valid, Y_valid = batchify(config.max_length, torch.load(valid_x_fn), torch.load(valid_y_fn)) train_set = TensorDataset(X_train, Y_train) valid_set = TensorDataset(X_valid, Y_valid) train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True) valid_loader = DataLoader(valid_set, batch_size=config.batch_size) model = Model(config.flavor, config.device) fit(config, model, checkpoint_path, train_loader, valid_loader, config.updates, config.period, config.lr) def run_eval(config, test_x_fn, test_y_fn, checkpoint_path): X_test, Y_test = batchify(config.max_length, torch.load(test_x_fn), torch.load(test_y_fn)) test_set = TensorDataset(X_test, Y_test) test_loader = DataLoader(test_set, batch_size=config.batch_size) loaded = torch.load(checkpoint_path, map_location=config.device) if 'config' in loaded: config = Config(**loaded['config']) init(config) model = Model(config.flavor, config.device) model.load_state_dict(loaded['model_state_dict'], strict=False) print(*compute_performance(config, model, test_loader)) def recase(token, label): if label == case['LOWER']: return token.lower() elif label == case['CAPITALIZE']: return token.lower().capitalize() elif label == case['UPPER']: return token.upper() else: return token class CasePuncPredictor: def __init__(self, checkpoint_path, lang=default_config.lang, flavor=default_config.flavor, device=default_config.device): loaded = torch.load(checkpoint_path, map_location=device if torch.cuda.is_available() else 'cpu') if 'config' in loaded: self.config = Config(**loaded['config']) else: self.config = Config(lang=lang, flavor=flavor, device=device) init(self.config) self.model = Model(self.config.flavor, self.config.device) self.model.load_state_dict(loaded['model_state_dict']) self.model.eval() self.model.to(self.config.device) self.rev_case = {b: a for a, b in case.items()} self.rev_punc = {b: a for a, b in punctuation.items()} def tokenize(self, text): return [self.config.cls_token] + self.config.tokenizer.tokenize(text) + [self.config.sep_token] def predict(self, tokens, getter=lambda x: x): max_length = self.config.max_length device = self.config.device if type(tokens) == str: tokens = self.tokenize(tokens) previous_label = punctuation['PERIOD'] for start in range(0, len(tokens), max_length): instance = tokens[start: start + max_length] if type(getter(instance[0])) == str: ids = self.config.tokenizer.convert_tokens_to_ids(getter(token) for token in instance) else: ids = [getter(token) for token in instance] if len(ids) < max_length: ids += [0] * (max_length - len(ids)) x = torch.tensor([ids]).long().to(device) y_scores1, y_scores2 = self.model(x) y_pred1 = torch.max(y_scores1, 2)[1] y_pred2 = torch.max(y_scores2, 2)[1] for i, id, token, punc_label, case_label in zip(range(len(instance)), ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]): if id == self.config.cls_token_id or id == self.config.sep_token_id: continue if previous_label != None and previous_label > 1: if case_label in [case['LOWER'], case['OTHER']]: # LOWER, OTHER case_label = case['CAPITALIZE'] if i + start == len(tokens) - 2 and punc_label == punctuation['O']: punc_label = punctuation['PERIOD'] yield (token, self.rev_case[case_label], self.rev_punc[punc_label]) previous_label = punc_label def map_case_label(self, token, case_label): if token.endswith(''): token = token[:-4] if token.startswith('##'): token = token[2:] return recase(token, case[case_label]) def map_punc_label(self, token, punc_label): if token.endswith(''): token = token[:-4] if token.startswith('##'): token = token[2:] return token + punctuation_syms[punctuation[punc_label]] def generate_predictions(config, checkpoint_path): loaded = torch.load(checkpoint_path, map_location=config.device if torch.cuda.is_available() else 'cpu') if 'config' in loaded: config = Config(**loaded['config']) init(config) model = Model(config.flavor, config.device) model.load_state_dict(loaded['model_state_dict'], strict=False) rev_case = {b: a for a, b in case.items()} rev_punc = {b: a for a, b in punctuation.items()} for line in sys.stdin: # also drop punctuation that we may generate line = ''.join([c for c in line if c not in mapped_punctuation]) if config.debug: print(line) tokens = [config.cls_token] + config.tokenizer.tokenize(line) + [config.sep_token] if config.debug: print(tokens) previous_label = punctuation['PERIOD'] first_time = True was_word = False for start in range(0, len(tokens), config.max_length): instance = tokens[start: start + config.max_length] ids = config.tokenizer.convert_tokens_to_ids(instance) #print(len(ids), file=sys.stderr) if len(ids) < config.max_length: ids += [config.pad_token_id] * (config.max_length - len(ids)) x = torch.tensor([ids]).long().to(config.device) y_scores1, y_scores2 = model(x) y_pred1 = torch.max(y_scores1, 2)[1] y_pred2 = torch.max(y_scores2, 2)[1] for id, token, punc_label, case_label in zip(ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]): if config.debug: print(id, token, punc_label, case_label, file=sys.stderr) if id == config.cls_token_id or id == config.sep_token_id: continue if previous_label != None and previous_label > 1: if case_label in [case['LOWER'], case['OTHER']]: case_label = case['CAPITALIZE'] previous_label = punc_label # different strategy due to sub-lexical token encoding in Flaubert if config.lang == 'fr': if token.endswith(''): cased_token = recase(token[:-4], case_label) if was_word: print(' ', end='') print(cased_token + punctuation_syms[punc_label], end='') was_word = True else: cased_token = recase(token, case_label) if was_word: print(' ', end='') print(cased_token, end='') was_word = False else: if token.startswith('##'): cased_token = recase(token[2:], case_label) print(cased_token, end='') else: cased_token = recase(token, case_label) if not first_time: print(' ', end='') first_time = False print(cased_token + punctuation_syms[punc_label], end='') if previous_label == 0: print('.', end='') print() def label_for_case(token): token = re.sub(r'[^\p{Han}\p{Ll}\p{Lu}]', '', token) if token == token.lower(): return 'LOWER' elif token == token.lower().capitalize(): return 'CAPITALIZE' elif token == token.upper(): return 'UPPER' else: return 'OTHER' def make_tensors(config, input_fn, output_x_fn, output_y_fn): # count file lines without loading them size = 0 with open(input_fn) as fp: for line in fp: size += 1 with open(input_fn) as fp: X = torch.IntTensor(size) Y = torch.ByteTensor(size, 2) offset = 0 for n, line in enumerate(fp): word, case_label, punc_label = line.strip().split('\t') id = config.tokenizer.convert_tokens_to_ids(word) if config.debug: assert word.lower() == tokenizer.convert_ids_to_tokens(id) X[offset] = id Y[offset, 0] = punctuation[punc_label] Y[offset, 1] = case[case_label] offset += 1 torch.save(X, output_x_fn) torch.save(Y, output_y_fn) mapped_punctuation = { '.': 'PERIOD', '...': 'PERIOD', ',': 'COMMA', ';': 'COMMA', ':': 'COMMA', '(': 'COMMA', ')': 'COMMA', '?': 'QUESTION', '!': 'EXCLAMATION', ',': 'COMMA', '!': 'EXCLAMATION', '?': 'QUESTION', ';': 'COMMA', ':': 'COMMA', '(': 'COMMA', '(': 'COMMA', ')': 'COMMA', '[': 'COMMA', ']': 'COMMA', '【': 'COMMA', '】': 'COMMA', '└': 'COMMA', '└ ': 'COMMA', '_': 'O', '。': 'PERIOD', '、': 'COMMA', # enumeration comma '、': 'COMMA', '…': 'PERIOD', '—': 'COMMA', '「': 'COMMA', '」': 'COMMA', '.': 'PERIOD', '《': 'O', '》': 'O', ',': 'COMMA', '“': 'O', '”': 'O', '"': 'O', '-': 'O', '-': 'O', '〉': 'COMMA', '〈': 'COMMA', '↑': 'O', '〔': 'COMMA', '〕': 'COMMA', } def preprocess_text(config, max_token_count=-1): global num_tokens_output max_token_count = int(max_token_count) num_tokens_output = 0 def process_segment(text, punctuation): global num_tokens_output text = text.replace('\t', ' ') tokens = config.tokenizer.tokenize(text) for i, token in enumerate(tokens): case_label = label_for_case(token) if i == len(tokens) - 1: print(token.lower(), case_label, punctuation, sep='\t') else: print(token.lower(), case_label, 'O', sep='\t') num_tokens_output += 1 # a bit too ugly, but alternative is to throw an exception if max_token_count > 0 and num_tokens_output >= max_token_count: sys.exit(0) for line in sys.stdin: line = line.strip() if line != '': line = unicodedata.normalize("NFC", line) if config.debug: print(line) start = 0 for i, char in enumerate(line): if char in mapped_punctuation: if i > start and line[start: i].strip() != '': process_segment(line[start: i], mapped_punctuation[char]) start = i + 1 if start < len(line): process_segment(line[start:], 'PERIOD') def preprocess_text_old_fr(config): assert config.lang == 'fr' splitsents = MosesSentenceSplitter(lang) tokenize = MosesTokenizer(lang, extra=['-no-escape']) normalize = MosesPunctuationNormalizer(lang) for line in sys.stdin: if line.strip() != '': for sentence in splitsents([normalize(line)]): tokens = tokenize(sentence) previous_token = None for token in tokens: if token in mapped_punctuation: if previous_token != None: print(previous_token, mapped_punctuation[token], sep='\t') previous_token = None elif not re.search(r'[\p{Han}\p{Ll}\p{Lu}\d]', token): # remove non-alphanumeric tokens continue else: if previous_token != None: print(previous_token, 'O', sep='\t') previous_token = token if previous_token != None: print(previous_token, 'PERIOD', sep='\t') # modification of the wordpiece tokenizer to keep case information even if vocab is lower cased # forked from https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py class WordpieceTokenizer(object): """Runs WordPiece tokenization.""" def __init__(self, vocab, unk_token, max_input_chars_per_word=100, keep_case=True): self.vocab = vocab self.unk_token = unk_token self.max_input_chars_per_word = max_input_chars_per_word self.keep_case = keep_case def tokenize(self, text): """ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform tokenization using the given vocabulary. For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`. Args: text: A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer`. Returns: A list of wordpiece tokens. """ output_tokens = [] for token in text.strip().split(): chars = list(token) if len(chars) > self.max_input_chars_per_word: output_tokens.append(self.unk_token) continue is_bad = False start = 0 sub_tokens = [] while start < len(chars): end = len(chars) cur_substr = None while start < end: substr = "".join(chars[start:end]) if start > 0: substr = "##" + substr # optionaly lowercase substring before checking for inclusion in vocab if (self.keep_case and substr.lower() in self.vocab) or (substr in self.vocab): cur_substr = substr break end -= 1 if cur_substr is None: is_bad = True break sub_tokens.append(cur_substr) start = end if is_bad: output_tokens.append(self.unk_token) else: output_tokens.extend(sub_tokens) return output_tokens # modification of XLM bpe tokenizer for keeping case information when vocab is lowercase # forked from https://github.com/huggingface/transformers/blob/cd56f3fe7eae4a53a9880e3f5e8f91877a78271c/src/transformers/models/xlm/tokenization_xlm.py def bpe(self, token): def to_lower(pair): #print(' ',pair) return (pair[0].lower(), pair[1].lower()) from transformers.models.xlm.tokenization_xlm import get_pairs word = tuple(token[:-1]) + (token[-1] + "",) if token in self.cache: return self.cache[token] pairs = get_pairs(word) if not pairs: return token + "" while True: bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(to_lower(pair), float("inf"))) #print(bigram) if to_lower(bigram) not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) except ValueError: new_word.extend(word[i:]) break else: new_word.extend(word[i:j]) i = j if word[i] == first and i < len(word) - 1 and word[i + 1] == second: new_word.append(first + second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) word = " ".join(word) if word == "\n ": word = "\n" self.cache[token] = word return word def init(config): init_random(config.seed) if config.lang == 'fr': config.tokenizer = tokenizer = AutoTokenizer.from_pretrained(config.flavor, do_lower_case=False) from transformers.models.xlm.tokenization_xlm import XLMTokenizer assert isinstance(tokenizer, XLMTokenizer) # monkey patch XLM tokenizer import types tokenizer.bpe = types.MethodType(bpe, tokenizer) else: # warning: needs to be BertTokenizer for monkey patching to work config.tokenizer = tokenizer = BertTokenizer.from_pretrained(config.flavor, do_lower_case=False) # warning: monkey patch tokenizer to keep case information #from recasing_tokenizer import WordpieceTokenizer config.tokenizer.wordpiece_tokenizer = WordpieceTokenizer(vocab=tokenizer.vocab, unk_token=tokenizer.unk_token) if config.lang == 'fr': config.pad_token_id = tokenizer.pad_token_id config.cls_token_id = tokenizer.bos_token_id config.cls_token = tokenizer.bos_token config.sep_token_id = tokenizer.sep_token_id config.sep_token = tokenizer.sep_token else: config.pad_token_id = tokenizer.pad_token_id config.cls_token_id = tokenizer.cls_token_id config.cls_token = tokenizer.cls_token config.sep_token_id = tokenizer.sep_token_id config.sep_token = tokenizer.sep_token if not torch.cuda.is_available() and config.device == 'cuda': print('WARNING: reverting to cpu as cuda is not available', file=sys.stderr) config.device = torch.device(config.device if torch.cuda.is_available() else 'cpu') def main(config, action, args): init(config) if action == 'train': train(config, *args) elif action == 'eval': run_eval(config, *args) elif action == 'predict': generate_predictions(config, *args) elif action == 'tensorize': make_tensors(config, *args) elif action == 'preprocess': preprocess_text(config, *args) else: print('invalid action "%s"' % action) sys.exit(1) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("action", help="train|eval|predict|tensorize|preprocess", type=str) parser.add_argument("action_args", help="arguments for selected action", type=str, nargs='*') parser.add_argument("--seed", help="random seed", default=default_config.seed, type=int) parser.add_argument("--lang", help="language (fr, en, zh)", default=default_config.lang, type=str) parser.add_argument("--flavor", help="bert flavor in transformers model zoo", default=default_config.flavor, type=str) parser.add_argument("--max-length", help="maximum input length", default=default_config.max_length, type=int) parser.add_argument("--batch-size", help="size of batches", default=default_config.batch_size, type=int) parser.add_argument("--device", help="computation device (cuda, cpu)", default=default_config.device, type=str) parser.add_argument("--debug", help="whether to output more debug info", default=default_config.debug, type=bool) parser.add_argument("--updates", help="number of training updates to perform", default=default_config.updates, type=bool) parser.add_argument("--period", help="validation period in updates", default=default_config.period, type=bool) parser.add_argument("--lr", help="learning rate", default=default_config.lr, type=bool) parser.add_argument("--dab-rate", help="drop at boundaries rate", default=default_config.dab_rate, type=bool) config = Config(**parser.parse_args().__dict__) main(config, config.action, config.action_args)