ddemszky commited on
Commit
91cb36e
·
1 Parent(s): 8b8145a

added custom handler

Browse files
__pycache__/handler.cpython-39.pyc ADDED
Binary file (8.82 kB). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (6.53 kB). View file
 
handler.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from scipy.special import softmax
3
+ import numpy as np
4
+ import weakref
5
+
6
+ from utils import clean_str, clean_str_nopunct
7
+ import torch
8
+ from transformers import BertTokenizer, BertForSequenceClassification
9
+ from utils import MultiHeadModel, BertInputBuilder, get_num_words
10
+
11
+ UPTAKE_MODEL='ddemszky/uptake-model'
12
+ REASONING_MODEL ='ddemszky/student-reasoning'
13
+ QUESTION_MODEL ='ddemszky/question-detection'
14
+
15
+ class Utterance:
16
+ def __init__(self, speaker, text, uid=None,
17
+ transcript=None, starttime=None, endtime=None, **kwargs):
18
+ self.speaker = speaker
19
+ self.text = text
20
+ self.uid = uid
21
+ self.starttime = starttime
22
+ self.endtime = endtime
23
+ self.transcript = weakref.ref(transcript) if transcript else None
24
+ self.props = kwargs
25
+
26
+ self.uptake = None
27
+ self.reasoning = None
28
+ self.question = None
29
+
30
+ def get_clean_text(self, remove_punct=False):
31
+ if remove_punct:
32
+ return clean_str_nopunct(self.text)
33
+ return clean_str(self.text)
34
+
35
+ def get_num_words(self):
36
+ return get_num_words(self.text)
37
+
38
+ def to_dict(self):
39
+ return {
40
+ 'speaker': self.speaker,
41
+ 'text': self.text,
42
+ 'uid': self.uid,
43
+ 'starttime': self.starttime,
44
+ 'endtime': self.endtime,
45
+ 'uptake': self.uptake,
46
+ 'reasoning': self.reasoning,
47
+ 'question': self.question,
48
+ **self.props
49
+ }
50
+
51
+ def __repr__(self):
52
+ return f"Utterance(speaker='{self.speaker}'," \
53
+ f"text='{self.text}', uid={self.uid}," \
54
+ f"starttime={self.starttime}, endtime={self.endtime}, props={self.props})"
55
+
56
+ class Transcript:
57
+ def __init__(self, **kwargs):
58
+ self.utterances = []
59
+ self.params = kwargs
60
+
61
+ def add_utterance(self, utterance):
62
+ utterance.transcript = weakref.ref(self)
63
+ self.utterances.append(utterance)
64
+
65
+ def get_idx(self, idx):
66
+ if idx >= len(self.utterances):
67
+ return None
68
+ return self.utterances[idx]
69
+
70
+ def get_uid(self, uid):
71
+ for utt in self.utterances:
72
+ if utt.uid == uid:
73
+ return utt
74
+ return None
75
+
76
+ def length(self):
77
+ return len(self.utterances)
78
+
79
+ def to_dict(self):
80
+ return {
81
+ 'utterances': [utterance.to_dict() for utterance in self.utterances],
82
+ **self.params
83
+ }
84
+
85
+ def __repr__(self):
86
+ return f"Transcript(utterances={self.utterances}, custom_params={self.params})"
87
+
88
+ class QuestionModel:
89
+ def __init__(self, device, tokenizer, input_builder, max_length=300, path=QUESTION_MODEL):
90
+ print("Loading models...")
91
+ self.device = device
92
+ self.tokenizer = tokenizer
93
+ self.input_builder = input_builder
94
+ self.max_length = max_length
95
+ self.model = MultiHeadModel.from_pretrained(path, head2size={"is_question": 2})
96
+ self.model.to(self.device)
97
+
98
+
99
+ def run_inference(self, transcript):
100
+ self.model.eval()
101
+ with torch.no_grad():
102
+ for i, utt in enumerate(transcript.utterances):
103
+ if "?" in utt.text:
104
+ utt.question = 1
105
+ else:
106
+ text = utt.get_clean_text(remove_punct=True)
107
+ instance = self.input_builder.build_inputs([], text,
108
+ max_length=self.max_length,
109
+ input_str=True)
110
+ output = self.get_prediction(instance)
111
+ print(output)
112
+ utt.question = np.argmax(output["is_question_logits"][0].tolist())
113
+
114
+ def get_prediction(self, instance):
115
+ instance["attention_mask"] = [[1] * len(instance["input_ids"])]
116
+ for key in ["input_ids", "token_type_ids", "attention_mask"]:
117
+ instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1
118
+ instance[key].to(self.device)
119
+
120
+ output = self.model(input_ids=instance["input_ids"],
121
+ attention_mask=instance["attention_mask"],
122
+ token_type_ids=instance["token_type_ids"],
123
+ return_pooler_output=False)
124
+ return output
125
+
126
+ class ReasoningModel:
127
+ def __init__(self, device, tokenizer, input_builder, max_length=128, path=REASONING_MODEL):
128
+ print("Loading models...")
129
+ self.device = device
130
+ self.tokenizer = tokenizer
131
+ self.input_builder = input_builder
132
+ self.max_length = max_length
133
+ self.model = BertForSequenceClassification.from_pretrained(path)
134
+ self.model.to(self.device)
135
+
136
+ def run_inference(self, transcript, min_num_words=8):
137
+ self.model.eval()
138
+ with torch.no_grad():
139
+ for i, utt in enumerate(transcript.utterances):
140
+ if utt.get_num_words() >= min_num_words:
141
+ instance = self.input_builder.build_inputs([], utt.text,
142
+ max_length=self.max_length,
143
+ input_str=True)
144
+ output = self.get_prediction(instance)
145
+ utt.reasoning = np.argmax(output["logits"][0].tolist())
146
+
147
+ def get_prediction(self, instance):
148
+ instance["attention_mask"] = [[1] * len(instance["input_ids"])]
149
+ for key in ["input_ids", "token_type_ids", "attention_mask"]:
150
+ instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1
151
+ instance[key].to(self.device)
152
+
153
+ output = self.model(input_ids=instance["input_ids"],
154
+ attention_mask=instance["attention_mask"],
155
+ token_type_ids=instance["token_type_ids"])
156
+ return output
157
+
158
+ class UptakeModel:
159
+ def __init__(self, device, tokenizer, input_builder, max_length=120, path=UPTAKE_MODEL):
160
+ print("Loading models...")
161
+ self.device = device
162
+ self.tokenizer = tokenizer
163
+ self.input_builder = input_builder
164
+ self.max_length = max_length
165
+ self.model = MultiHeadModel.from_pretrained(path, head2size={"nsp": 2})
166
+ self.model.to(self.device)
167
+
168
+ def run_inference(self, transcript, min_prev_words, uptake_speaker=None):
169
+ self.model.eval()
170
+ prev_num_words = 0
171
+ prev_utt = None
172
+ with torch.no_grad():
173
+ for i, utt in enumerate(transcript.utterances):
174
+ if ((uptake_speaker is None) or (utt.speaker == uptake_speaker)) and (prev_num_words >= min_prev_words):
175
+ textA = prev_utt.get_clean_text(remove_punct=False)
176
+ textB = utt.get_clean_text(remove_punct=False)
177
+ instance = self.input_builder.build_inputs([textA], textB,
178
+ max_length=self.max_length,
179
+ input_str=True)
180
+ output = self.get_prediction(instance)
181
+
182
+ utt.uptake = int(softmax(output["nsp_logits"][0].tolist())[1] > .8)
183
+ prev_num_words = utt.get_num_words()
184
+ prev_utt = utt
185
+
186
+ def get_prediction(self, instance):
187
+ instance["attention_mask"] = [[1] * len(instance["input_ids"])]
188
+ for key in ["input_ids", "token_type_ids", "attention_mask"]:
189
+ instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1
190
+ instance[key].to(self.device)
191
+
192
+ output = self.model(input_ids=instance["input_ids"],
193
+ attention_mask=instance["attention_mask"],
194
+ token_type_ids=instance["token_type_ids"],
195
+ return_pooler_output=False)
196
+ return output
197
+
198
+
199
+ class EndpointHandler():
200
+ def __init__(self):
201
+ print("Loading models...")
202
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
203
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
204
+ self.input_builder = BertInputBuilder(tokenizer=self.tokenizer)
205
+
206
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
207
+ """
208
+ data args:
209
+ inputs (:obj: `list`):
210
+ List of dicts, where each dict represents an utterance; each utterance object must have a `speaker`,
211
+ `text` and `uid`and can include list of custom properties
212
+ parameters (:obj: `dict`)
213
+ Return:
214
+ A :obj:`list` | `dict`: will be serialized and returned
215
+ """
216
+ # get inputs
217
+ utterances = data.pop("inputs", data)
218
+ params = data.pop("parameters", None)
219
+
220
+ print("EXAMPLES")
221
+ for utt in utterances[:3]:
222
+ print("speaker %s: %s" % (utt["speaker"], utt["text"]))
223
+
224
+ transcript = Transcript(filename=params.pop("filename", None))
225
+ for utt in utterances:
226
+ transcript.add_utterance(Utterance(**utt))
227
+
228
+ print("Running inference on %d examples..." % transcript.length())
229
+
230
+ # Uptake
231
+ uptake_model = UptakeModel(self.device, self.tokenizer, self.input_builder)
232
+ uptake_model.run_inference(transcript, min_prev_words=params['uptake_min_num_words'],
233
+ uptake_speaker=params.pop("uptake_speaker", None))
234
+
235
+ # Reasoning
236
+ reasoning_model = ReasoningModel(self.device, self.tokenizer, self.input_builder)
237
+ reasoning_model.run_inference(transcript)
238
+
239
+ # Question
240
+ question_model = QuestionModel(self.device, self.tokenizer, self.input_builder)
241
+ question_model.run_inference(transcript)
242
+
243
+ return transcript.to_dict()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ clean-text==1.1.4
2
+ num2words==0.5.10
3
+ numpy==1.22.4
4
+ scipy==1.7.3
5
+ torch==1.10.2
6
+ transformers==4.25.1
test.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from handler import EndpointHandler
3
+
4
+ # init handler
5
+ my_handler = EndpointHandler()
6
+
7
+ # prepare sample payload
8
+ example = {
9
+ "inputs": [
10
+ {"uid": "1", "speaker": "Alice", "text": "How much is the fish?" },
11
+ {"uid": "2", "speaker": "Bob", "text": "I do not know about the fish. Because you put a long side and it’s a long side. What do you think." },
12
+ {"uid": "3", "speaker": "Alice", "text": "OK, thank you Bob." },
13
+ ],
14
+ "parameters": {
15
+ "uptake_min_num_words": 5,
16
+ "uptake_speaker": "Bob",
17
+ "filename": "sample.csv",
18
+ }
19
+ }
20
+
21
+ # test the handler
22
+ print(my_handler(example))
utils.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel
3
+ from torch import nn
4
+ from itertools import chain
5
+ from torch.nn import MSELoss, CrossEntropyLoss
6
+ from cleantext import clean
7
+ from num2words import num2words
8
+ import re
9
+ import string
10
+
11
+ punct_chars = list((set(string.punctuation) | {'’', '‘', '–', '—', '~', '|', '“', '”', '…', "'", "`", '_'}))
12
+ punct_chars.sort()
13
+ punctuation = ''.join(punct_chars)
14
+ replace = re.compile('[%s]' % re.escape(punctuation))
15
+
16
+ def get_num_words(text):
17
+ if not isinstance(text, str):
18
+ print("%s is not a string" % text)
19
+ text = replace.sub(' ', text)
20
+ text = re.sub(r'\s+', ' ', text)
21
+ text = text.strip()
22
+ text = re.sub(r'\[.+\]', " ", text)
23
+ return len(text.split())
24
+
25
+ def number_to_words(num):
26
+ try:
27
+ return num2words(re.sub(",", "", num))
28
+ except:
29
+ return num
30
+
31
+
32
+ clean_str = lambda s: clean(s,
33
+ fix_unicode=True, # fix various unicode errors
34
+ to_ascii=True, # transliterate to closest ASCII representation
35
+ lower=True, # lowercase text
36
+ no_line_breaks=True, # fully strip line breaks as opposed to only normalizing them
37
+ no_urls=True, # replace all URLs with a special token
38
+ no_emails=True, # replace all email addresses with a special token
39
+ no_phone_numbers=True, # replace all phone numbers with a special token
40
+ no_numbers=True, # replace all numbers with a special token
41
+ no_digits=False, # replace all digits with a special token
42
+ no_currency_symbols=False, # replace all currency symbols with a special token
43
+ no_punct=False, # fully remove punctuation
44
+ replace_with_url="<URL>",
45
+ replace_with_email="<EMAIL>",
46
+ replace_with_phone_number="<PHONE>",
47
+ replace_with_number=lambda m: number_to_words(m.group()),
48
+ replace_with_digit="0",
49
+ replace_with_currency_symbol="<CUR>",
50
+ lang="en"
51
+ )
52
+
53
+ clean_str_nopunct = lambda s: clean(s,
54
+ fix_unicode=True, # fix various unicode errors
55
+ to_ascii=True, # transliterate to closest ASCII representation
56
+ lower=True, # lowercase text
57
+ no_line_breaks=True, # fully strip line breaks as opposed to only normalizing them
58
+ no_urls=True, # replace all URLs with a special token
59
+ no_emails=True, # replace all email addresses with a special token
60
+ no_phone_numbers=True, # replace all phone numbers with a special token
61
+ no_numbers=True, # replace all numbers with a special token
62
+ no_digits=False, # replace all digits with a special token
63
+ no_currency_symbols=False, # replace all currency symbols with a special token
64
+ no_punct=True, # fully remove punctuation
65
+ replace_with_url="<URL>",
66
+ replace_with_email="<EMAIL>",
67
+ replace_with_phone_number="<PHONE>",
68
+ replace_with_number=lambda m: number_to_words(m.group()),
69
+ replace_with_digit="0",
70
+ replace_with_currency_symbol="<CUR>",
71
+ lang="en"
72
+ )
73
+
74
+
75
+
76
+ class MultiHeadModel(BertPreTrainedModel):
77
+ """Pre-trained BERT model that uses our loss functions"""
78
+
79
+ def __init__(self, config, head2size):
80
+ super(MultiHeadModel, self).__init__(config, head2size)
81
+ config.num_labels = 1
82
+ self.bert = BertModel(config)
83
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
84
+ module_dict = {}
85
+ for head_name, num_labels in head2size.items():
86
+ module_dict[head_name] = nn.Linear(config.hidden_size, num_labels)
87
+ self.heads = nn.ModuleDict(module_dict)
88
+
89
+ self.init_weights()
90
+
91
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None,
92
+ head2labels=None, return_pooler_output=False, head2mask=None,
93
+ nsp_loss_weights=None):
94
+
95
+ device = "cuda" if torch.cuda.is_available() else "cpu"
96
+
97
+ # Get logits
98
+ output = self.bert(
99
+ input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
100
+ output_attentions=False, output_hidden_states=False, return_dict=True)
101
+ pooled_output = self.dropout(output["pooler_output"]).to(device)
102
+
103
+ head2logits = {}
104
+ return_dict = {}
105
+ for head_name, head in self.heads.items():
106
+ head2logits[head_name] = self.heads[head_name](pooled_output)
107
+ head2logits[head_name] = head2logits[head_name].float()
108
+ return_dict[head_name + "_logits"] = head2logits[head_name]
109
+
110
+
111
+ if head2labels is not None:
112
+ for head_name, labels in head2labels.items():
113
+ num_classes = head2logits[head_name].shape[1]
114
+
115
+ # Regression (e.g. for politeness)
116
+ if num_classes == 1:
117
+
118
+ # Only consider positive examples
119
+ if head2mask is not None and head_name in head2mask:
120
+ num_positives = head2labels[head2mask[head_name]].sum() # use certain labels as mask
121
+ if num_positives == 0:
122
+ return_dict[head_name + "_loss"] = torch.tensor([0]).to(device)
123
+ else:
124
+ loss_fct = MSELoss(reduction='none')
125
+ loss = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1))
126
+ return_dict[head_name + "_loss"] = loss.dot(head2labels[head2mask[head_name]].float().view(-1)) / num_positives
127
+ else:
128
+ loss_fct = MSELoss()
129
+ return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1))
130
+ else:
131
+ loss_fct = CrossEntropyLoss(weight=nsp_loss_weights.float())
132
+ return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name], labels.view(-1))
133
+
134
+
135
+ if return_pooler_output:
136
+ return_dict["pooler_output"] = output["pooler_output"]
137
+
138
+ return return_dict
139
+
140
+ class InputBuilder(object):
141
+ """Base class for building inputs from segments."""
142
+
143
+ def __init__(self, tokenizer):
144
+ self.tokenizer = tokenizer
145
+ self.mask = [tokenizer.mask_token_id]
146
+
147
+ def build_inputs(self, history, reply, max_length):
148
+ raise NotImplementedError
149
+
150
+ def mask_seq(self, sequence, seq_id):
151
+ sequence[seq_id] = self.mask
152
+ return sequence
153
+
154
+ @classmethod
155
+ def _combine_sequence(self, history, reply, max_length, flipped=False):
156
+ # Trim all inputs to max_length
157
+ history = [s[:max_length] for s in history]
158
+ reply = reply[:max_length]
159
+ if flipped:
160
+ return [reply] + history
161
+ return history + [reply]
162
+
163
+
164
+ class BertInputBuilder(InputBuilder):
165
+ """Processor for BERT inputs"""
166
+
167
+ def __init__(self, tokenizer):
168
+ InputBuilder.__init__(self, tokenizer)
169
+ self.cls = [tokenizer.cls_token_id]
170
+ self.sep = [tokenizer.sep_token_id]
171
+ self.model_inputs = ["input_ids", "token_type_ids", "attention_mask"]
172
+ self.padded_inputs = ["input_ids", "token_type_ids"]
173
+ self.flipped = False
174
+
175
+
176
+ def build_inputs(self, history, reply, max_length, input_str=True):
177
+ """See base class."""
178
+ if input_str:
179
+ history = [self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(t)) for t in history]
180
+ reply = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(reply))
181
+ sequence = self._combine_sequence(history, reply, max_length, self.flipped)
182
+ sequence = [s + self.sep for s in sequence]
183
+ sequence[0] = self.cls + sequence[0]
184
+
185
+ instance = {}
186
+ instance["input_ids"] = list(chain(*sequence))
187
+ last_speaker = 0
188
+ other_speaker = 1
189
+ seq_length = len(sequence)
190
+ instance["token_type_ids"] = [last_speaker if ((seq_length - i) % 2 == 1) else other_speaker
191
+ for i, s in enumerate(sequence) for _ in s]
192
+ return instance