File size: 2,601 Bytes
65eeb0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
import lightning as pl
import torchvision.models as models
from torchmetrics import Accuracy


class Classifier(pl.LightningModule):
    def __init__(self, transfer=True):
        super(Classifier, self).__init__()
        self.conv = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1)  # Adjust input channels to 3
        self.model = models.efficientnet_b1(weights='IMAGENET1K_V1')
        if transfer:
            # layers are frozen by using eval()
            self.model.eval()
            # freeze params
            for p in self.model.parameters() : 
                p.requires_grad = False
        num_ftrs = 1280
        self.model.classifier = nn.Sequential(
            nn.LeakyReLU(),
            nn.Dropout(p=0.3), 
            nn.Linear(in_features=num_ftrs , out_features=2),
            nn.Softmax(dim=1)  
        )

        self.criterion = nn.CrossEntropyLoss()
        self.train_accuracy = Accuracy(task='binary')
        self.val_accuracy = Accuracy(task='binary')

    def forward(self, x):
        x = self.conv(x)
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=True)
        # Calculate and log accuracy
        _, preds = torch.max(outputs, 1)
        acc = self.train_accuracy(preds, labels)
        self.log('train_acc', acc, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        self.log('val_loss', loss, prog_bar=True, sync_dist=True)
        # Calculate and log accuracy
        _, preds = torch.max(outputs, 1)
        acc = self.val_accuracy(preds, labels)
        self.log('val_acc', acc, prog_bar=True, sync_dist=True)
        return loss
    
    def on_train_epoch_end(self):
        self.train_accuracy.reset()

    def on_validation_epoch_end(self):
        self.val_accuracy.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, verbose=True)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss',
            },
            'monitor': 'val_loss'
        }