|
from typing import List, Dict, Any |
|
import numpy as np |
|
from transformers import BertTokenizer, BertModel |
|
import torch |
|
import pickle |
|
|
|
|
|
def unpickle_obj(filepath): |
|
with open(filepath, 'rb') as f_in: |
|
data = pickle.load(f_in) |
|
print(f"unpickled {filepath}") |
|
return data |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.model = unpickle_obj(path) |
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.bert = BertModel.from_pretrained('bert-base-uncased').to(self.device) |
|
|
|
def get_embeddings(self, texts: List[str]): |
|
inputs = self.tokenizer(texts, return_tensors='pt', truncation=True, |
|
padding=True, max_length=512).to(self.device) |
|
with torch.no_grad(): |
|
outputs = self.bert(**inputs) |
|
return outputs.last_hidden_state.mean(dim=1).cpu().numpy() |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
queries = data['queries'] |
|
texts = data['texts'] |
|
queries_vec = self.get_embeddings(queries) |
|
texts_vec = self.get_embeddings(texts) |
|
diff = (np.array(texts_vec)[:, np.newaxis] - np.array(queries_vec))\ |
|
.reshape(-1, len(queries_vec[0])) |
|
return self.model.predict_proba(diff) |