File size: 5,748 Bytes
5af7e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from typing import List, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from transformers import BertTokenizer

from src.models import BertForPunctuation

PUNCTUATION_SIGNS = ['', ',', '.', '?']
PAUSE_TOKEN = 0
MODEL_NAME = "verbit/hebrew_punctuation"


def tokenize_text(
    word_list: List[str], pause_list: List[float], tokenizer: BertTokenizer
) -> Tuple[List[int], List[int], List[float]]:
    """
    Tokenizes text and generates pause list for each word
    Args:
        word_list: list of words
        pause_list: list of pauses after each word in seconds
        tokenizer: tokenizer

    Returns:
        original_word_idx: list of indexes of original words
        x: list of indexed words
        pause: list of pauses after each word in seconds
    """
    assert len(word_list) == len(pause_list), "word_list and pause_list should have the same length"
    x, pause = [], []

    # when we do tokenization the number of tokens might be more than one for single word, so we need to keep
    # mapping tokens into real words
    original_word_idx = []
    for w, p in zip(word_list, pause_list):
        tokens = tokenizer.tokenize(w)
        p = [p]
        # converting tokens to idx, if we have no token for current word then just pad it with 0 to be safe
        _x = tokenizer.convert_tokens_to_ids(tokens) if tokens else [0]

        if len(_x) > 1:
            p = (len(_x) - 1) * [0] + p
        x += _x
        original_word_idx.append(len(x) - 1)
        pause += p

    return original_word_idx, x, pause


def gen_model_inputs(
    x: List[int],
    pause: List[float],
    forward_context: int,
    backward_context: int,
) -> torch.Tensor:
    """
    Generates inputs for model out of list of indexed words.
    Inserts a pause token into the segment
    Args:
        x: list of indexed words
        pause: list of corresponding pauses
        forward_context: size of the forward context window
        backward_context: size of the backward context window (without the predicted token)`

    Returns:
        A tensor of model inputs for each indexed word in x
    """
    model_input = []
    tokenized_pause = [PAUSE_TOKEN] * len(pause)
    x_pad = [0] * backward_context + x + [0] * forward_context

    for i in range(len(x)):
        segment = x_pad[i : i + backward_context + forward_context + 1]
        segment.insert(backward_context + 1, tokenized_pause[i])
        model_input.append(segment)
    return torch.tensor(model_input)


def add_punctuation_to_text(text: str, punct_prob: np.ndarray) -> str:
    """
    Inserts punctuation to text on provided punctuation string for every word
    Args:
        text: text to insert punctuation to
        punct_prob: matrix of probabilities for each punctuation

    Returns:
        text with punctuation
    """
    words = text.split()
    new_words = list()

    punctuation_idx = np.argmax(punct_prob, axis=1)
    punctuation_list = [PUNCTUATION_SIGNS[i] for i in punctuation_idx]

    for word, punctuation_str in zip(words, punctuation_list):
        if punctuation_str:
            new_words.append(word + punctuation_str)
        else:
            new_words.append(word)

    punct_text = ' '.join(new_words)
    return punct_text


def get_prediction(
    model: BertForPunctuation,
    text: str,
    tokenizer: BertTokenizer,
    batch_size: int = 16,
    backward_context: int = 15,
    forward_context: int = 16,
    pause_list: Optional[List[float]] = None,
    device: str = 'cpu',
) -> str:
    """
    Generates predictions for given list of words.
    Args:
        model: punctuation model
        text: text to predict punctuation for
        tokenizer: tokenizer
        batch_size: batch size
        backward_context: size of the backward context window
        forward_context: size of the forward context window
        pause_list: list of pauses after each word in seconds
        device: device to run model on

    Returns:
        text with punctuation
    """
    word_list = text.split()
    if not pause_list:
        # make default pauses if pauses are not provided
        pause_list = [0.0] * len(word_list)

    word_idx, x, pause = tokenize_text(word_list=word_list, pause_list=pause_list, tokenizer=tokenizer)

    model_inputs = gen_model_inputs(x, pause, forward_context, backward_context)
    model_inputs = model_inputs.index_select(0, torch.LongTensor(word_idx)).to(device)
    inputs_length = len(model_inputs)

    output = []
    with torch.no_grad():
        for ndx in range(0, inputs_length, batch_size):
            o = model(model_inputs[ndx : min(ndx + batch_size, inputs_length)])
            o = F.softmax(o, dim=1)
            output.append(o.cpu().data.numpy())

    punct_probabilities_matrix = np.concatenate(output, axis=0)

    punct_text = add_punctuation_to_text(text, punct_probabilities_matrix)

    return punct_text


def main():
    model = BertForPunctuation.from_pretrained(MODEL_NAME)
    tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
    model.eval()

    text = """讞讘专转 讜专讘讬讟 驻讬转讞讛 诪注专讻转 诇转诪诇讜诇 讛诪讘讜住住转 注诇 讘讬谞讛 诪诇讗讻讜转讬转 讜讙讜专诐 讗谞讜砖讬 讜砖讜拽讚转 注诇 转诪诇讜诇 注讚讜讬讜转 谞讬爪讜诇讬 砖讜讗讛 
    讗转 讛转讜爪讗讜转 讗驻砖专 诇专讗讜转 讻讘专 讘专砖转 讘讛谉 讞诇拽讬诐 诪注讚讜转讜 砖诇 讟讜讘讬讛 讘讬讬诇住拽讬 砖讛讬讛 诪驻拽讚 讙讚讜讚 讛驻专讟讬讝谞讬诐 讛讬讛讜讚讬诐 讘讘讬讬诇讜专讜住讬讛"""
    punct_text = get_prediction(
        model=model,
        text=text,
        tokenizer=tokenizer,
        backward_context=model.config.backward_context,
        forward_context=model.config.forward_context,
    )
    print(punct_text)


if __name__ == "__main__":
    main()