lr-bert-base-uncased / handler.py
Uan Sholanbayev
added model
9edfae2
raw
history blame
1.45 kB
from typing import List, Dict, Any
import numpy as np
from transformers import BertTokenizer, BertModel
import torch
import pickle
import os
def unpickle_obj(filepath):
with open(filepath, 'rb') as f_in:
data = pickle.load(f_in)
print(f"unpickled {filepath}")
return data
class EndpointHandler():
def __init__(self, path=""):
self.model = unpickle_obj(f"{os.getcwd()}/bert_lr.pkl")
self.tokenizer = BertTokenizer.from_pretrained(os.getcwd(), local_files_only=True)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.bert = BertModel.from_pretrained(os.getcwd()).to(self.device)
def get_embeddings(self, texts: List[str]):
inputs = self.tokenizer(texts, return_tensors='pt', truncation=True,
padding=True, max_length=512).to(self.device)
with torch.no_grad():
outputs = self.bert(**inputs)
return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
queries = data['queries']
texts = data['texts']
queries_vec = self.get_embeddings(queries)
texts_vec = self.get_embeddings(texts)
diff = (np.array(texts_vec)[:, np.newaxis] - np.array(queries_vec))\
.reshape(-1, len(queries_vec[0]))
return [{
"outputs": self.model.predict_proba(diff)
}]