Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import lightning as pl | |
import torchvision.models as models | |
from torchmetrics import Accuracy | |
class Classifier(pl.LightningModule): | |
def __init__(self, transfer=True): | |
super(Classifier, self).__init__() | |
self.conv = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1) # Adjust input channels to 3 | |
self.model = models.efficientnet_b1(weights='IMAGENET1K_V1') | |
if transfer: | |
# layers are frozen by using eval() | |
self.model.eval() | |
# freeze params | |
for p in self.model.parameters() : | |
p.requires_grad = False | |
num_ftrs = 1280 | |
self.model.classifier = nn.Sequential( | |
nn.LeakyReLU(), | |
nn.Dropout(p=0.3), | |
nn.Linear(in_features=num_ftrs , out_features=2), | |
nn.Softmax(dim=1) | |
) | |
self.criterion = nn.CrossEntropyLoss() | |
self.train_accuracy = Accuracy(task='binary') | |
self.val_accuracy = Accuracy(task='binary') | |
def forward(self, x): | |
x = self.conv(x) | |
return self.model(x) | |
def training_step(self, batch, batch_idx): | |
images, labels = batch | |
outputs = self(images) | |
loss = self.criterion(outputs, labels) | |
self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=True) | |
# Calculate and log accuracy | |
_, preds = torch.max(outputs, 1) | |
acc = self.train_accuracy(preds, labels) | |
self.log('train_acc', acc, prog_bar=True, on_step=True, on_epoch=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
images, labels = batch | |
outputs = self(images) | |
loss = self.criterion(outputs, labels) | |
self.log('val_loss', loss, prog_bar=True, sync_dist=True) | |
# Calculate and log accuracy | |
_, preds = torch.max(outputs, 1) | |
acc = self.val_accuracy(preds, labels) | |
self.log('val_acc', acc, prog_bar=True, sync_dist=True) | |
return loss | |
def on_train_epoch_end(self): | |
self.train_accuracy.reset() | |
def on_validation_epoch_end(self): | |
self.val_accuracy.reset() | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=0.0001) | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, verbose=True) | |
return { | |
'optimizer': optimizer, | |
'lr_scheduler': { | |
'scheduler': scheduler, | |
'monitor': 'val_loss', | |
}, | |
'monitor': 'val_loss' | |
} |