from typing import Dict, List, Any from scipy.special import softmax import numpy as np import weakref from utils import clean_str, clean_str_nopunct import torch from utils import MultiHeadModel, BertInputBuilder, get_num_words import transformers from transformers import BertTokenizer, BertForSequenceClassification transformers.logging.set_verbosity_debug() UPTAKE_MODEL='ddemszky/uptake-model' REASONING_MODEL ='ddemszky/student-reasoning' QUESTION_MODEL ='ddemszky/question-detection' class Utterance: def __init__(self, speaker, text, uid=None, transcript=None, starttime=None, endtime=None, **kwargs): self.speaker = speaker self.text = text self.uid = uid self.starttime = starttime self.endtime = endtime self.transcript = weakref.ref(transcript) if transcript else None self.props = kwargs self.uptake = None self.reasoning = None self.question = None def get_clean_text(self, remove_punct=False): if remove_punct: return clean_str_nopunct(self.text) return clean_str(self.text) def get_num_words(self): return get_num_words(self.text) def to_dict(self): return { 'speaker': self.speaker, 'text': self.text, 'uid': self.uid, 'starttime': self.starttime, 'endtime': self.endtime, 'uptake': self.uptake, 'reasoning': self.reasoning, 'question': self.question, **self.props } def __repr__(self): return f"Utterance(speaker='{self.speaker}'," \ f"text='{self.text}', uid={self.uid}," \ f"starttime={self.starttime}, endtime={self.endtime}, props={self.props})" class Transcript: def __init__(self, **kwargs): self.utterances = [] self.params = kwargs def add_utterance(self, utterance): utterance.transcript = weakref.ref(self) self.utterances.append(utterance) def get_idx(self, idx): if idx >= len(self.utterances): return None return self.utterances[idx] def get_uid(self, uid): for utt in self.utterances: if utt.uid == uid: return utt return None def length(self): return len(self.utterances) def to_dict(self): return { 'utterances': [utterance.to_dict() for utterance in self.utterances], **self.params } def __repr__(self): return f"Transcript(utterances={self.utterances}, custom_params={self.params})" class QuestionModel: def __init__(self, device, tokenizer, input_builder, max_length=300, path=QUESTION_MODEL): print("Loading models...") self.device = device self.tokenizer = tokenizer self.input_builder = input_builder self.max_length = max_length self.model = MultiHeadModel.from_pretrained(path, head2size={"is_question": 2}) self.model.to(self.device) def run_inference(self, transcript): self.model.eval() with torch.no_grad(): for i, utt in enumerate(transcript.utterances): if "?" in utt.text: utt.question = 1 else: text = utt.get_clean_text(remove_punct=True) instance = self.input_builder.build_inputs([], text, max_length=self.max_length, input_str=True) output = self.get_prediction(instance) print(output) utt.question = np.argmax(output["is_question_logits"][0].tolist()) def get_prediction(self, instance): instance["attention_mask"] = [[1] * len(instance["input_ids"])] for key in ["input_ids", "token_type_ids", "attention_mask"]: instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1 instance[key].to(self.device) output = self.model(input_ids=instance["input_ids"], attention_mask=instance["attention_mask"], token_type_ids=instance["token_type_ids"], return_pooler_output=False) return output class ReasoningModel: def __init__(self, device, tokenizer, input_builder, max_length=128, path=REASONING_MODEL): print("Loading models...") self.device = device self.tokenizer = tokenizer self.input_builder = input_builder self.max_length = max_length self.model = BertForSequenceClassification.from_pretrained(path) self.model.to(self.device) def run_inference(self, transcript, min_num_words=8): self.model.eval() with torch.no_grad(): for i, utt in enumerate(transcript.utterances): if utt.get_num_words() >= min_num_words: instance = self.input_builder.build_inputs([], utt.text, max_length=self.max_length, input_str=True) output = self.get_prediction(instance) utt.reasoning = np.argmax(output["logits"][0].tolist()) def get_prediction(self, instance): instance["attention_mask"] = [[1] * len(instance["input_ids"])] for key in ["input_ids", "token_type_ids", "attention_mask"]: instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1 instance[key].to(self.device) output = self.model(input_ids=instance["input_ids"], attention_mask=instance["attention_mask"], token_type_ids=instance["token_type_ids"]) return output class UptakeModel: def __init__(self, device, tokenizer, input_builder, max_length=120, path=UPTAKE_MODEL): print("Loading models...") self.device = device self.tokenizer = tokenizer self.input_builder = input_builder self.max_length = max_length self.model = MultiHeadModel.from_pretrained(path, head2size={"nsp": 2}) self.model.to(self.device) def run_inference(self, transcript, min_prev_words, uptake_speaker=None): self.model.eval() prev_num_words = 0 prev_utt = None with torch.no_grad(): for i, utt in enumerate(transcript.utterances): if ((uptake_speaker is None) or (utt.speaker == uptake_speaker)) and (prev_num_words >= min_prev_words): textA = prev_utt.get_clean_text(remove_punct=False) textB = utt.get_clean_text(remove_punct=False) instance = self.input_builder.build_inputs([textA], textB, max_length=self.max_length, input_str=True) output = self.get_prediction(instance) utt.uptake = int(softmax(output["nsp_logits"][0].tolist())[1] > .8) prev_num_words = utt.get_num_words() prev_utt = utt def get_prediction(self, instance): instance["attention_mask"] = [[1] * len(instance["input_ids"])] for key in ["input_ids", "token_type_ids", "attention_mask"]: instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1 instance[key].to(self.device) output = self.model(input_ids=instance["input_ids"], attention_mask=instance["attention_mask"], token_type_ids=instance["token_type_ids"], return_pooler_output=False) return output class EndpointHandler(): def __init__(self, path="."): print("Loading models...") self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") self.input_builder = BertInputBuilder(tokenizer=self.tokenizer) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `list`): List of dicts, where each dict represents an utterance; each utterance object must have a `speaker`, `text` and `uid`and can include list of custom properties parameters (:obj: `dict`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs utterances = data.pop("inputs", data) params = data.pop("parameters", None) print("EXAMPLES") for utt in utterances[:3]: print("speaker %s: %s" % (utt["speaker"], utt["text"])) transcript = Transcript(filename=params.pop("filename", None)) for utt in utterances: transcript.add_utterance(Utterance(**utt)) print("Running inference on %d examples..." % transcript.length()) # Uptake uptake_model = UptakeModel(self.device, self.tokenizer, self.input_builder) uptake_model.run_inference(transcript, min_prev_words=params['uptake_min_num_words'], uptake_speaker=params.pop("uptake_speaker", None)) # Reasoning reasoning_model = ReasoningModel(self.device, self.tokenizer, self.input_builder) reasoning_model.run_inference(transcript) # Question question_model = QuestionModel(self.device, self.tokenizer, self.input_builder) question_model.run_inference(transcript) return transcript.to_dict()