from typing import List, Optional, Tuple import numpy as np import torch import torch.nn.functional as F from transformers import BertTokenizer from src.models import BertForPunctuation PUNCTUATION_SIGNS = ['', ',', '.', '?'] PAUSE_TOKEN = 0 MODEL_NAME = "verbit/hebrew_punctuation" def tokenize_text( word_list: List[str], pause_list: List[float], tokenizer: BertTokenizer ) -> Tuple[List[int], List[int], List[float]]: """ Tokenizes text and generates pause list for each word Args: word_list: list of words pause_list: list of pauses after each word in seconds tokenizer: tokenizer Returns: original_word_idx: list of indexes of original words x: list of indexed words pause: list of pauses after each word in seconds """ assert len(word_list) == len(pause_list), "word_list and pause_list should have the same length" x, pause = [], [] # when we do tokenization the number of tokens might be more than one for single word, so we need to keep # mapping tokens into real words original_word_idx = [] for w, p in zip(word_list, pause_list): tokens = tokenizer.tokenize(w) p = [p] # converting tokens to idx, if we have no token for current word then just pad it with 0 to be safe _x = tokenizer.convert_tokens_to_ids(tokens) if tokens else [0] if len(_x) > 1: p = (len(_x) - 1) * [0] + p x += _x original_word_idx.append(len(x) - 1) pause += p return original_word_idx, x, pause def gen_model_inputs( x: List[int], pause: List[float], forward_context: int, backward_context: int, ) -> torch.Tensor: """ Generates inputs for model out of list of indexed words. Inserts a pause token into the segment Args: x: list of indexed words pause: list of corresponding pauses forward_context: size of the forward context window backward_context: size of the backward context window (without the predicted token)` Returns: A tensor of model inputs for each indexed word in x """ model_input = [] tokenized_pause = [PAUSE_TOKEN] * len(pause) x_pad = [0] * backward_context + x + [0] * forward_context for i in range(len(x)): segment = x_pad[i : i + backward_context + forward_context + 1] segment.insert(backward_context + 1, tokenized_pause[i]) model_input.append(segment) return torch.tensor(model_input) def add_punctuation_to_text(text: str, punct_prob: np.ndarray) -> str: """ Inserts punctuation to text on provided punctuation string for every word Args: text: text to insert punctuation to punct_prob: matrix of probabilities for each punctuation Returns: text with punctuation """ words = text.split() new_words = list() punctuation_idx = np.argmax(punct_prob, axis=1) punctuation_list = [PUNCTUATION_SIGNS[i] for i in punctuation_idx] for word, punctuation_str in zip(words, punctuation_list): if punctuation_str: new_words.append(word + punctuation_str) else: new_words.append(word) punct_text = ' '.join(new_words) return punct_text def get_prediction( model: BertForPunctuation, text: str, tokenizer: BertTokenizer, batch_size: int = 16, backward_context: int = 15, forward_context: int = 16, pause_list: Optional[List[float]] = None, device: str = 'cpu', ) -> str: """ Generates predictions for given list of words. Args: model: punctuation model text: text to predict punctuation for tokenizer: tokenizer batch_size: batch size backward_context: size of the backward context window forward_context: size of the forward context window pause_list: list of pauses after each word in seconds device: device to run model on Returns: text with punctuation """ word_list = text.split() if not pause_list: # make default pauses if pauses are not provided pause_list = [0.0] * len(word_list) word_idx, x, pause = tokenize_text(word_list=word_list, pause_list=pause_list, tokenizer=tokenizer) model_inputs = gen_model_inputs(x, pause, forward_context, backward_context) model_inputs = model_inputs.index_select(0, torch.LongTensor(word_idx)).to(device) inputs_length = len(model_inputs) output = [] with torch.no_grad(): for ndx in range(0, inputs_length, batch_size): o = model(model_inputs[ndx : min(ndx + batch_size, inputs_length)]) o = F.softmax(o, dim=1) output.append(o.cpu().data.numpy()) punct_probabilities_matrix = np.concatenate(output, axis=0) punct_text = add_punctuation_to_text(text, punct_probabilities_matrix) return punct_text def main(): model = BertForPunctuation.from_pretrained(MODEL_NAME) tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) model.eval() text = """חברת ורביט פיתחה מערכת לתמלול המבוססת על בינה מלאכותית וגורם אנושי ושוקדת על תמלול עדויות ניצולי שואה את התוצאות אפשר לראות כבר ברשת בהן חלקים מעדותו של טוביה ביילסקי שהיה מפקד גדוד הפרטיזנים היהודים בביילורוסיה""" punct_text = get_prediction( model=model, text=text, tokenizer=tokenizer, backward_context=model.config.backward_context, forward_context=model.config.forward_context, ) print(punct_text) if __name__ == "__main__": main()