Ashlee Kupor commited on
Commit
a373a60
·
1 Parent(s): 12e14fb

Add handler

Browse files
Files changed (2) hide show
  1. handler.py +171 -0
  2. test_run_handler.py +13 -0
handler.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from simpletransformers.classification import ClassificationModel, ClassificationArgs
2
+ from typing import Dict, List, Any
3
+ import pandas as pd
4
+ import webvtt
5
+ from datetime import datetime
6
+ import torch
7
+ import spacy
8
+
9
+ nlp = spacy.load("en_core_web_sm")
10
+ tokenizer = nlp.tokenizer
11
+ token_limit = 200
12
+
13
+ class Utterance(object):
14
+
15
+ def __init__(self, starttime, endtime, speaker, text,
16
+ idx, prev_utterance, prev_prev_utterance):
17
+ self.starttime = starttime
18
+ self.endtime = endtime
19
+ self.speaker = speaker
20
+ self.text = text
21
+ self.idx = idx
22
+ self.prev_utterance = prev_utterance
23
+ self.prev_prev_utterance = prev_prev_utterance
24
+
25
+ class EndpointHandler():
26
+ def __init__(self, path="."):
27
+ print("Loading models...")
28
+ cuda_available = torch.cuda.is_available()
29
+ self.model = ClassificationModel(
30
+ "roberta", path, use_cuda=cuda_available
31
+ )
32
+
33
+ def utterance_to_str(self, utterance: Utterance) -> (List[str], str):
34
+ #eliciting uses prior text
35
+
36
+ doc = nlp(utterance.text)
37
+ prior_text = self.get_prior_text(utterance)
38
+
39
+ if len(doc) > token_limit:
40
+ utterance_text_list = self.handle_long_utterances(doc)
41
+ utterance_with_prior_text = []
42
+ for text in utterance_text_list:
43
+ utterance_with_prior_text.append([prior_text, text])
44
+ return utterance_with_prior_text, 'list'
45
+
46
+ else:
47
+ return [prior_text, utterance.text], 'single'
48
+
49
+ def format_speaker(self, speaker: str, source: str) -> str:
50
+ prior_text = ''
51
+ if speaker == 'student':
52
+ prior_text += '***STUDENT '
53
+ else:
54
+ prior_text += '***SECTION_LEADER '
55
+ if source == 'not chat':
56
+ prior_text += '(audio)*** : '
57
+ else:
58
+ prior_text += '(chat)*** : '
59
+ return prior_text
60
+
61
+ def get_prior_text(self, utterance: Utterance) -> str:
62
+ prior_text = ''
63
+ if utterance.prev_utterance != None and utterance.prev_prev_utterance != None:
64
+ #TODO: add in the source
65
+ prior_text = '\"' + self.format_speaker(utterance.prev_prev_utterance.speaker, 'not chat') + utterance.prev_prev_utterance.text + ' \n '
66
+ prior_text += self.format_speaker(utterance.prev_utterance.speaker, 'not chat') + utterance.prev_utterance.text + ' \n '
67
+ else:
68
+ prior_text = 'No prior utterance'
69
+ return prior_text
70
+
71
+ def handle_long_utterances(self, doc: str) -> List[str]:
72
+ split_count = 1
73
+ total_sent = len([x for x in doc.sents])
74
+ sent_count = 0
75
+ token_count = 0
76
+ split_utterance = ''
77
+ utterances = []
78
+ for sent in doc.sents:
79
+ # add a sentence to split
80
+ split_utterance = split_utterance + ' ' + sent.text
81
+ token_count += len(sent)
82
+ sent_count +=1
83
+ if token_count >= token_limit or sent_count == total_sent:
84
+ # save utterance segment
85
+ utterances.append(split_utterance)
86
+
87
+ # restart count
88
+ split_utterance = ''
89
+ token_count = 0
90
+ split_count += 1
91
+
92
+ return utterances
93
+
94
+ def convert_time(self, time_str):
95
+ time = datetime.strptime(time_str, "%H:%M:%S.%f")
96
+ return 1000 * (3600 * time.hour + 60 * time.minute + time.second) + time.microsecond / 1000
97
+
98
+ def process_vtt_transcript(self, vttfile) -> List[Utterance]:
99
+ """Process raw vtt file."""
100
+
101
+ utterances_list = []
102
+ text = ""
103
+ prev_start = "00:00:00.000"
104
+ prev_end = "00:00:00.000"
105
+ idx = 0
106
+ prev_speaker = None
107
+ prev_utterance = None
108
+ prev_prev_utterance = None
109
+ for caption in webvtt.read(vttfile):
110
+
111
+ # Get speaker
112
+ check_for_speaker = caption.text.split(":")
113
+ if len(check_for_speaker) > 1: # the speaker was changed or restated
114
+ speaker = check_for_speaker[0]
115
+ else:
116
+ speaker = prev_speaker
117
+
118
+ # Get utterance
119
+ new_text = check_for_speaker[1] if len(check_for_speaker) > 1 else check_for_speaker[0]
120
+
121
+ # If speaker was changed, start new batch
122
+ if (prev_speaker is not None) and (speaker != prev_speaker):
123
+ utterance = Utterance(starttime=self.convert_time(prev_start),
124
+ endtime=self.convert_time(prev_end),
125
+ speaker=prev_speaker,
126
+ text=text.strip(),
127
+ idx=idx,
128
+ prev_utterance=prev_utterance,
129
+ prev_prev_utterance=prev_prev_utterance)
130
+
131
+ utterances_list.append(utterance)
132
+
133
+ # Start new batch
134
+ prev_start = caption.start
135
+ text = ""
136
+ prev_prev_utterance = prev_utterance
137
+ prev_utterance = utterance
138
+ idx+=1
139
+ text += new_text + " "
140
+ prev_end = caption.end
141
+ prev_speaker = speaker
142
+
143
+ # Append last one
144
+ if prev_speaker is not None:
145
+ utterance = Utterance(starttime=self.convert_time(prev_start),
146
+ endtime=self.convert_time(prev_end),
147
+ speaker=prev_speaker,
148
+ text=text.strip(),
149
+ idx=idx,
150
+ prev_utterance=prev_utterance,
151
+ prev_prev_utterance=prev_prev_utterance)
152
+ utterances_list.append(utterance)
153
+
154
+ return utterances_list
155
+
156
+
157
+ def __call__(self, data_file: str) -> List[Dict[str, Any]]:
158
+ ''' data_file is a str pointing to filename of type .vtt '''
159
+
160
+ utterances_list = []
161
+ for utterance in self.process_vtt_transcript(data_file):
162
+ #TODO: filter out to only have SL utterances
163
+ utterance_str, is_list = self.utterance_to_str(utterance)
164
+ if is_list == 'list':
165
+ utterances_list.extend(utterance_str)
166
+ else:
167
+ utterances_list.append(utterance_str)
168
+
169
+ predictions, raw_outputs = self.model.predict(utterances_list)
170
+
171
+ return predictions
test_run_handler.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+ # init handler
4
+ my_handler = EndpointHandler(path=".")
5
+
6
+ # prepare sample payload
7
+ test_payload = 'test.transcript.vtt'
8
+
9
+ # test the handler
10
+ test_pred=my_handler(test_payload)
11
+
12
+ # show results
13
+ print("test_pred", test_pred)