revoicing / handler.py
Ashlee Kupor
Add handler
5499fc9
raw
history blame
6.81 kB
from simpletransformers.classification import ClassificationModel, ClassificationArgs
from typing import Dict, List, Any
import pandas as pd
import webvtt
from datetime import datetime
import torch
import spacy
nlp = spacy.load("en_core_web_sm")
tokenizer = nlp.tokenizer
token_limit = 200
class Utterance(object):
def __init__(self, starttime, endtime, speaker, text,
idx, prev_utterance, prev_prev_utterance):
self.starttime = starttime
self.endtime = endtime
self.speaker = speaker
self.text = text
self.idx = idx
self.prev_utterance = prev_utterance
self.prev_prev_utterance = prev_prev_utterance
class EndpointHandler():
def __init__(self, path="."):
print("Loading models...")
cuda_available = torch.cuda.is_available()
self.model = ClassificationModel(
"roberta", path, use_cuda=cuda_available
)
def utterance_to_str(self, utterance: Utterance) -> (List[str], str):
#revoicing using prior text and truncates end of the prior text
doc = nlp(utterance.text)
prior_text = self.truncate_end(self.get_prior_text(utterance))
if len(doc) > token_limit:
utterance_text_list = self.handle_long_utterances(doc)
utterance_with_prior_text = []
for text in utterance_text_list:
utterance_with_prior_text.append([prior_text, text])
return utterance_with_prior_text, 'list'
else:
return [prior_text, utterance.text], 'single'
def truncate_end(self, prior_text: str) -> str:
max_seq_length = 512
prior_text_max_length = int(max_seq_length / 2) #divide by 2 because 2 columns
if len(prior_text) > prior_text_max_length:
starting_index = len(prior_text) - prior_text_max_length
return prior_text[starting_index:]
return prior_text
def format_speaker(self, speaker: str, source: str) -> str:
prior_text = ''
if speaker == 'student':
prior_text += '***STUDENT '
else:
prior_text += '***SECTION_LEADER '
if source == 'not chat':
prior_text += '(audio)*** : '
else:
prior_text += '(chat)*** : '
return prior_text
def get_prior_text(self, utterance: Utterance) -> str:
prior_text = ''
if utterance.prev_utterance != None and utterance.prev_prev_utterance != None:
#TODO: add in the source
prior_text = '\"' + self.format_speaker(utterance.prev_prev_utterance.speaker, 'not chat') + utterance.prev_prev_utterance.text + ' \n '
prior_text += self.format_speaker(utterance.prev_utterance.speaker, 'not chat') + utterance.prev_utterance.text + ' \n '
else:
prior_text = 'No prior utterance'
return prior_text
def handle_long_utterances(self, doc: str) -> List[str]:
split_count = 1
total_sent = len([x for x in doc.sents])
sent_count = 0
token_count = 0
split_utterance = ''
utterances = []
for sent in doc.sents:
# add a sentence to split
split_utterance = split_utterance + ' ' + sent.text
token_count += len(sent)
sent_count +=1
if token_count >= token_limit or sent_count == total_sent:
# save utterance segment
utterances.append(split_utterance)
# restart count
split_utterance = ''
token_count = 0
split_count += 1
return utterances
def convert_time(self, time_str):
time = datetime.strptime(time_str, "%H:%M:%S.%f")
return 1000 * (3600 * time.hour + 60 * time.minute + time.second) + time.microsecond / 1000
def process_vtt_transcript(self, vttfile) -> List[Utterance]:
"""Process raw vtt file."""
utterances_list = []
text = ""
prev_start = "00:00:00.000"
prev_end = "00:00:00.000"
idx = 0
prev_speaker = None
prev_utterance = None
prev_prev_utterance = None
for caption in webvtt.read(vttfile):
# Get speaker
check_for_speaker = caption.text.split(":")
if len(check_for_speaker) > 1: # the speaker was changed or restated
speaker = check_for_speaker[0]
else:
speaker = prev_speaker
# Get utterance
new_text = check_for_speaker[1] if len(check_for_speaker) > 1 else check_for_speaker[0]
# If speaker was changed, start new batch
if (prev_speaker is not None) and (speaker != prev_speaker):
utterance = Utterance(starttime=self.convert_time(prev_start),
endtime=self.convert_time(prev_end),
speaker=prev_speaker,
text=text.strip(),
idx=idx,
prev_utterance=prev_utterance,
prev_prev_utterance=prev_prev_utterance)
utterances_list.append(utterance)
# Start new batch
prev_start = caption.start
text = ""
prev_prev_utterance = prev_utterance
prev_utterance = utterance
idx+=1
text += new_text + " "
prev_end = caption.end
prev_speaker = speaker
# Append last one
if prev_speaker is not None:
utterance = Utterance(starttime=self.convert_time(prev_start),
endtime=self.convert_time(prev_end),
speaker=prev_speaker,
text=text.strip(),
idx=idx,
prev_utterance=prev_utterance,
prev_prev_utterance=prev_prev_utterance)
utterances_list.append(utterance)
return utterances_list
def __call__(self, data_file: str) -> List[Dict[str, Any]]:
''' data_file is a str pointing to filename of type .vtt '''
utterances_list = []
for utterance in self.process_vtt_transcript(data_file):
#TODO: filter out to only have SL utterances
utterance_str, is_list = self.utterance_to_str(utterance)
if is_list == 'list':
utterances_list.extend(utterance_str)
else:
utterances_list.append(utterance_str)
predictions, raw_outputs = self.model.predict(utterances_list)
return predictions