Ashlee Kupor commited on
Commit
5499fc9
·
1 Parent(s): e102f04

Add handler

Browse files
Files changed (2) hide show
  1. handler.py +180 -0
  2. test_run_handler.py +13 -0
handler.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from simpletransformers.classification import ClassificationModel, ClassificationArgs
2
+ from typing import Dict, List, Any
3
+ import pandas as pd
4
+ import webvtt
5
+ from datetime import datetime
6
+ import torch
7
+ import spacy
8
+
9
+ nlp = spacy.load("en_core_web_sm")
10
+ tokenizer = nlp.tokenizer
11
+ token_limit = 200
12
+
13
+ class Utterance(object):
14
+
15
+ def __init__(self, starttime, endtime, speaker, text,
16
+ idx, prev_utterance, prev_prev_utterance):
17
+ self.starttime = starttime
18
+ self.endtime = endtime
19
+ self.speaker = speaker
20
+ self.text = text
21
+ self.idx = idx
22
+ self.prev_utterance = prev_utterance
23
+ self.prev_prev_utterance = prev_prev_utterance
24
+
25
+ class EndpointHandler():
26
+ def __init__(self, path="."):
27
+ print("Loading models...")
28
+ cuda_available = torch.cuda.is_available()
29
+ self.model = ClassificationModel(
30
+ "roberta", path, use_cuda=cuda_available
31
+ )
32
+
33
+ def utterance_to_str(self, utterance: Utterance) -> (List[str], str):
34
+ #revoicing using prior text and truncates end of the prior text
35
+
36
+ doc = nlp(utterance.text)
37
+ prior_text = self.truncate_end(self.get_prior_text(utterance))
38
+
39
+ if len(doc) > token_limit:
40
+ utterance_text_list = self.handle_long_utterances(doc)
41
+ utterance_with_prior_text = []
42
+ for text in utterance_text_list:
43
+ utterance_with_prior_text.append([prior_text, text])
44
+ return utterance_with_prior_text, 'list'
45
+
46
+ else:
47
+ return [prior_text, utterance.text], 'single'
48
+
49
+ def truncate_end(self, prior_text: str) -> str:
50
+ max_seq_length = 512
51
+ prior_text_max_length = int(max_seq_length / 2) #divide by 2 because 2 columns
52
+
53
+ if len(prior_text) > prior_text_max_length:
54
+ starting_index = len(prior_text) - prior_text_max_length
55
+ return prior_text[starting_index:]
56
+ return prior_text
57
+
58
+ def format_speaker(self, speaker: str, source: str) -> str:
59
+ prior_text = ''
60
+ if speaker == 'student':
61
+ prior_text += '***STUDENT '
62
+ else:
63
+ prior_text += '***SECTION_LEADER '
64
+ if source == 'not chat':
65
+ prior_text += '(audio)*** : '
66
+ else:
67
+ prior_text += '(chat)*** : '
68
+ return prior_text
69
+
70
+ def get_prior_text(self, utterance: Utterance) -> str:
71
+ prior_text = ''
72
+ if utterance.prev_utterance != None and utterance.prev_prev_utterance != None:
73
+ #TODO: add in the source
74
+ prior_text = '\"' + self.format_speaker(utterance.prev_prev_utterance.speaker, 'not chat') + utterance.prev_prev_utterance.text + ' \n '
75
+ prior_text += self.format_speaker(utterance.prev_utterance.speaker, 'not chat') + utterance.prev_utterance.text + ' \n '
76
+ else:
77
+ prior_text = 'No prior utterance'
78
+ return prior_text
79
+
80
+ def handle_long_utterances(self, doc: str) -> List[str]:
81
+ split_count = 1
82
+ total_sent = len([x for x in doc.sents])
83
+ sent_count = 0
84
+ token_count = 0
85
+ split_utterance = ''
86
+ utterances = []
87
+ for sent in doc.sents:
88
+ # add a sentence to split
89
+ split_utterance = split_utterance + ' ' + sent.text
90
+ token_count += len(sent)
91
+ sent_count +=1
92
+ if token_count >= token_limit or sent_count == total_sent:
93
+ # save utterance segment
94
+ utterances.append(split_utterance)
95
+
96
+ # restart count
97
+ split_utterance = ''
98
+ token_count = 0
99
+ split_count += 1
100
+
101
+ return utterances
102
+
103
+ def convert_time(self, time_str):
104
+ time = datetime.strptime(time_str, "%H:%M:%S.%f")
105
+ return 1000 * (3600 * time.hour + 60 * time.minute + time.second) + time.microsecond / 1000
106
+
107
+ def process_vtt_transcript(self, vttfile) -> List[Utterance]:
108
+ """Process raw vtt file."""
109
+
110
+ utterances_list = []
111
+ text = ""
112
+ prev_start = "00:00:00.000"
113
+ prev_end = "00:00:00.000"
114
+ idx = 0
115
+ prev_speaker = None
116
+ prev_utterance = None
117
+ prev_prev_utterance = None
118
+ for caption in webvtt.read(vttfile):
119
+
120
+ # Get speaker
121
+ check_for_speaker = caption.text.split(":")
122
+ if len(check_for_speaker) > 1: # the speaker was changed or restated
123
+ speaker = check_for_speaker[0]
124
+ else:
125
+ speaker = prev_speaker
126
+
127
+ # Get utterance
128
+ new_text = check_for_speaker[1] if len(check_for_speaker) > 1 else check_for_speaker[0]
129
+
130
+ # If speaker was changed, start new batch
131
+ if (prev_speaker is not None) and (speaker != prev_speaker):
132
+ utterance = Utterance(starttime=self.convert_time(prev_start),
133
+ endtime=self.convert_time(prev_end),
134
+ speaker=prev_speaker,
135
+ text=text.strip(),
136
+ idx=idx,
137
+ prev_utterance=prev_utterance,
138
+ prev_prev_utterance=prev_prev_utterance)
139
+
140
+ utterances_list.append(utterance)
141
+
142
+ # Start new batch
143
+ prev_start = caption.start
144
+ text = ""
145
+ prev_prev_utterance = prev_utterance
146
+ prev_utterance = utterance
147
+ idx+=1
148
+ text += new_text + " "
149
+ prev_end = caption.end
150
+ prev_speaker = speaker
151
+
152
+ # Append last one
153
+ if prev_speaker is not None:
154
+ utterance = Utterance(starttime=self.convert_time(prev_start),
155
+ endtime=self.convert_time(prev_end),
156
+ speaker=prev_speaker,
157
+ text=text.strip(),
158
+ idx=idx,
159
+ prev_utterance=prev_utterance,
160
+ prev_prev_utterance=prev_prev_utterance)
161
+ utterances_list.append(utterance)
162
+
163
+ return utterances_list
164
+
165
+
166
+ def __call__(self, data_file: str) -> List[Dict[str, Any]]:
167
+ ''' data_file is a str pointing to filename of type .vtt '''
168
+
169
+ utterances_list = []
170
+ for utterance in self.process_vtt_transcript(data_file):
171
+ #TODO: filter out to only have SL utterances
172
+ utterance_str, is_list = self.utterance_to_str(utterance)
173
+ if is_list == 'list':
174
+ utterances_list.extend(utterance_str)
175
+ else:
176
+ utterances_list.append(utterance_str)
177
+
178
+ predictions, raw_outputs = self.model.predict(utterances_list)
179
+
180
+ return predictions
test_run_handler.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+ # init handler
4
+ my_handler = EndpointHandler(path=".")
5
+
6
+ # prepare sample payload
7
+ test_payload = 'test.transcript.vtt'
8
+
9
+ # test the handler
10
+ test_pred=my_handler(test_payload)
11
+
12
+ # show results
13
+ print("test_pred", test_pred)