transactify / bert_model.py
ai-venkat-r's picture
Model creation (#8)
e1a89b3 verified
raw
history blame
4.97 kB
# Import Required Libraries
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertModel, AdamW
from sklearn.metrics import accuracy_score
import numpy as np
# Import functions from the preprocessing module
from transactify.data_preprocessing import preprocessing_data, split_data, read_data
# Define a BERT-based classification model
class BertClassifier(nn.Module):
def __init__(self, num_labels, dropout_rate=0.3):
super(BertClassifier, self).__init__()
self.bert = BertModel.from_pretrained("bert-base-uncased")
self.dropout = nn.Dropout(dropout_rate)
self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs[1] # Pooler output (CLS token)
output = self.dropout(pooled_output)
logits = self.classifier(output)
return logits
# Training the model
# Training the model
def train_model(model, train_dataloader, val_dataloader, device, epochs=3, lr=2e-5):
optimizer = AdamW(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(epochs):
model.train()
total_train_loss = 0
for step, batch in enumerate(train_dataloader):
b_input_ids, b_input_mask, b_labels = batch
b_input_ids = b_input_ids.to(device)
b_input_mask = b_input_mask.to(device)
b_labels = b_labels.to(device).long() # Ensure labels are LongTensor
model.zero_grad()
outputs = model(b_input_ids, b_input_mask)
loss = loss_fn(outputs, b_labels)
total_train_loss += loss.item()
loss.backward()
optimizer.step()
avg_train_loss = total_train_loss / len(train_dataloader)
print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss}")
model.eval()
total_val_accuracy = 0
total_val_loss = 0
with torch.no_grad():
for batch in val_dataloader:
b_input_ids, b_input_mask, b_labels = batch
b_input_ids = b_input_ids.to(device)
b_input_mask = b_input_mask.to(device)
b_labels = b_labels.to(device)
outputs = model(b_input_ids, b_input_mask)
loss = loss_fn(outputs, b_labels)
total_val_loss += loss.item()
preds = torch.argmax(outputs, dim=1)
total_val_accuracy += (preds == b_labels).sum().item()
avg_val_accuracy = total_val_accuracy / len(val_dataloader.dataset)
avg_val_loss = total_val_loss / len(val_dataloader)
print(f"Validation Loss: {avg_val_loss}, Validation Accuracy: {avg_val_accuracy}")
# Testing the model
def test_model(model, test_dataloader, device):
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for batch in test_dataloader:
b_input_ids, b_input_mask, b_labels = batch
b_input_ids = b_input_ids.to(device)
b_input_mask = b_input_mask.to(device)
b_labels = b_labels.to(device)
outputs = model(b_input_ids, b_input_mask)
preds = torch.argmax(outputs, dim=1)
all_preds.append(preds.cpu().numpy())
all_labels.append(b_labels.cpu().numpy())
all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)
accuracy = accuracy_score(all_labels, all_preds)
print(f"Test Accuracy: {accuracy}")
# Main function to train, validate, and test the model
def main(data_path, epochs=3, batch_size=16):
# Read and preprocess data
data = read_data(data_path)
if data is None:
return
input_ids, attention_masks, labels, labelencoder = preprocessing_data(data)
X_train_ids, X_test_ids, X_train_masks, X_test_masks, y_train, y_test = split_data(input_ids, attention_masks, labels)
# Determine the number of labels
num_labels = len(labelencoder.classes_)
# Create the model
model = BertClassifier(num_labels)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Create dataloaders
train_dataset = TensorDataset(X_train_ids, X_train_masks, y_train)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
val_dataset = TensorDataset(X_test_ids, X_test_masks, y_test)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
# Train the model
train_model(model, train_dataloader, val_dataloader, device, epochs=epochs)
# Test the model
test_dataloader = DataLoader(val_dataset, batch_size=batch_size)
test_model(model, test_dataloader, device)
if __name__ == "__main__":
data_path = r"E:\transactify\transactify\Dataset\transaction_data.csv"
main(data_path)