from typing import List, Dict, Any import numpy as np from transformers import BertTokenizer, BertModel import torch import pickle def unpickle_obj(filepath): with open(filepath, 'rb') as f_in: data = pickle.load(f_in) print(f"unpickled {filepath}") return data class EndpointHandler(): def __init__(self, path=""): self.model = unpickle_obj(path) self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.bert = BertModel.from_pretrained('bert-base-uncased').to(self.device) def get_embeddings(self, texts: List[str]): inputs = self.tokenizer(texts, return_tensors='pt', truncation=True, padding=True, max_length=512).to(self.device) with torch.no_grad(): outputs = self.bert(**inputs) return outputs.last_hidden_state.mean(dim=1).cpu().numpy() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: queries = data['queries'] texts = data['texts'] queries_vec = self.get_embeddings(queries) texts_vec = self.get_embeddings(texts) diff = (np.array(texts_vec)[:, np.newaxis] - np.array(queries_vec))\ .reshape(-1, len(queries_vec[0])) return self.model.predict_proba(diff)