Spaces:
Sleeping
Sleeping
File size: 2,601 Bytes
65eeb0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
import torch.nn as nn
import lightning as pl
import torchvision.models as models
from torchmetrics import Accuracy
class Classifier(pl.LightningModule):
def __init__(self, transfer=True):
super(Classifier, self).__init__()
self.conv = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1) # Adjust input channels to 3
self.model = models.efficientnet_b1(weights='IMAGENET1K_V1')
if transfer:
# layers are frozen by using eval()
self.model.eval()
# freeze params
for p in self.model.parameters() :
p.requires_grad = False
num_ftrs = 1280
self.model.classifier = nn.Sequential(
nn.LeakyReLU(),
nn.Dropout(p=0.3),
nn.Linear(in_features=num_ftrs , out_features=2),
nn.Softmax(dim=1)
)
self.criterion = nn.CrossEntropyLoss()
self.train_accuracy = Accuracy(task='binary')
self.val_accuracy = Accuracy(task='binary')
def forward(self, x):
x = self.conv(x)
return self.model(x)
def training_step(self, batch, batch_idx):
images, labels = batch
outputs = self(images)
loss = self.criterion(outputs, labels)
self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=True)
# Calculate and log accuracy
_, preds = torch.max(outputs, 1)
acc = self.train_accuracy(preds, labels)
self.log('train_acc', acc, prog_bar=True, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
images, labels = batch
outputs = self(images)
loss = self.criterion(outputs, labels)
self.log('val_loss', loss, prog_bar=True, sync_dist=True)
# Calculate and log accuracy
_, preds = torch.max(outputs, 1)
acc = self.val_accuracy(preds, labels)
self.log('val_acc', acc, prog_bar=True, sync_dist=True)
return loss
def on_train_epoch_end(self):
self.train_accuracy.reset()
def on_validation_epoch_end(self):
self.val_accuracy.reset()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, verbose=True)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'monitor': 'val_loss',
},
'monitor': 'val_loss'
} |