|
from typing import List, Optional, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import BertTokenizer |
|
|
|
from src.models import BertForPunctuation |
|
|
|
PUNCTUATION_SIGNS = ['', ',', '.', '?'] |
|
PAUSE_TOKEN = 0 |
|
MODEL_NAME = "verbit/hebrew_punctuation" |
|
|
|
|
|
def tokenize_text( |
|
word_list: List[str], pause_list: List[float], tokenizer: BertTokenizer |
|
) -> Tuple[List[int], List[int], List[float]]: |
|
""" |
|
Tokenizes text and generates pause list for each word |
|
Args: |
|
word_list: list of words |
|
pause_list: list of pauses after each word in seconds |
|
tokenizer: tokenizer |
|
|
|
Returns: |
|
original_word_idx: list of indexes of original words |
|
x: list of indexed words |
|
pause: list of pauses after each word in seconds |
|
""" |
|
assert len(word_list) == len(pause_list), "word_list and pause_list should have the same length" |
|
x, pause = [], [] |
|
|
|
|
|
|
|
original_word_idx = [] |
|
for w, p in zip(word_list, pause_list): |
|
tokens = tokenizer.tokenize(w) |
|
p = [p] |
|
|
|
_x = tokenizer.convert_tokens_to_ids(tokens) if tokens else [0] |
|
|
|
if len(_x) > 1: |
|
p = (len(_x) - 1) * [0] + p |
|
x += _x |
|
original_word_idx.append(len(x) - 1) |
|
pause += p |
|
|
|
return original_word_idx, x, pause |
|
|
|
|
|
def gen_model_inputs( |
|
x: List[int], |
|
pause: List[float], |
|
forward_context: int, |
|
backward_context: int, |
|
) -> torch.Tensor: |
|
""" |
|
Generates inputs for model out of list of indexed words. |
|
Inserts a pause token into the segment |
|
Args: |
|
x: list of indexed words |
|
pause: list of corresponding pauses |
|
forward_context: size of the forward context window |
|
backward_context: size of the backward context window (without the predicted token)` |
|
|
|
Returns: |
|
A tensor of model inputs for each indexed word in x |
|
""" |
|
model_input = [] |
|
tokenized_pause = [PAUSE_TOKEN] * len(pause) |
|
x_pad = [0] * backward_context + x + [0] * forward_context |
|
|
|
for i in range(len(x)): |
|
segment = x_pad[i : i + backward_context + forward_context + 1] |
|
segment.insert(backward_context + 1, tokenized_pause[i]) |
|
model_input.append(segment) |
|
return torch.tensor(model_input) |
|
|
|
|
|
def add_punctuation_to_text(text: str, punct_prob: np.ndarray) -> str: |
|
""" |
|
Inserts punctuation to text on provided punctuation string for every word |
|
Args: |
|
text: text to insert punctuation to |
|
punct_prob: matrix of probabilities for each punctuation |
|
|
|
Returns: |
|
text with punctuation |
|
""" |
|
words = text.split() |
|
new_words = list() |
|
|
|
punctuation_idx = np.argmax(punct_prob, axis=1) |
|
punctuation_list = [PUNCTUATION_SIGNS[i] for i in punctuation_idx] |
|
|
|
for word, punctuation_str in zip(words, punctuation_list): |
|
if punctuation_str: |
|
new_words.append(word + punctuation_str) |
|
else: |
|
new_words.append(word) |
|
|
|
punct_text = ' '.join(new_words) |
|
return punct_text |
|
|
|
|
|
def get_prediction( |
|
model: BertForPunctuation, |
|
text: str, |
|
tokenizer: BertTokenizer, |
|
batch_size: int = 16, |
|
backward_context: int = 15, |
|
forward_context: int = 16, |
|
pause_list: Optional[List[float]] = None, |
|
device: str = 'cpu', |
|
) -> str: |
|
""" |
|
Generates predictions for given list of words. |
|
Args: |
|
model: punctuation model |
|
text: text to predict punctuation for |
|
tokenizer: tokenizer |
|
batch_size: batch size |
|
backward_context: size of the backward context window |
|
forward_context: size of the forward context window |
|
pause_list: list of pauses after each word in seconds |
|
device: device to run model on |
|
|
|
Returns: |
|
text with punctuation |
|
""" |
|
word_list = text.split() |
|
if not pause_list: |
|
|
|
pause_list = [0.0] * len(word_list) |
|
|
|
word_idx, x, pause = tokenize_text(word_list=word_list, pause_list=pause_list, tokenizer=tokenizer) |
|
|
|
model_inputs = gen_model_inputs(x, pause, forward_context, backward_context) |
|
model_inputs = model_inputs.index_select(0, torch.LongTensor(word_idx)).to(device) |
|
inputs_length = len(model_inputs) |
|
|
|
output = [] |
|
with torch.no_grad(): |
|
for ndx in range(0, inputs_length, batch_size): |
|
o = model(model_inputs[ndx : min(ndx + batch_size, inputs_length)]) |
|
o = F.softmax(o, dim=1) |
|
output.append(o.cpu().data.numpy()) |
|
|
|
punct_probabilities_matrix = np.concatenate(output, axis=0) |
|
|
|
punct_text = add_punctuation_to_text(text, punct_probabilities_matrix) |
|
|
|
return punct_text |
|
|
|
|
|
def main(): |
|
model = BertForPunctuation.from_pretrained(MODEL_NAME) |
|
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) |
|
model.eval() |
|
|
|
text = """讞讘专转 讜专讘讬讟 驻讬转讞讛 诪注专讻转 诇转诪诇讜诇 讛诪讘讜住住转 注诇 讘讬谞讛 诪诇讗讻讜转讬转 讜讙讜专诐 讗谞讜砖讬 讜砖讜拽讚转 注诇 转诪诇讜诇 注讚讜讬讜转 谞讬爪讜诇讬 砖讜讗讛 |
|
讗转 讛转讜爪讗讜转 讗驻砖专 诇专讗讜转 讻讘专 讘专砖转 讘讛谉 讞诇拽讬诐 诪注讚讜转讜 砖诇 讟讜讘讬讛 讘讬讬诇住拽讬 砖讛讬讛 诪驻拽讚 讙讚讜讚 讛驻专讟讬讝谞讬诐 讛讬讛讜讚讬诐 讘讘讬讬诇讜专讜住讬讛""" |
|
punct_text = get_prediction( |
|
model=model, |
|
text=text, |
|
tokenizer=tokenizer, |
|
backward_context=model.config.backward_context, |
|
forward_context=model.config.forward_context, |
|
) |
|
print(punct_text) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|