shubham142000 commited on
Commit
e1efc31
1 Parent(s): 3f51766

Update bert_embeddings.py

Browse files
Files changed (1) hide show
  1. bert_embeddings.py +40 -0
bert_embeddings.py CHANGED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertModel
2
+ import torch
3
+ import numpy as np
4
+
5
+ def get_bert_embeddings_from_texts(positive_text, unlabelled_text, batch_size=32):
6
+ # Initialize BERT tokenizer and model
7
+ bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
8
+ bert_model = BertModel.from_pretrained('bert-base-uncased')
9
+
10
+ def get_bert_embeddings(texts, tokenizer, model, batch_size=32):
11
+ all_embeddings = []
12
+
13
+ for i in range(0, len(texts), batch_size):
14
+ batch_texts = texts[i:i+batch_size]
15
+
16
+ # Tokenize the batch of texts
17
+ tokens = tokenizer(batch_texts, padding=True, truncation=True, return_tensors='pt')
18
+
19
+ # Move input tensors to GPU if available
20
+ if torch.cuda.is_available():
21
+ tokens = {k: v.to('cuda') for k, v in tokens.items()}
22
+
23
+ # Get the BERT embeddings for the batch
24
+ with torch.no_grad():
25
+ embeddings = model(**tokens)[0]
26
+ embeddings = embeddings.mean(dim=1)
27
+
28
+ all_embeddings.append(embeddings.cpu())
29
+
30
+ all_embeddings = torch.cat(all_embeddings, dim=0)
31
+ return all_embeddings
32
+
33
+ # Get BERT embeddings for positive labeled data
34
+ bert_embeddings_positive = get_bert_embeddings(positive_text, bert_tokenizer, bert_model)
35
+
36
+ # Get BERT embeddings for unlabeled data
37
+ bert_embeddings_unlabeled = get_bert_embeddings(unlabelled_text, bert_tokenizer, bert_model)
38
+
39
+ return bert_embeddings_positive, bert_embeddings_unlabeled
40
+