Arxiv__Recommendations / bert_embeddings.py
shubham142000's picture
Update bert_embeddings.py
e1efc31 verified
raw
history blame
1.53 kB
from transformers import BertTokenizer, BertModel
import torch
import numpy as np
def get_bert_embeddings_from_texts(positive_text, unlabelled_text, batch_size=32):
# Initialize BERT tokenizer and model
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
def get_bert_embeddings(texts, tokenizer, model, batch_size=32):
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
# Tokenize the batch of texts
tokens = tokenizer(batch_texts, padding=True, truncation=True, return_tensors='pt')
# Move input tensors to GPU if available
if torch.cuda.is_available():
tokens = {k: v.to('cuda') for k, v in tokens.items()}
# Get the BERT embeddings for the batch
with torch.no_grad():
embeddings = model(**tokens)[0]
embeddings = embeddings.mean(dim=1)
all_embeddings.append(embeddings.cpu())
all_embeddings = torch.cat(all_embeddings, dim=0)
return all_embeddings
# Get BERT embeddings for positive labeled data
bert_embeddings_positive = get_bert_embeddings(positive_text, bert_tokenizer, bert_model)
# Get BERT embeddings for unlabeled data
bert_embeddings_unlabeled = get_bert_embeddings(unlabelled_text, bert_tokenizer, bert_model)
return bert_embeddings_positive, bert_embeddings_unlabeled