|
from typing import Dict, List, Any |
|
from scipy.special import softmax |
|
|
|
from utils import clean_str, clean_str_nopunct |
|
import torch |
|
from transformers import BertTokenizer |
|
from utils import MultiHeadModel, BertInputBuilder, get_num_words |
|
|
|
MODEL_CHECKPOINT='ddemszky/uptake-model' |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path="."): |
|
print("Loading models...") |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
self.input_builder = BertInputBuilder(tokenizer=self.tokenizer) |
|
self.model = MultiHeadModel.from_pretrained(path, head2size={"nsp": 2}) |
|
self.model.to(self.device) |
|
self.max_length = 120 |
|
|
|
def get_clean_text(self, text, remove_punct=False): |
|
if remove_punct: |
|
return clean_str_nopunct(text) |
|
return clean_str(text) |
|
|
|
def get_prediction(self, instance): |
|
instance["attention_mask"] = [[1] * len(instance["input_ids"])] |
|
for key in ["input_ids", "token_type_ids", "attention_mask"]: |
|
instance[key] = torch.tensor(instance[key]).unsqueeze(0) |
|
instance[key].to(self.device) |
|
|
|
output = self.model(input_ids=instance["input_ids"], |
|
attention_mask=instance["attention_mask"], |
|
token_type_ids=instance["token_type_ids"], |
|
return_pooler_output=False) |
|
return output |
|
|
|
def get_uptake_score(self, utterances, speakerA, speakerB): |
|
|
|
textA = self.get_clean_text(utterances[speakerA], remove_punct=False) |
|
textB = self.get_clean_text(utterances[speakerB], remove_punct=False) |
|
|
|
instance = self.input_builder.build_inputs([textA], textB, |
|
max_length=self.max_length, |
|
input_str=True) |
|
output = self.get_prediction(instance) |
|
uptake_score = softmax(output["nsp_logits"][0].tolist())[1] |
|
return uptake_score |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `list`) |
|
parameters (:obj: `dict`) |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
inputs = data.pop("inputs", data) |
|
params = data.pop("parameters", None) |
|
|
|
utterances = inputs |
|
print("EXAMPLES") |
|
for utt_pair in utterances[:3]: |
|
print("speaker A: %s" % utt_pair[params["speaker_A"]]) |
|
print("speaker B: %s" % utt_pair[params["speaker_B"]]) |
|
print("----") |
|
|
|
print("Running inference on %d examples..." % len(utterances)) |
|
self.model.eval() |
|
uptake_scores = [] |
|
with torch.no_grad(): |
|
for i, utt in enumerate(utterances): |
|
prev_num_words = get_num_words(utt[params["speaker_A"]]) |
|
if prev_num_words < params["student_min_words"]: |
|
uptake_scores.append(None) |
|
continue |
|
uptake_score = self.get_uptake_score(utterances=utt, |
|
speakerA=params["speaker_A"], |
|
speakerB=params["speaker_B"]) |
|
uptake_scores.append(uptake_score) |
|
|
|
return uptake_scores |