|
from transformers import BertTokenizer, BertModel |
|
import torch |
|
import numpy as np |
|
|
|
def get_bert_embeddings_from_texts(positive_text, unlabelled_text, batch_size=32): |
|
|
|
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
bert_model = BertModel.from_pretrained('bert-base-uncased') |
|
|
|
def get_bert_embeddings(texts, tokenizer, model, batch_size=32): |
|
all_embeddings = [] |
|
|
|
for i in range(0, len(texts), batch_size): |
|
batch_texts = texts[i:i+batch_size] |
|
|
|
|
|
tokens = tokenizer(batch_texts, padding=True, truncation=True, return_tensors='pt') |
|
|
|
|
|
if torch.cuda.is_available(): |
|
tokens = {k: v.to('cuda') for k, v in tokens.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
embeddings = model(**tokens)[0] |
|
embeddings = embeddings.mean(dim=1) |
|
|
|
all_embeddings.append(embeddings.cpu()) |
|
|
|
all_embeddings = torch.cat(all_embeddings, dim=0) |
|
return all_embeddings |
|
|
|
|
|
bert_embeddings_positive = get_bert_embeddings(positive_text, bert_tokenizer, bert_model) |
|
|
|
|
|
bert_embeddings_unlabeled = get_bert_embeddings(unlabelled_text, bert_tokenizer, bert_model) |
|
|
|
return bert_embeddings_positive, bert_embeddings_unlabeled |
|
|
|
|