hebrew_punctuation / src /inference.py
verbit-research's picture
add_model_src_code (#1)
5af7e8d verified
raw
history blame
5.75 kB
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from transformers import BertTokenizer
from src.models import BertForPunctuation
PUNCTUATION_SIGNS = ['', ',', '.', '?']
PAUSE_TOKEN = 0
MODEL_NAME = "verbit/hebrew_punctuation"
def tokenize_text(
word_list: List[str], pause_list: List[float], tokenizer: BertTokenizer
) -> Tuple[List[int], List[int], List[float]]:
"""
Tokenizes text and generates pause list for each word
Args:
word_list: list of words
pause_list: list of pauses after each word in seconds
tokenizer: tokenizer
Returns:
original_word_idx: list of indexes of original words
x: list of indexed words
pause: list of pauses after each word in seconds
"""
assert len(word_list) == len(pause_list), "word_list and pause_list should have the same length"
x, pause = [], []
# when we do tokenization the number of tokens might be more than one for single word, so we need to keep
# mapping tokens into real words
original_word_idx = []
for w, p in zip(word_list, pause_list):
tokens = tokenizer.tokenize(w)
p = [p]
# converting tokens to idx, if we have no token for current word then just pad it with 0 to be safe
_x = tokenizer.convert_tokens_to_ids(tokens) if tokens else [0]
if len(_x) > 1:
p = (len(_x) - 1) * [0] + p
x += _x
original_word_idx.append(len(x) - 1)
pause += p
return original_word_idx, x, pause
def gen_model_inputs(
x: List[int],
pause: List[float],
forward_context: int,
backward_context: int,
) -> torch.Tensor:
"""
Generates inputs for model out of list of indexed words.
Inserts a pause token into the segment
Args:
x: list of indexed words
pause: list of corresponding pauses
forward_context: size of the forward context window
backward_context: size of the backward context window (without the predicted token)`
Returns:
A tensor of model inputs for each indexed word in x
"""
model_input = []
tokenized_pause = [PAUSE_TOKEN] * len(pause)
x_pad = [0] * backward_context + x + [0] * forward_context
for i in range(len(x)):
segment = x_pad[i : i + backward_context + forward_context + 1]
segment.insert(backward_context + 1, tokenized_pause[i])
model_input.append(segment)
return torch.tensor(model_input)
def add_punctuation_to_text(text: str, punct_prob: np.ndarray) -> str:
"""
Inserts punctuation to text on provided punctuation string for every word
Args:
text: text to insert punctuation to
punct_prob: matrix of probabilities for each punctuation
Returns:
text with punctuation
"""
words = text.split()
new_words = list()
punctuation_idx = np.argmax(punct_prob, axis=1)
punctuation_list = [PUNCTUATION_SIGNS[i] for i in punctuation_idx]
for word, punctuation_str in zip(words, punctuation_list):
if punctuation_str:
new_words.append(word + punctuation_str)
else:
new_words.append(word)
punct_text = ' '.join(new_words)
return punct_text
def get_prediction(
model: BertForPunctuation,
text: str,
tokenizer: BertTokenizer,
batch_size: int = 16,
backward_context: int = 15,
forward_context: int = 16,
pause_list: Optional[List[float]] = None,
device: str = 'cpu',
) -> str:
"""
Generates predictions for given list of words.
Args:
model: punctuation model
text: text to predict punctuation for
tokenizer: tokenizer
batch_size: batch size
backward_context: size of the backward context window
forward_context: size of the forward context window
pause_list: list of pauses after each word in seconds
device: device to run model on
Returns:
text with punctuation
"""
word_list = text.split()
if not pause_list:
# make default pauses if pauses are not provided
pause_list = [0.0] * len(word_list)
word_idx, x, pause = tokenize_text(word_list=word_list, pause_list=pause_list, tokenizer=tokenizer)
model_inputs = gen_model_inputs(x, pause, forward_context, backward_context)
model_inputs = model_inputs.index_select(0, torch.LongTensor(word_idx)).to(device)
inputs_length = len(model_inputs)
output = []
with torch.no_grad():
for ndx in range(0, inputs_length, batch_size):
o = model(model_inputs[ndx : min(ndx + batch_size, inputs_length)])
o = F.softmax(o, dim=1)
output.append(o.cpu().data.numpy())
punct_probabilities_matrix = np.concatenate(output, axis=0)
punct_text = add_punctuation_to_text(text, punct_probabilities_matrix)
return punct_text
def main():
model = BertForPunctuation.from_pretrained(MODEL_NAME)
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
model.eval()
text = """讞讘专转 讜专讘讬讟 驻讬转讞讛 诪注专讻转 诇转诪诇讜诇 讛诪讘讜住住转 注诇 讘讬谞讛 诪诇讗讻讜转讬转 讜讙讜专诐 讗谞讜砖讬 讜砖讜拽讚转 注诇 转诪诇讜诇 注讚讜讬讜转 谞讬爪讜诇讬 砖讜讗讛
讗转 讛转讜爪讗讜转 讗驻砖专 诇专讗讜转 讻讘专 讘专砖转 讘讛谉 讞诇拽讬诐 诪注讚讜转讜 砖诇 讟讜讘讬讛 讘讬讬诇住拽讬 砖讛讬讛 诪驻拽讚 讙讚讜讚 讛驻专讟讬讝谞讬诐 讛讬讛讜讚讬诐 讘讘讬讬诇讜专讜住讬讛"""
punct_text = get_prediction(
model=model,
text=text,
tokenizer=tokenizer,
backward_context=model.config.backward_context,
forward_context=model.config.forward_context,
)
print(punct_text)
if __name__ == "__main__":
main()