|
from simpletransformers.classification import ClassificationModel, ClassificationArgs |
|
from typing import Dict, List, Any |
|
import pandas as pd |
|
import webvtt |
|
from datetime import datetime |
|
import torch |
|
import spacy |
|
|
|
nlp = spacy.load("en_core_web_sm") |
|
tokenizer = nlp.tokenizer |
|
token_limit = 200 |
|
|
|
class Utterance(object): |
|
|
|
def __init__(self, starttime, endtime, speaker, text, |
|
idx, prev_utterance, prev_prev_utterance): |
|
self.starttime = starttime |
|
self.endtime = endtime |
|
self.speaker = speaker |
|
self.text = text |
|
self.idx = idx |
|
self.prev_utterance = prev_utterance |
|
self.prev_prev_utterance = prev_prev_utterance |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path="."): |
|
print("Loading models...") |
|
cuda_available = torch.cuda.is_available() |
|
self.model = ClassificationModel( |
|
"roberta", path, use_cuda=cuda_available |
|
) |
|
|
|
def utterance_to_str(self, utterance: Utterance) -> (List[str], str): |
|
|
|
|
|
doc = nlp(utterance.text) |
|
prior_text = self.truncate_end(self.get_prior_text(utterance)) |
|
|
|
if len(doc) > token_limit: |
|
utterance_text_list = self.handle_long_utterances(doc) |
|
utterance_with_prior_text = [] |
|
for text in utterance_text_list: |
|
utterance_with_prior_text.append([prior_text, text]) |
|
return utterance_with_prior_text, 'list' |
|
|
|
else: |
|
return [prior_text, utterance.text], 'single' |
|
|
|
def truncate_end(self, prior_text: str) -> str: |
|
max_seq_length = 512 |
|
prior_text_max_length = int(max_seq_length / 2) |
|
|
|
if len(prior_text) > prior_text_max_length: |
|
starting_index = len(prior_text) - prior_text_max_length |
|
return prior_text[starting_index:] |
|
return prior_text |
|
|
|
def format_speaker(self, speaker: str, source: str) -> str: |
|
prior_text = '' |
|
if speaker == 'student': |
|
prior_text += '***STUDENT ' |
|
else: |
|
prior_text += '***SECTION_LEADER ' |
|
if source == 'not chat': |
|
prior_text += '(audio)*** : ' |
|
else: |
|
prior_text += '(chat)*** : ' |
|
return prior_text |
|
|
|
def get_prior_text(self, utterance: Utterance) -> str: |
|
prior_text = '' |
|
if utterance.prev_utterance != None and utterance.prev_prev_utterance != None: |
|
|
|
prior_text = '\"' + self.format_speaker(utterance.prev_prev_utterance.speaker, 'not chat') + utterance.prev_prev_utterance.text + ' \n ' |
|
prior_text += self.format_speaker(utterance.prev_utterance.speaker, 'not chat') + utterance.prev_utterance.text + ' \n ' |
|
else: |
|
prior_text = 'No prior utterance' |
|
return prior_text |
|
|
|
def handle_long_utterances(self, doc: str) -> List[str]: |
|
split_count = 1 |
|
total_sent = len([x for x in doc.sents]) |
|
sent_count = 0 |
|
token_count = 0 |
|
split_utterance = '' |
|
utterances = [] |
|
for sent in doc.sents: |
|
|
|
split_utterance = split_utterance + ' ' + sent.text |
|
token_count += len(sent) |
|
sent_count +=1 |
|
if token_count >= token_limit or sent_count == total_sent: |
|
|
|
utterances.append(split_utterance) |
|
|
|
|
|
split_utterance = '' |
|
token_count = 0 |
|
split_count += 1 |
|
|
|
return utterances |
|
|
|
def convert_time(self, time_str): |
|
time = datetime.strptime(time_str, "%H:%M:%S.%f") |
|
return 1000 * (3600 * time.hour + 60 * time.minute + time.second) + time.microsecond / 1000 |
|
|
|
def process_vtt_transcript(self, vttfile) -> List[Utterance]: |
|
"""Process raw vtt file.""" |
|
|
|
utterances_list = [] |
|
text = "" |
|
prev_start = "00:00:00.000" |
|
prev_end = "00:00:00.000" |
|
idx = 0 |
|
prev_speaker = None |
|
prev_utterance = None |
|
prev_prev_utterance = None |
|
for caption in webvtt.read(vttfile): |
|
|
|
|
|
check_for_speaker = caption.text.split(":") |
|
if len(check_for_speaker) > 1: |
|
speaker = check_for_speaker[0] |
|
else: |
|
speaker = prev_speaker |
|
|
|
|
|
new_text = check_for_speaker[1] if len(check_for_speaker) > 1 else check_for_speaker[0] |
|
|
|
|
|
if (prev_speaker is not None) and (speaker != prev_speaker): |
|
utterance = Utterance(starttime=self.convert_time(prev_start), |
|
endtime=self.convert_time(prev_end), |
|
speaker=prev_speaker, |
|
text=text.strip(), |
|
idx=idx, |
|
prev_utterance=prev_utterance, |
|
prev_prev_utterance=prev_prev_utterance) |
|
|
|
utterances_list.append(utterance) |
|
|
|
|
|
prev_start = caption.start |
|
text = "" |
|
prev_prev_utterance = prev_utterance |
|
prev_utterance = utterance |
|
idx+=1 |
|
text += new_text + " " |
|
prev_end = caption.end |
|
prev_speaker = speaker |
|
|
|
|
|
if prev_speaker is not None: |
|
utterance = Utterance(starttime=self.convert_time(prev_start), |
|
endtime=self.convert_time(prev_end), |
|
speaker=prev_speaker, |
|
text=text.strip(), |
|
idx=idx, |
|
prev_utterance=prev_utterance, |
|
prev_prev_utterance=prev_prev_utterance) |
|
utterances_list.append(utterance) |
|
|
|
return utterances_list |
|
|
|
|
|
def __call__(self, data_file: str) -> List[Dict[str, Any]]: |
|
''' data_file is a str pointing to filename of type .vtt ''' |
|
|
|
utterances_list = [] |
|
for utterance in self.process_vtt_transcript(data_file): |
|
|
|
utterance_str, is_list = self.utterance_to_str(utterance) |
|
if is_list == 'list': |
|
utterances_list.extend(utterance_str) |
|
else: |
|
utterances_list.append(utterance_str) |
|
|
|
predictions, raw_outputs = self.model.predict(utterances_list) |
|
|
|
return predictions |
|
|