venkyyuvy commited on
Commit
7e9b1a3
·
1 Parent(s): d104b05

pytorch lightning removal

Browse files
Files changed (4) hide show
  1. app.py +5 -6
  2. model.py +0 -118
  3. requirements.txt +0 -1
  4. resnet.pth → resnet18.pth +2 -2
app.py CHANGED
@@ -1,20 +1,19 @@
1
  import operator
2
  import torch
3
- from torchvision import transforms
4
  import gradio as gr
5
  from pytorch_grad_cam import GradCAM
6
  from pytorch_grad_cam.utils.image import show_cam_on_image
7
  import gradio as gr
8
- import model
9
  from data_loader import CIFAR_CLASS_LABELS, TEST_TRANSFORM
10
  import matplotlib
 
11
  matplotlib.use('agg')
12
  from matplotlib import pyplot as plt
13
 
14
- resnet_18 = model.LitResnet()
15
- state_dict = torch.load("resnet.pth", map_location=torch.device('cpu'))
16
- resnet_18.load_state_dict(state_dict)
17
- resnet_18_model = resnet_18.model
18
 
19
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
20
  'dog', 'frog', 'horse', 'ship', 'truck')
 
1
  import operator
2
  import torch
 
3
  import gradio as gr
4
  from pytorch_grad_cam import GradCAM
5
  from pytorch_grad_cam.utils.image import show_cam_on_image
6
  import gradio as gr
 
7
  from data_loader import CIFAR_CLASS_LABELS, TEST_TRANSFORM
8
  import matplotlib
9
+ from model import ResNet18
10
  matplotlib.use('agg')
11
  from matplotlib import pyplot as plt
12
 
13
+
14
+ resnet_18_model = ResNet18()
15
+ resnet_18_model.load_state_dict(torch.load('resnet18.pth'))
16
+ resnet_18_model.eval()
17
 
18
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
19
  'dog', 'frog', 'horse', 'ship', 'truck')
model.py CHANGED
@@ -79,121 +79,3 @@ class ResNet(nn.Module):
79
  def ResNet18():
80
  return ResNet(BasicBlock, [2, 2, 2, 2])
81
 
82
- import torch.nn as nn
83
- from torch.optim.lr_scheduler import OneCycleLR
84
- from torch.utils.data import DataLoader
85
- import matplotlib.pyplot as plt
86
-
87
- from data_loader import CifarAlbumentationsDataset,\
88
- CIFAR_CLASS_LABELS, TRAIN_TRANSFORM, TEST_TRANSFORM
89
- import model
90
- from torch_lr_finder import LRFinder
91
-
92
- import torch
93
- import torch.nn as nn
94
- import torch.nn.functional as F
95
- from pytorch_lightning import LightningModule
96
- from torch.optim.lr_scheduler import OneCycleLR
97
- from torchmetrics.functional import accuracy
98
-
99
- class LitResnet(LightningModule):
100
- def __init__(self, lr=0.03, batch_size=512):
101
- super().__init__()
102
-
103
- self.save_hyperparameters()
104
- self.criterion = nn.CrossEntropyLoss()
105
- self.model = ResNet18()
106
-
107
- def forward(self, x):
108
- return self.model(x)
109
-
110
- def training_step(self, batch, batch_idx):
111
- x, y = batch
112
- output = self.forward(x)
113
- loss = self.criterion(output, y)
114
- self.log("train_loss", loss)
115
- acc = accuracy(torch.argmax(output, dim=1),
116
- y, 'multiclass', num_classes=10)
117
- self.log(f"train_acc", acc, prog_bar=True)
118
- return loss
119
-
120
- def evaluate(self, batch, stage=None):
121
- x, y = batch
122
- output = self.forward(x)
123
- loss = self.criterion(output, y)
124
- preds = torch.argmax(output, dim=1)
125
- acc = accuracy(preds, y, 'multiclass', num_classes=10)
126
-
127
- if stage:
128
- self.log(f"{stage}_loss", loss, prog_bar=True)
129
- self.log(f"{stage}_acc", acc, prog_bar=True)
130
-
131
- def validation_step(self, batch, batch_idx):
132
- self.evaluate(batch, "val")
133
-
134
- def test_step(self, batch, batch_idx):
135
- self.evaluate(batch, "test")
136
-
137
- # todo
138
- # change the default for num_iter
139
- def lr_finder(self, optimizer, num_iter=200,):
140
- lr_finder = LRFinder(self, optimizer, self.criterion,
141
- device=self.device)
142
- lr_finder.range_test(
143
- self.train_dataloader(), end_lr=1,
144
- num_iter=num_iter, step_mode='exp',
145
- )
146
- ax, suggested_lr = lr_finder.plot(suggest_lr=True)
147
- # todo
148
- # how to log maplotlib images
149
- # self.logger.experiment.add_image('lr_finder', plt.gcf(), 0)
150
- lr_finder.reset()
151
- return suggested_lr
152
- def configure_optimizers(self):
153
- optimizer = torch.optim.SGD(
154
- self.parameters(),
155
- lr=self.hparams.lr,
156
- momentum=0.9,
157
- weight_decay=5e-4,
158
- )
159
- suggested_lr = self.lr_finder(optimizer)
160
- steps_per_epoch = len(self.train_dataloader())
161
- scheduler_dict = {
162
- "scheduler": OneCycleLR(
163
- optimizer, max_lr=suggested_lr,
164
- steps_per_epoch=steps_per_epoch,
165
- epochs=self.trainer.max_epochs,
166
- pct_start=5/self.trainer.max_epochs,
167
- three_phase=False,
168
- div_factor=100,
169
- final_div_factor=100,
170
- anneal_strategy='linear',
171
- ),
172
- "interval": "step",
173
- }
174
- return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
175
- ####################
176
- # DATA RELATED HOOKS
177
- ####################
178
-
179
- def prepare_data(self, data_path='../data'):
180
- CifarAlbumentationsDataset(
181
- data_path, train=True, download=True)
182
- CifarAlbumentationsDataset(
183
- data_path, train=False, download=True)
184
-
185
- def setup(self, stage=None, data_dir='../data'):
186
-
187
- if stage == "fit" or stage is None:
188
- self.train_dataset = CifarAlbumentationsDataset(data_dir, train=True, transform=TRAIN_TRANSFORM)
189
- self.test_dataset = CifarAlbumentationsDataset(data_dir, train=False, transform=TEST_TRANSFORM)
190
-
191
-
192
- def train_dataloader(self):
193
- return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size,
194
- shuffle=True, pin_memory=True) #num_workers=4,
195
-
196
- def val_dataloader(self):
197
- return DataLoader(self.test_dataset, batch_size=self.hparams.batch_size,
198
- shuffle=False, pin_memory=True)
199
-
 
79
  def ResNet18():
80
  return ResNet(BasicBlock, [2, 2, 2, 2])
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -5,4 +5,3 @@ grad-cam
5
  pillow
6
  numpy
7
  albumentations
8
- pytorch-lightning
 
5
  pillow
6
  numpy
7
  albumentations
 
resnet.pth → resnet18.pth RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9bfd0a13c1f17b0977282a76b56aace7b682c1e3ac6ea12716ccbba97aa7afd3
3
- size 44774390
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5cc1136d85d13f5a4152098208cdddf533cf344ffbfb30a06dcb698dddf862f
3
+ size 44772860