uptake-model / handler.py
ddemszky
add custom handler
7800c33
raw
history blame
3.39 kB
from typing import Dict, List, Any
from scipy.special import softmax
from utils import clean_str, clean_str_nopunct
import torch
from transformers import BertTokenizer
from utils import MultiHeadModel, BertInputBuilder, get_num_words
MODEL_CHECKPOINT='ddemszky/uptake-model'
class EndpointHandler():
def __init__(self, path="."):
print("Loading models...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.input_builder = BertInputBuilder(tokenizer=self.tokenizer)
self.model = MultiHeadModel.from_pretrained(path, head2size={"nsp": 2})
self.model.to(self.device)
self.max_length = 120
def get_clean_text(self, text, remove_punct=False):
if remove_punct:
return clean_str_nopunct(text)
return clean_str(text)
def get_prediction(self, instance):
instance["attention_mask"] = [[1] * len(instance["input_ids"])]
for key in ["input_ids", "token_type_ids", "attention_mask"]:
instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1
instance[key].to(self.device)
output = self.model(input_ids=instance["input_ids"],
attention_mask=instance["attention_mask"],
token_type_ids=instance["token_type_ids"],
return_pooler_output=False)
return output
def get_uptake_score(self, utterances, speakerA, speakerB):
textA = self.get_clean_text(utterances[speakerA], remove_punct=False)
textB = self.get_clean_text(utterances[speakerB], remove_punct=False)
instance = self.input_builder.build_inputs([textA], textB,
max_length=self.max_length,
input_str=True)
output = self.get_prediction(instance)
uptake_score = softmax(output["nsp_logits"][0].tolist())[1]
return uptake_score
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `list`)
parameters (:obj: `dict`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
inputs = data.pop("inputs", data)
params = data.pop("parameters", None)
utterances = inputs
print("EXAMPLES")
for utt_pair in utterances[:3]:
print("speaker A: %s" % utt_pair[params["speaker_A"]])
print("speaker B: %s" % utt_pair[params["speaker_B"]])
print("----")
print("Running inference on %d examples..." % len(utterances))
self.model.eval()
uptake_scores = []
with torch.no_grad():
for i, utt in enumerate(utterances):
prev_num_words = get_num_words(utt[params["speaker_A"]])
if prev_num_words < params["student_min_words"]:
uptake_scores.append(None)
continue
uptake_score = self.get_uptake_score(utterances=utt,
speakerA=params["speaker_A"],
speakerB=params["speaker_B"])
uptake_scores.append(uptake_score)
return uptake_scores