rpunct-gr-app / myrpunct /punctuate.py
wldmr's picture
Update myrpunct/punctuate.py
4a97738
# -*- coding: utf-8 -*-
# ๐Ÿ’พโš™๏ธ๐Ÿ”ฎ
__author__ = "Daulet N."
__email__ = "[email protected]"
import logging
from langdetect import detect
from simpletransformers.ner import NERModel, NERArgs
class RestorePuncts:
def __init__(self, wrds_per_pred=250, use_cuda=False):
self.wrds_per_pred = wrds_per_pred
self.overlap_wrds = 30
self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U']
self.model_hf = "felflare/bert-restore-punctuation"
self.model_args = NERArgs()
self.model_args.silent = True
self.model_args.max_seq_length = 512
#self.model_args.use_multiprocessing = False
self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args=self.model_args)
#self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args={"silent": True, "max_seq_length": 512, "use_multiprocessing": False})
print("class init ...")
print("use_multiprocessing: ",self.model_args.use_multiprocessing)
def status(self):
print("function called")
def punctuate(self, text: str, lang:str=''):
"""
Performs punctuation restoration on arbitrarily large text.
Detects if input is not English, if non-English was detected terminates predictions.
Overrride by supplying `lang='en'`
Args:
- text (str): Text to punctuate, can be few words to as large as you want.
- lang (str): Explicit language of input text.
"""
if not lang and len(text) > 10:
lang = detect(text)
if lang != 'en':
raise Exception(F"""Non English text detected. Restore Punctuation works only for English.
If you are certain the input is English, pass argument lang='en' to this function.
Punctuate received: {text}""")
# plit up large text into bert digestable chunks
splits = self.split_on_toks(text, self.wrds_per_pred, self.overlap_wrds)
# predict slices
# full_preds_lst contains tuple of labels and logits
full_preds_lst = [self.predict(i['text']) for i in splits]
# extract predictions, and discard logits
preds_lst = [i[0][0] for i in full_preds_lst]
# join text slices
combined_preds = self.combine_results(text, preds_lst)
# create punctuated prediction
punct_text = self.punctuate_texts(combined_preds)
return punct_text
def predict(self, input_slice):
"""
Passes the unpunctuated text to the model for punctuation.
"""
predictions, raw_outputs = self.model.predict([input_slice])
return predictions, raw_outputs
@staticmethod
def split_on_toks(text, length, overlap):
"""
Splits text into predefined slices of overlapping text with indexes (offsets)
that tie-back to original text.
This is done to bypass 512 token limit on transformer models by sequentially
feeding chunks of < 512 toks.
Example output:
[{...}, {"text": "...", 'start_idx': 31354, 'end_idx': 32648}, {...}]
"""
wrds = text.replace('\n', ' ').split(" ")
resp = []
lst_chunk_idx = 0
i = 0
while True:
# words in the chunk and the overlapping portion
wrds_len = wrds[(length * i):(length * (i + 1))]
wrds_ovlp = wrds[(length * (i + 1)):((length * (i + 1)) + overlap)]
wrds_split = wrds_len + wrds_ovlp
# Break loop if no more words
if not wrds_split:
break
wrds_str = " ".join(wrds_split)
nxt_chunk_start_idx = len(" ".join(wrds_len))
lst_char_idx = len(" ".join(wrds_split))
resp_obj = {
"text": wrds_str,
"start_idx": lst_chunk_idx,
"end_idx": lst_char_idx + lst_chunk_idx,
}
resp.append(resp_obj)
lst_chunk_idx += nxt_chunk_start_idx + 1
i += 1
logging.info(f"Sliced transcript into {len(resp)} slices.")
return resp
@staticmethod
def combine_results(full_text: str, text_slices):
"""
Given a full text and predictions of each slice combines predictions into a single text again.
Performs validataion wether text was combined correctly
"""
split_full_text = full_text.replace('\n', ' ').split(" ")
split_full_text = [i for i in split_full_text if i]
split_full_text_len = len(split_full_text)
output_text = []
index = 0
if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
text_slices = text_slices[:-1]
for _slice in text_slices:
slice_wrds = len(_slice)
for ix, wrd in enumerate(_slice):
# print(index, "|", str(list(wrd.keys())[0]), "|", split_full_text[index])
if index == split_full_text_len:
break
if split_full_text[index] == str(list(wrd.keys())[0]) and \
ix <= slice_wrds - 3 and text_slices[-1] != _slice:
index += 1
pred_item_tuple = list(wrd.items())[0]
output_text.append(pred_item_tuple)
elif split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == _slice:
index += 1
pred_item_tuple = list(wrd.items())[0]
output_text.append(pred_item_tuple)
assert [i[0] for i in output_text] == split_full_text
return output_text
@staticmethod
def punctuate_texts(full_pred: list):
"""
Given a list of Predictions from the model, applies the predictions to text,
thus punctuating it.
"""
punct_resp = ""
for i in full_pred:
word, label = i
if label[-1] == "U":
punct_wrd = word.capitalize()
else:
punct_wrd = word
if label[0] != "O":
punct_wrd += label[0]
punct_resp += punct_wrd + " "
punct_resp = punct_resp.strip()
# Append trailing period if doesnt exist.
if punct_resp[-1].isalnum():
punct_resp += "."
return punct_resp
if __name__ == "__main__":
punct_model = RestorePuncts()
# read test file
with open('../tests/sample_text.txt', 'r') as fp:
test_sample = fp.read()
# predict text and print
punctuated = punct_model.punctuate(test_sample)
print(punctuated)