shubham142000 commited on
Commit
48d1b47
·
verified ·
1 Parent(s): e1efc31

Update bert_embeddings.py

Browse files
Files changed (1) hide show
  1. bert_embeddings.py +9 -35
bert_embeddings.py CHANGED
@@ -1,40 +1,14 @@
1
- from transformers import BertTokenizer, BertModel
2
- import torch
3
  import numpy as np
4
 
5
- def get_bert_embeddings_from_texts(positive_text, unlabelled_text, batch_size=32):
6
- # Initialize BERT tokenizer and model
7
- bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
8
- bert_model = BertModel.from_pretrained('bert-base-uncased')
9
 
10
- def get_bert_embeddings(texts, tokenizer, model, batch_size=32):
11
- all_embeddings = []
12
 
13
- for i in range(0, len(texts), batch_size):
14
- batch_texts = texts[i:i+batch_size]
15
-
16
- # Tokenize the batch of texts
17
- tokens = tokenizer(batch_texts, padding=True, truncation=True, return_tensors='pt')
18
-
19
- # Move input tensors to GPU if available
20
- if torch.cuda.is_available():
21
- tokens = {k: v.to('cuda') for k, v in tokens.items()}
22
-
23
- # Get the BERT embeddings for the batch
24
- with torch.no_grad():
25
- embeddings = model(**tokens)[0]
26
- embeddings = embeddings.mean(dim=1)
27
-
28
- all_embeddings.append(embeddings.cpu())
29
-
30
- all_embeddings = torch.cat(all_embeddings, dim=0)
31
- return all_embeddings
32
-
33
- # Get BERT embeddings for positive labeled data
34
- bert_embeddings_positive = get_bert_embeddings(positive_text, bert_tokenizer, bert_model)
35
-
36
- # Get BERT embeddings for unlabeled data
37
- bert_embeddings_unlabeled = get_bert_embeddings(unlabelled_text, bert_tokenizer, bert_model)
38
-
39
- return bert_embeddings_positive, bert_embeddings_unlabeled
40
 
 
 
1
+ from sentence_transformers import SentenceTransformer
 
2
  import numpy as np
3
 
4
+ def get_sentence_embeddings(positive_text, unlabelled_text):
5
+ # Initialize SentenceTransformer model
6
+ model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
7
 
8
+ # Generate embeddings for positive text
9
+ positive_embeddings = model.encode(positive_text)
10
 
11
+ # Generate embeddings for unlabelled text
12
+ unlabelled_embeddings = model.encode(unlabelled_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ return positive_embeddings, unlabelled_embeddings