Arxiv__Recommendations / bert_embeddings.py
shubham142000's picture
Update bert_embeddings.py
48d1b47 verified
raw
history blame
510 Bytes
from sentence_transformers import SentenceTransformer
import numpy as np
def get_sentence_embeddings(positive_text, unlabelled_text):
# Initialize SentenceTransformer model
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Generate embeddings for positive text
positive_embeddings = model.encode(positive_text)
# Generate embeddings for unlabelled text
unlabelled_embeddings = model.encode(unlabelled_text)
return positive_embeddings, unlabelled_embeddings