File size: 1,349 Bytes
e3c0725 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import torch
from transformers import BertModel, BertTokenizerFast
import torch.nn.functional as F
def similarity(embeddings_1, embeddings_2):
normalized_embeddings_1 = F.normalize(embeddings_1, p=2)
normalized_embeddings_2 = F.normalize(embeddings_2, p=2)
return torch.matmul(
normalized_embeddings_1, normalized_embeddings_2.transpose(0, 1)
)
class LaBSE:
def __init__(self):
self.tokenizer = BertTokenizerFast.from_pretrained("setu4993/LaBSE")
self.model = BertModel.from_pretrained("setu4993/LaBSE")
#self.model = self.model.to('cuda')
self.model.eval()
@torch.no_grad()
def __call__(self, sentences):
if not isinstance(sentences, list):
sentences = [sentences]
tokens = self.tokenizer(sentences, return_tensors="pt", padding=True)
#print(tokens.input_ids.shape, tokens.token_type_ids.shape, tokens.attention_mask.shape)
#tokens = tokens.to("cuda")
outputs = self.model(**tokens)
embeddings = outputs.pooler_output
return F.normalize(embeddings, p=2).cpu()#.numpy()
@property
def dim(self):
return 768
if __name__ == "__main__":
model = LaBSE()
sents = ["arma virumque cano", "arma gravi numero violentaque bella parabam"]
output = model(sents)
print("end", output.shape) |