import os import torch from torch import nn, optim from tqdm import tqdm from huggingface_hub import HfApi import numpy as np from sklearn.metrics import (precision_score, recall_score, f1_score, roc_auc_score, cohen_kappa_score, matthews_corrcoef, confusion_matrix) from modeling_sagvit import SAGViTClassifier from data_loader import get_dataloaders ##################################################################### # This file provides the training loop and metric computation. It uses # the SAG-ViT model defined in sag_vit_model.py, and the data from data_loader.py. # The training loop is adapted to implement early stopping and track various metrics. ##################################################################### def train_model(model, model_name, train_loader, val_loader, num_epochs, criterion, optimizer, device, patience=8, verbose=True): """ Trains the SAG-ViT model and evaluates it on the validation set. Implements early stopping based on validation loss. Parameters: - model (nn.Module): The SAG-ViT model. - model_name (str): A name to identify the model (used for saving checkpoints). - train_loader, val_loader: DataLoaders for training and validation. - num_epochs (int): Maximum number of epochs. - criterion (nn.Module): Loss function. - optimizer (torch.optim.Optimizer): Optimization algorithm. - device (torch.device): Device to run the computations on (CPU/GPU). - patience (int): Early stopping patience. Returns: - history (dict): Dictionary containing training and validation metrics per epoch. """ history = { 'train_loss': [], 'train_acc': [], 'train_prec': [], 'train_rec': [], 'train_f1': [], 'train_auc': [], 'train_mcc': [], 'train_cohen_kappa': [], 'train_confusion_matrix': [], 'val_loss': [], 'val_acc': [], 'val_prec': [], 'val_rec': [], 'val_f1': [], 'val_auc': [], 'val_mcc': [], 'val_cohen_kappa': [], 'val_confusion_matrix': [] } best_val_loss = float('inf') patience_counter = 0 best_model_state = None for epoch in range(num_epochs): print(f'Epoch {epoch+1}/{num_epochs}') model.train() train_loss_total, correct, total = 0, 0, 0 all_preds, all_labels, all_probs = [], [], [] # Training loop for batch_idx, (X, y) in enumerate(tqdm(train_loader)): inputs, labels = X.to(device), y.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss_total += loss.item() probs = torch.softmax(outputs, dim=1) _, preds = torch.max(outputs, 1) correct += (preds == labels).sum().item() total += labels.size(0) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) all_probs.extend(probs.detach().cpu().numpy()) # Compute training metrics train_acc = correct / total train_prec = precision_score(all_labels, all_preds, average='macro', zero_division=0) train_rec = recall_score(all_labels, all_preds, average='macro') train_f1 = f1_score(all_labels, all_preds, average='macro') train_cohen_kappa = cohen_kappa_score(all_labels, all_preds) train_mcc = matthews_corrcoef(all_labels, all_preds) train_confusion = confusion_matrix(all_labels, all_preds) history['train_loss'].append(train_loss_total / len(train_loader)) history['train_acc'].append(train_acc) history['train_prec'].append(train_prec) history['train_rec'].append(train_rec) history['train_f1'].append(train_f1) history['train_cohen_kappa'].append(train_cohen_kappa) history['train_mcc'].append(train_mcc) history['train_confusion_matrix'].append(train_confusion) # Validation model.eval() val_loss_total, correct, total = 0, 0, 0 all_preds, all_labels, all_probs = [], [], [] with torch.no_grad(): for batch_idx, (X, y) in enumerate(tqdm(val_loader)): inputs, labels = X.to(device), y.to(device) outputs = model(inputs) loss = criterion(outputs, labels) val_loss_total += loss.item() probs = torch.softmax(outputs, dim=1) _, preds = torch.max(outputs, 1) correct += (preds == labels).sum().item() total += labels.size(0) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) all_probs.extend(probs.detach().cpu().numpy()) # Compute validation metrics val_acc = correct / total val_prec = precision_score(all_labels, all_preds, average='macro', zero_division=0) val_rec = recall_score(all_labels, all_preds, average='macro') val_f1 = f1_score(all_labels, all_preds, average='macro') val_cohen_kappa = cohen_kappa_score(all_labels, all_preds) val_mcc = matthews_corrcoef(all_labels, all_preds) val_confusion = confusion_matrix(all_labels, all_preds) history['val_loss'].append(val_loss_total / len(val_loader)) history['val_acc'].append(val_acc) history['val_prec'].append(val_prec) history['val_rec'].append(val_rec) history['val_f1'].append(val_f1) history['val_cohen_kappa'].append(val_cohen_kappa) history['val_mcc'].append(val_mcc) history['val_confusion_matrix'].append(val_confusion) # Print epoch summary if verbose: print(f"Train Loss: {history['train_loss'][-1]:.4f}, Train Acc: {history['train_acc'][-1]:.4f}, " f"Val Loss: {history['val_loss'][-1]:.4f}, Val Acc: {history['val_acc'][-1]:.4f}") # Early stopping current_val_loss = history['val_loss'][-1] if current_val_loss < best_val_loss: best_val_loss = current_val_loss best_model_state = model.state_dict() patience_counter = 0 else: patience_counter += 1 print(f"Patience counter: {patience_counter}/{patience}") if patience_counter >= patience: print("Early stopping triggered.") model.load_state_dict(best_model_state) torch.save(model.state_dict(), f'{model_name}.pth') return history model.load_state_dict(best_model_state) torch.save(model.state_dict(), f'{model_name}.pth') return history if __name__ == "__main__": device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Training on device: {device}") data_dir = "data/PlantVillage" # "path/to/data/dir" num_classes = len(os.listdir(data_dir)) train_loader, val_loader = get_dataloaders(data_dir=data_dir, img_size=224, batch_size=32) # Minimum image size should be atleast (49, 49) model = SAGViTClassifier(num_classes=num_classes).to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.0001) num_epochs = 100 history = train_model( model, 'SAG-ViT', train_loader, val_loader, num_epochs, criterion, optimizer, device ) # You may save history to a CSV or analyze it further as needed. # Example: # import pandas as pd # history_df = pd.DataFrame(history) # history_df.to_csv("training_history.csv", index=False) # Load the saved model back (best practice before pushing) model.load_state_dict(torch.load("SAG-ViT.pth")) model.eval() # Push the model to the Hugging Face Hub model.push_to_hub("shravvvv/SAG-ViT", commit_message="Initial model push", private=True, trust_remote_code=True)