cfe-gen / src /classifier.py
anindya-hf-2002's picture
upload application files
65eeb0e verified
raw
history blame
2.6 kB
import torch
import torch.nn as nn
import lightning as pl
import torchvision.models as models
from torchmetrics import Accuracy
class Classifier(pl.LightningModule):
def __init__(self, transfer=True):
super(Classifier, self).__init__()
self.conv = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1) # Adjust input channels to 3
self.model = models.efficientnet_b1(weights='IMAGENET1K_V1')
if transfer:
# layers are frozen by using eval()
self.model.eval()
# freeze params
for p in self.model.parameters() :
p.requires_grad = False
num_ftrs = 1280
self.model.classifier = nn.Sequential(
nn.LeakyReLU(),
nn.Dropout(p=0.3),
nn.Linear(in_features=num_ftrs , out_features=2),
nn.Softmax(dim=1)
)
self.criterion = nn.CrossEntropyLoss()
self.train_accuracy = Accuracy(task='binary')
self.val_accuracy = Accuracy(task='binary')
def forward(self, x):
x = self.conv(x)
return self.model(x)
def training_step(self, batch, batch_idx):
images, labels = batch
outputs = self(images)
loss = self.criterion(outputs, labels)
self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=True)
# Calculate and log accuracy
_, preds = torch.max(outputs, 1)
acc = self.train_accuracy(preds, labels)
self.log('train_acc', acc, prog_bar=True, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
images, labels = batch
outputs = self(images)
loss = self.criterion(outputs, labels)
self.log('val_loss', loss, prog_bar=True, sync_dist=True)
# Calculate and log accuracy
_, preds = torch.max(outputs, 1)
acc = self.val_accuracy(preds, labels)
self.log('val_acc', acc, prog_bar=True, sync_dist=True)
return loss
def on_train_epoch_end(self):
self.train_accuracy.reset()
def on_validation_epoch_end(self):
self.val_accuracy.reset()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, verbose=True)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'monitor': 'val_loss',
},
'monitor': 'val_loss'
}