verbit-research nirraviv89 commited on
Commit
5af7e8d
โ€ข
1 Parent(s): cd89bf9

add_model_src_code (#1)

Browse files

- add model src code (6f8733e4a99d15adc3c3f5794ceb9f442ca289b2)
- rename and documentation (fbc3442307396d208fb11eae95ceef4d7b81ed89)


Co-authored-by: Nir Raviv <[email protected]>

Files changed (4) hide show
  1. requirements.txt +3 -0
  2. src/config.py +44 -0
  3. src/inference.py +174 -0
  4. src/models.py +24 -0
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ numpy==1.23.5
2
+ torch==2.2.2
3
+ transformers==4.44.2
src/config.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertConfig
2
+
3
+
4
+ class PunctuationBertConfig(BertConfig):
5
+ r"""
6
+ This is the configuration class to store the configuration of a [`PunctuationBertConfig`]. It is based on BERT config
7
+ to the specified arguments, defining the model architecture.
8
+ Args:
9
+ backward_context (`int`, *optional*, defaults to 15):
10
+ size of backward context window
11
+ forward_context (`int`, *optional*, defaults to 16):
12
+ size of forward context window
13
+ output_size (`int`, *optional*, defaults to 4):
14
+ number of punctuation classes
15
+ dropout (`float`, *optional*, defaults to 0.3):
16
+ dropout rate
17
+
18
+ Examples:
19
+ ```python
20
+ >>> from transformers import BertConfig, BertModel
21
+
22
+ >>> # Initializing a BERT google-bert/bert-base-uncased style configuration
23
+ >>> configuration = PunctuationBertConfig()
24
+
25
+ >>> # Initializing a model (with random weights) from the google-bert/bert-base-uncased style configuration
26
+ >>> model = BertForPunctuation(configuration)
27
+
28
+ >>> # Accessing the model configuration
29
+ >>> configuration = model.config
30
+ ```"""
31
+
32
+ def __init__(
33
+ self,
34
+ backward_context=15,
35
+ forward_context=16,
36
+ output_size=4,
37
+ dropout=0.3,
38
+ **kwargs,
39
+ ):
40
+ super().__init__(**kwargs)
41
+ self.backward_context = backward_context
42
+ self.forward_context = forward_context
43
+ self.output_size = output_size
44
+ self.dropout = dropout
src/inference.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from transformers import BertTokenizer
7
+
8
+ from src.models import BertForPunctuation
9
+
10
+ PUNCTUATION_SIGNS = ['', ',', '.', '?']
11
+ PAUSE_TOKEN = 0
12
+ MODEL_NAME = "verbit/hebrew_punctuation"
13
+
14
+
15
+ def tokenize_text(
16
+ word_list: List[str], pause_list: List[float], tokenizer: BertTokenizer
17
+ ) -> Tuple[List[int], List[int], List[float]]:
18
+ """
19
+ Tokenizes text and generates pause list for each word
20
+ Args:
21
+ word_list: list of words
22
+ pause_list: list of pauses after each word in seconds
23
+ tokenizer: tokenizer
24
+
25
+ Returns:
26
+ original_word_idx: list of indexes of original words
27
+ x: list of indexed words
28
+ pause: list of pauses after each word in seconds
29
+ """
30
+ assert len(word_list) == len(pause_list), "word_list and pause_list should have the same length"
31
+ x, pause = [], []
32
+
33
+ # when we do tokenization the number of tokens might be more than one for single word, so we need to keep
34
+ # mapping tokens into real words
35
+ original_word_idx = []
36
+ for w, p in zip(word_list, pause_list):
37
+ tokens = tokenizer.tokenize(w)
38
+ p = [p]
39
+ # converting tokens to idx, if we have no token for current word then just pad it with 0 to be safe
40
+ _x = tokenizer.convert_tokens_to_ids(tokens) if tokens else [0]
41
+
42
+ if len(_x) > 1:
43
+ p = (len(_x) - 1) * [0] + p
44
+ x += _x
45
+ original_word_idx.append(len(x) - 1)
46
+ pause += p
47
+
48
+ return original_word_idx, x, pause
49
+
50
+
51
+ def gen_model_inputs(
52
+ x: List[int],
53
+ pause: List[float],
54
+ forward_context: int,
55
+ backward_context: int,
56
+ ) -> torch.Tensor:
57
+ """
58
+ Generates inputs for model out of list of indexed words.
59
+ Inserts a pause token into the segment
60
+ Args:
61
+ x: list of indexed words
62
+ pause: list of corresponding pauses
63
+ forward_context: size of the forward context window
64
+ backward_context: size of the backward context window (without the predicted token)`
65
+
66
+ Returns:
67
+ A tensor of model inputs for each indexed word in x
68
+ """
69
+ model_input = []
70
+ tokenized_pause = [PAUSE_TOKEN] * len(pause)
71
+ x_pad = [0] * backward_context + x + [0] * forward_context
72
+
73
+ for i in range(len(x)):
74
+ segment = x_pad[i : i + backward_context + forward_context + 1]
75
+ segment.insert(backward_context + 1, tokenized_pause[i])
76
+ model_input.append(segment)
77
+ return torch.tensor(model_input)
78
+
79
+
80
+ def add_punctuation_to_text(text: str, punct_prob: np.ndarray) -> str:
81
+ """
82
+ Inserts punctuation to text on provided punctuation string for every word
83
+ Args:
84
+ text: text to insert punctuation to
85
+ punct_prob: matrix of probabilities for each punctuation
86
+
87
+ Returns:
88
+ text with punctuation
89
+ """
90
+ words = text.split()
91
+ new_words = list()
92
+
93
+ punctuation_idx = np.argmax(punct_prob, axis=1)
94
+ punctuation_list = [PUNCTUATION_SIGNS[i] for i in punctuation_idx]
95
+
96
+ for word, punctuation_str in zip(words, punctuation_list):
97
+ if punctuation_str:
98
+ new_words.append(word + punctuation_str)
99
+ else:
100
+ new_words.append(word)
101
+
102
+ punct_text = ' '.join(new_words)
103
+ return punct_text
104
+
105
+
106
+ def get_prediction(
107
+ model: BertForPunctuation,
108
+ text: str,
109
+ tokenizer: BertTokenizer,
110
+ batch_size: int = 16,
111
+ backward_context: int = 15,
112
+ forward_context: int = 16,
113
+ pause_list: Optional[List[float]] = None,
114
+ device: str = 'cpu',
115
+ ) -> str:
116
+ """
117
+ Generates predictions for given list of words.
118
+ Args:
119
+ model: punctuation model
120
+ text: text to predict punctuation for
121
+ tokenizer: tokenizer
122
+ batch_size: batch size
123
+ backward_context: size of the backward context window
124
+ forward_context: size of the forward context window
125
+ pause_list: list of pauses after each word in seconds
126
+ device: device to run model on
127
+
128
+ Returns:
129
+ text with punctuation
130
+ """
131
+ word_list = text.split()
132
+ if not pause_list:
133
+ # make default pauses if pauses are not provided
134
+ pause_list = [0.0] * len(word_list)
135
+
136
+ word_idx, x, pause = tokenize_text(word_list=word_list, pause_list=pause_list, tokenizer=tokenizer)
137
+
138
+ model_inputs = gen_model_inputs(x, pause, forward_context, backward_context)
139
+ model_inputs = model_inputs.index_select(0, torch.LongTensor(word_idx)).to(device)
140
+ inputs_length = len(model_inputs)
141
+
142
+ output = []
143
+ with torch.no_grad():
144
+ for ndx in range(0, inputs_length, batch_size):
145
+ o = model(model_inputs[ndx : min(ndx + batch_size, inputs_length)])
146
+ o = F.softmax(o, dim=1)
147
+ output.append(o.cpu().data.numpy())
148
+
149
+ punct_probabilities_matrix = np.concatenate(output, axis=0)
150
+
151
+ punct_text = add_punctuation_to_text(text, punct_probabilities_matrix)
152
+
153
+ return punct_text
154
+
155
+
156
+ def main():
157
+ model = BertForPunctuation.from_pretrained(MODEL_NAME)
158
+ tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
159
+ model.eval()
160
+
161
+ text = """ื—ื‘ืจืช ื•ืจื‘ื™ื˜ ืคื™ืชื—ื” ืžืขืจื›ืช ืœืชืžืœื•ืœ ื”ืžื‘ื•ืกืกืช ืขืœ ื‘ื™ื ื” ืžืœืื›ื•ืชื™ืช ื•ื’ื•ืจื ืื ื•ืฉื™ ื•ืฉื•ืงื“ืช ืขืœ ืชืžืœื•ืœ ืขื“ื•ื™ื•ืช ื ื™ืฆื•ืœื™ ืฉื•ืื”
162
+ ืืช ื”ืชื•ืฆืื•ืช ืืคืฉืจ ืœืจืื•ืช ื›ื‘ืจ ื‘ืจืฉืช ื‘ื”ืŸ ื—ืœืงื™ื ืžืขื“ื•ืชื• ืฉืœ ื˜ื•ื‘ื™ื” ื‘ื™ื™ืœืกืงื™ ืฉื”ื™ื” ืžืคืงื“ ื’ื“ื•ื“ ื”ืคืจื˜ื™ื–ื ื™ื ื”ื™ื”ื•ื“ื™ื ื‘ื‘ื™ื™ืœื•ืจื•ืกื™ื”"""
163
+ punct_text = get_prediction(
164
+ model=model,
165
+ text=text,
166
+ tokenizer=tokenizer,
167
+ backward_context=model.config.backward_context,
168
+ forward_context=model.config.forward_context,
169
+ )
170
+ print(punct_text)
171
+
172
+
173
+ if __name__ == "__main__":
174
+ main()
src/models.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from transformers import BertForMaskedLM, PreTrainedModel
3
+
4
+ from src.config import PunctuationBertConfig
5
+
6
+
7
+ class BertForPunctuation(PreTrainedModel):
8
+ config_class = PunctuationBertConfig
9
+
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+ # segment_size equal backward_context + forward_context + predicted token + pause token
13
+ segment_size = config.backward_context + config.forward_context + 2
14
+ bert_vocab_size = config.vocab_size
15
+ self.bert = BertForMaskedLM(config)
16
+ self.bn = nn.BatchNorm1d(segment_size * bert_vocab_size)
17
+ self.fc = nn.Linear(segment_size * bert_vocab_size, config.output_size)
18
+ self.dropout = nn.Dropout(config.dropout)
19
+
20
+ def forward(self, x):
21
+ x = self.bert(x)[0]
22
+ x = x.view(x.shape[0], -1)
23
+ x = self.fc(self.dropout(self.bn(x)))
24
+ return x