# Import Required Libraries import torch import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset from transformers import BertModel, AdamW from sklearn.metrics import accuracy_score import numpy as np # Import functions from the preprocessing module from transactify.data_preprocessing import preprocessing_data, split_data, read_data # Define a BERT-based classification model class BertClassifier(nn.Module): def __init__(self, num_labels, dropout_rate=0.3): super(BertClassifier, self).__init__() self.bert = BertModel.from_pretrained("bert-base-uncased") self.dropout = nn.Dropout(dropout_rate) self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs[1] # Pooler output (CLS token) output = self.dropout(pooled_output) logits = self.classifier(output) return logits # Training the model # Training the model def train_model(model, train_dataloader, val_dataloader, device, epochs=3, lr=2e-5): optimizer = AdamW(model.parameters(), lr=lr) loss_fn = nn.CrossEntropyLoss() for epoch in range(epochs): model.train() total_train_loss = 0 for step, batch in enumerate(train_dataloader): b_input_ids, b_input_mask, b_labels = batch b_input_ids = b_input_ids.to(device) b_input_mask = b_input_mask.to(device) b_labels = b_labels.to(device).long() # Ensure labels are LongTensor model.zero_grad() outputs = model(b_input_ids, b_input_mask) loss = loss_fn(outputs, b_labels) total_train_loss += loss.item() loss.backward() optimizer.step() avg_train_loss = total_train_loss / len(train_dataloader) print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss}") model.eval() total_val_accuracy = 0 total_val_loss = 0 with torch.no_grad(): for batch in val_dataloader: b_input_ids, b_input_mask, b_labels = batch b_input_ids = b_input_ids.to(device) b_input_mask = b_input_mask.to(device) b_labels = b_labels.to(device) outputs = model(b_input_ids, b_input_mask) loss = loss_fn(outputs, b_labels) total_val_loss += loss.item() preds = torch.argmax(outputs, dim=1) total_val_accuracy += (preds == b_labels).sum().item() avg_val_accuracy = total_val_accuracy / len(val_dataloader.dataset) avg_val_loss = total_val_loss / len(val_dataloader) print(f"Validation Loss: {avg_val_loss}, Validation Accuracy: {avg_val_accuracy}") # Testing the model def test_model(model, test_dataloader, device): model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for batch in test_dataloader: b_input_ids, b_input_mask, b_labels = batch b_input_ids = b_input_ids.to(device) b_input_mask = b_input_mask.to(device) b_labels = b_labels.to(device) outputs = model(b_input_ids, b_input_mask) preds = torch.argmax(outputs, dim=1) all_preds.append(preds.cpu().numpy()) all_labels.append(b_labels.cpu().numpy()) all_preds = np.concatenate(all_preds) all_labels = np.concatenate(all_labels) accuracy = accuracy_score(all_labels, all_preds) print(f"Test Accuracy: {accuracy}") # Main function to train, validate, and test the model def main(data_path, epochs=3, batch_size=16): # Read and preprocess data data = read_data(data_path) if data is None: return input_ids, attention_masks, labels, labelencoder = preprocessing_data(data) X_train_ids, X_test_ids, X_train_masks, X_test_masks, y_train, y_test = split_data(input_ids, attention_masks, labels) # Determine the number of labels num_labels = len(labelencoder.classes_) # Create the model model = BertClassifier(num_labels) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Create dataloaders train_dataset = TensorDataset(X_train_ids, X_train_masks, y_train) train_dataloader = DataLoader(train_dataset, batch_size=batch_size) val_dataset = TensorDataset(X_test_ids, X_test_masks, y_test) val_dataloader = DataLoader(val_dataset, batch_size=batch_size) # Train the model train_model(model, train_dataloader, val_dataloader, device, epochs=epochs) # Test the model test_dataloader = DataLoader(val_dataset, batch_size=batch_size) test_model(model, test_dataloader, device) if __name__ == "__main__": data_path = r"E:\transactify\transactify\Dataset\transaction_data.csv" main(data_path)