Spaces:
Sleeping
Sleeping
init commit
Browse files- README.md +3 -1
- app.py +130 -0
- data_loader.py +48 -0
- examples/airplane.jpg +0 -0
- examples/automobile.jpg +0 -0
- examples/cat.jpg +0 -0
- examples/deer.jpg +0 -0
- examples/dog.jpg +0 -0
- examples/frog.jpg +0 -0
- examples/horse.jpg +0 -0
- examples/ship.jpg +0 -0
- examples/truck.jpg +0 -0
- misclassified_images.pt +0 -0
- model.py +199 -0
- requirements.txt +0 -0
README.md
CHANGED
@@ -10,4 +10,6 @@ pinned: false
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
+
This application classifies the given images into one of the ten classes in the cifar 10 dataset. It provides the sample misclasification images done by the model in the test dataset.
|
14 |
+
|
15 |
+
The app also provides the option of visualizing the GradCAM (Gradient based Class activation mapping) output for model explainability. The user has the option to choose the layer
|
app.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import operator
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms
|
4 |
+
import gradio as gr
|
5 |
+
from pytorch_grad_cam import GradCAM
|
6 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
7 |
+
import gradio as gr
|
8 |
+
import model
|
9 |
+
from data_loader import CIFAR_CLASS_LABELS, TEST_TRANSFORM
|
10 |
+
import matplotlib
|
11 |
+
matplotlib.use('agg')
|
12 |
+
from matplotlib import pyplot as plt
|
13 |
+
|
14 |
+
resnet_18 = model.LitResnet()
|
15 |
+
state_dict = torch.load("saved_model.pth", map_location=torch.device('cpu'))
|
16 |
+
resnet_18.load_state_dict(state_dict)
|
17 |
+
resnet_18_model = resnet_18.model
|
18 |
+
|
19 |
+
classes = ('plane', 'car', 'bird', 'cat', 'deer',
|
20 |
+
'dog', 'frog', 'horse', 'ship', 'truck')
|
21 |
+
|
22 |
+
|
23 |
+
def inference(input_img, n_top_classes,
|
24 |
+
apply_gradcam, transparency=0.5,
|
25 |
+
target_layer_number = -1):
|
26 |
+
org_img = input_img
|
27 |
+
input_img = TEST_TRANSFORM(image=input_img)['image']
|
28 |
+
input_img = input_img.unsqueeze(0)
|
29 |
+
outputs = resnet_18_model(input_img)
|
30 |
+
softmax = torch.nn.Softmax(dim=0)
|
31 |
+
o = softmax(outputs.flatten())
|
32 |
+
y = {classes[i]: float(o[i]) for i in range(10)}
|
33 |
+
sorted_pred = sorted(y.items(), key=operator.itemgetter(1), reverse=True)
|
34 |
+
sorted_pred = sorted_pred[: n_top_classes]
|
35 |
+
confidences = {klass: prob for klass, prob in sorted_pred}
|
36 |
+
if apply_gradcam:
|
37 |
+
target_layers = [resnet_18_model.layer3[target_layer_number]]
|
38 |
+
cam = GradCAM(model=resnet_18_model, target_layers=target_layers, use_cuda=False)
|
39 |
+
grayscale_cam = cam(input_tensor=input_img, targets=None)
|
40 |
+
grayscale_cam = grayscale_cam[0, :]
|
41 |
+
visualization = show_cam_on_image(
|
42 |
+
org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
|
43 |
+
return (gr.update(value= confidences),
|
44 |
+
gr.update(value=visualization, visible=True))
|
45 |
+
return (gr.update(value=confidences),
|
46 |
+
gr.update(visible=False))
|
47 |
+
|
48 |
+
def show_misclasif(see_misclassif, n_images):
|
49 |
+
if see_misclassif:
|
50 |
+
subset = torch.load('misclassified_images.pt')
|
51 |
+
images, actuals, preds = torch.tensor(subset[0])[:20], subset[1], subset[2]
|
52 |
+
figsize=(n_images, 4)
|
53 |
+
nrows=2
|
54 |
+
ncols=n_images//2
|
55 |
+
fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
|
56 |
+
fig.suptitle('misclassified images', weight='bold', size=10)
|
57 |
+
axes = axes.ravel()
|
58 |
+
for img, actual, pred, ax in zip(images, actuals, preds, axes):
|
59 |
+
ax.imshow(img)
|
60 |
+
ax.set_title(
|
61 |
+
f'Prediction={CIFAR_CLASS_LABELS[pred]}\n Actual={CIFAR_CLASS_LABELS[actual]}',
|
62 |
+
fontsize=8)
|
63 |
+
ax.set(xticks=[], yticks=[], xticklabels=[], yticklabels=[])
|
64 |
+
ax.axis('off')
|
65 |
+
image_path = "plot.png"
|
66 |
+
fig.savefig(image_path)
|
67 |
+
plt.close()
|
68 |
+
return gr.update(value=image_path, visible=True)
|
69 |
+
|
70 |
+
|
71 |
+
with gr.Blocks() as demo:
|
72 |
+
with gr.Row():
|
73 |
+
with gr.Column():
|
74 |
+
input_image = gr.Image(shape=(32, 32), label="Input Image")
|
75 |
+
n_top_classes = gr.Slider(maximum=10, minimum=1, value=3, step=1,
|
76 |
+
label="Top n classes to show", interactive=True)
|
77 |
+
require_gradcam = gr.Checkbox(label="Apply GradCAM",
|
78 |
+
info="Do you want see the GRAD-CAM visualization")
|
79 |
+
opacity_gradcam = gr.Slider(0, 1, value=0.5,
|
80 |
+
label="Opacity of GradCAM")
|
81 |
+
layer_gradcam = gr.Slider(-2, -1, value=-2, step=1,
|
82 |
+
label="Which Layer?")
|
83 |
+
submit = gr.Button("Submit")
|
84 |
+
with gr.Column():
|
85 |
+
pred_classes = gr.Label()
|
86 |
+
grad_cam = gr.Image(shape=(32, 32),
|
87 |
+
label="Output",visible=False)\
|
88 |
+
.style(width=128, height=128)
|
89 |
+
with gr.Row():
|
90 |
+
with gr.Column():
|
91 |
+
see_misclassif = gr.Checkbox(label="View misclassified images",
|
92 |
+
info="Do you want see the miscassified images in the test dataset")
|
93 |
+
n_misclasif = gr.Slider(maximum=20, minimum=2, value=10, step=2,
|
94 |
+
label="Number of misclassified images to show",
|
95 |
+
interactive=True, visible=False)
|
96 |
+
render = gr.Button("Render", visible=False)
|
97 |
+
misclasif_display = gr.Image(visible=False)
|
98 |
+
|
99 |
+
n_top_classes.postprocess(n_top_classes.value)
|
100 |
+
submit.click(inference,
|
101 |
+
inputs=[input_image, n_top_classes, require_gradcam,
|
102 |
+
opacity_gradcam, layer_gradcam],
|
103 |
+
outputs=[pred_classes, grad_cam]
|
104 |
+
)
|
105 |
+
def turn_on_misclasif(see_misclassif):
|
106 |
+
if see_misclassif:
|
107 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
|
108 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
109 |
+
|
110 |
+
see_misclassif.change(turn_on_misclasif, see_misclassif, [n_misclasif, render, misclasif_display])
|
111 |
+
render.click(show_misclasif, [see_misclassif, n_misclasif], misclasif_display)
|
112 |
+
|
113 |
+
gr.Examples(
|
114 |
+
examples=[
|
115 |
+
["examples/truck.jpg", 3, True],
|
116 |
+
["examples/ship.jpg", 3, True],
|
117 |
+
["examples/dog.jpg", 3, True],
|
118 |
+
["examples/cat.jpg", 3, True],
|
119 |
+
["examples/horse.jpg", 3, True],
|
120 |
+
["examples/airplane.jpg", 3, True],
|
121 |
+
["examples/parrot.jpg", 3, True],
|
122 |
+
["examples/automobile.jpg", 3, True],
|
123 |
+
["examples/deer.jpg", 3, True],
|
124 |
+
["examples/frog.jpg", 3, True],
|
125 |
+
],
|
126 |
+
inputs=[input_image, n_top_classes, require_gradcam],
|
127 |
+
outputs=[pred_classes, grad_cam],
|
128 |
+
fn=inference,
|
129 |
+
)
|
130 |
+
demo.launch()
|
data_loader.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import datasets
|
2 |
+
import albumentations as A
|
3 |
+
from albumentations.pytorch import ToTensorV2
|
4 |
+
|
5 |
+
|
6 |
+
NORM_DATA_MEAN = (0.49139968, 0.48215841, 0.44653091)
|
7 |
+
NORM_DATA_STD = (0.24703223, 0.24348513, 0.26158784)
|
8 |
+
|
9 |
+
CIFAR_CLASS_LABELS = [
|
10 |
+
'airplane', 'automobile', 'bird', 'cat', 'deer',
|
11 |
+
'dog', 'frog', 'horse', 'ship', 'truck'
|
12 |
+
]
|
13 |
+
|
14 |
+
TRAIN_TRANSFORM = A.Compose([
|
15 |
+
A.Normalize(
|
16 |
+
mean=NORM_DATA_MEAN,
|
17 |
+
std=NORM_DATA_STD,
|
18 |
+
),
|
19 |
+
A.HorizontalFlip(),
|
20 |
+
A.Compose([
|
21 |
+
A.PadIfNeeded(min_height=40, min_width=40, p=1.0),
|
22 |
+
A.CoarseDropout(max_holes=1, max_height=16, max_width=16,
|
23 |
+
min_holes=1, min_height=16, min_width=16,
|
24 |
+
fill_value=NORM_DATA_MEAN, mask_fill_value=None, p=1.0),
|
25 |
+
A.RandomCrop(p=1.0, height=32, width=32)
|
26 |
+
]),
|
27 |
+
ToTensorV2(),
|
28 |
+
])
|
29 |
+
|
30 |
+
TEST_TRANSFORM = A.Compose([
|
31 |
+
A.Normalize(
|
32 |
+
mean=NORM_DATA_MEAN,
|
33 |
+
std=NORM_DATA_STD,
|
34 |
+
),
|
35 |
+
ToTensorV2(),
|
36 |
+
])
|
37 |
+
|
38 |
+
class CifarAlbumentationsDataset(datasets.CIFAR10):
|
39 |
+
def __init__(self, *args, **kwargs):
|
40 |
+
super().__init__(*args, **kwargs)
|
41 |
+
def __getitem__(self, idx):
|
42 |
+
img, target = self.data[idx], self.targets[idx]
|
43 |
+
if self.transform:
|
44 |
+
augmented = self.transform(image=img)
|
45 |
+
image = augmented['image']
|
46 |
+
return image, target
|
47 |
+
|
48 |
+
|
examples/airplane.jpg
ADDED
![]() |
examples/automobile.jpg
ADDED
![]() |
examples/cat.jpg
ADDED
![]() |
examples/deer.jpg
ADDED
![]() |
examples/dog.jpg
ADDED
![]() |
examples/frog.jpg
ADDED
![]() |
examples/horse.jpg
ADDED
![]() |
examples/ship.jpg
ADDED
![]() |
examples/truck.jpg
ADDED
![]() |
misclassified_images.pt
ADDED
Binary file (450 kB). View file
|
|
model.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
https://github.com/kuangliu/pytorch-cifar
|
3 |
+
|
4 |
+
ResNet in PyTorch.
|
5 |
+
|
6 |
+
For Pre-activation ResNet, see 'preact_resnet.py'.
|
7 |
+
|
8 |
+
Reference:
|
9 |
+
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
10 |
+
Deep Residual Learning for Image Recognition. arXiv:1512.03385
|
11 |
+
'''
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
from torch_lr_finder import LRFinder
|
16 |
+
|
17 |
+
|
18 |
+
class BasicBlock(nn.Module):
|
19 |
+
expansion = 1
|
20 |
+
|
21 |
+
def __init__(self, in_planes, planes, stride=1):
|
22 |
+
super(BasicBlock, self).__init__()
|
23 |
+
self.conv1 = nn.Conv2d(
|
24 |
+
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
25 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
26 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
|
27 |
+
stride=1, padding=1, bias=False)
|
28 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
29 |
+
|
30 |
+
self.shortcut = nn.Sequential()
|
31 |
+
if stride != 1 or in_planes != self.expansion*planes:
|
32 |
+
self.shortcut = nn.Sequential(
|
33 |
+
nn.Conv2d(in_planes, self.expansion*planes,
|
34 |
+
kernel_size=1, stride=stride, bias=False),
|
35 |
+
nn.BatchNorm2d(self.expansion*planes)
|
36 |
+
)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
40 |
+
out = self.bn2(self.conv2(out))
|
41 |
+
out += self.shortcut(x)
|
42 |
+
out = F.relu(out)
|
43 |
+
return out
|
44 |
+
|
45 |
+
class ResNet(nn.Module):
|
46 |
+
def __init__(self, block, num_blocks, num_classes=10):
|
47 |
+
super(ResNet, self).__init__()
|
48 |
+
self.in_planes = 64
|
49 |
+
|
50 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
|
51 |
+
stride=1, padding=1, bias=False)
|
52 |
+
self.bn1 = nn.BatchNorm2d(64)
|
53 |
+
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
54 |
+
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
55 |
+
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
56 |
+
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
57 |
+
self.linear = nn.Linear(512*block.expansion, num_classes)
|
58 |
+
|
59 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
60 |
+
strides = [stride] + [1]*(num_blocks-1)
|
61 |
+
layers = []
|
62 |
+
for stride in strides:
|
63 |
+
layers.append(block(self.in_planes, planes, stride))
|
64 |
+
self.in_planes = planes * block.expansion
|
65 |
+
return nn.Sequential(*layers)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
69 |
+
out = self.layer1(out)
|
70 |
+
out = self.layer2(out)
|
71 |
+
out = self.layer3(out)
|
72 |
+
out = self.layer4(out)
|
73 |
+
out = F.avg_pool2d(out, 4)
|
74 |
+
out = out.view(out.size(0), -1)
|
75 |
+
out = self.linear(out)
|
76 |
+
return out
|
77 |
+
|
78 |
+
|
79 |
+
def ResNet18():
|
80 |
+
return ResNet(BasicBlock, [2, 2, 2, 2])
|
81 |
+
|
82 |
+
import torch.nn as nn
|
83 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
84 |
+
from torch.utils.data import DataLoader
|
85 |
+
import matplotlib.pyplot as plt
|
86 |
+
|
87 |
+
from data_loader import CifarAlbumentationsDataset,\
|
88 |
+
CIFAR_CLASS_LABELS, TRAIN_TRANSFORM, TEST_TRANSFORM
|
89 |
+
import model
|
90 |
+
from torch_lr_finder import LRFinder
|
91 |
+
|
92 |
+
import torch
|
93 |
+
import torch.nn as nn
|
94 |
+
import torch.nn.functional as F
|
95 |
+
from pytorch_lightning import LightningModule
|
96 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
97 |
+
from torchmetrics.functional import accuracy
|
98 |
+
|
99 |
+
class LitResnet(LightningModule):
|
100 |
+
def __init__(self, lr=0.03, batch_size=512):
|
101 |
+
super().__init__()
|
102 |
+
|
103 |
+
self.save_hyperparameters()
|
104 |
+
self.criterion = nn.CrossEntropyLoss()
|
105 |
+
self.model = ResNet18()
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return self.model(x)
|
109 |
+
|
110 |
+
def training_step(self, batch, batch_idx):
|
111 |
+
x, y = batch
|
112 |
+
output = self.forward(x)
|
113 |
+
loss = self.criterion(output, y)
|
114 |
+
self.log("train_loss", loss)
|
115 |
+
acc = accuracy(torch.argmax(output, dim=1),
|
116 |
+
y, 'multiclass', num_classes=10)
|
117 |
+
self.log(f"train_acc", acc, prog_bar=True)
|
118 |
+
return loss
|
119 |
+
|
120 |
+
def evaluate(self, batch, stage=None):
|
121 |
+
x, y = batch
|
122 |
+
output = self.forward(x)
|
123 |
+
loss = self.criterion(output, y)
|
124 |
+
preds = torch.argmax(output, dim=1)
|
125 |
+
acc = accuracy(preds, y, 'multiclass', num_classes=10)
|
126 |
+
|
127 |
+
if stage:
|
128 |
+
self.log(f"{stage}_loss", loss, prog_bar=True)
|
129 |
+
self.log(f"{stage}_acc", acc, prog_bar=True)
|
130 |
+
|
131 |
+
def validation_step(self, batch, batch_idx):
|
132 |
+
self.evaluate(batch, "val")
|
133 |
+
|
134 |
+
def test_step(self, batch, batch_idx):
|
135 |
+
self.evaluate(batch, "test")
|
136 |
+
|
137 |
+
# todo
|
138 |
+
# change the default for num_iter
|
139 |
+
def lr_finder(self, optimizer, num_iter=200,):
|
140 |
+
lr_finder = LRFinder(self, optimizer, self.criterion,
|
141 |
+
device=self.device)
|
142 |
+
lr_finder.range_test(
|
143 |
+
self.train_dataloader(), end_lr=1,
|
144 |
+
num_iter=num_iter, step_mode='exp',
|
145 |
+
)
|
146 |
+
ax, suggested_lr = lr_finder.plot(suggest_lr=True)
|
147 |
+
# todo
|
148 |
+
# how to log maplotlib images
|
149 |
+
# self.logger.experiment.add_image('lr_finder', plt.gcf(), 0)
|
150 |
+
lr_finder.reset()
|
151 |
+
return suggested_lr
|
152 |
+
def configure_optimizers(self):
|
153 |
+
optimizer = torch.optim.SGD(
|
154 |
+
self.parameters(),
|
155 |
+
lr=self.hparams.lr,
|
156 |
+
momentum=0.9,
|
157 |
+
weight_decay=5e-4,
|
158 |
+
)
|
159 |
+
suggested_lr = self.lr_finder(optimizer)
|
160 |
+
steps_per_epoch = len(self.train_dataloader())
|
161 |
+
scheduler_dict = {
|
162 |
+
"scheduler": OneCycleLR(
|
163 |
+
optimizer, max_lr=suggested_lr,
|
164 |
+
steps_per_epoch=steps_per_epoch,
|
165 |
+
epochs=self.trainer.max_epochs,
|
166 |
+
pct_start=5/self.trainer.max_epochs,
|
167 |
+
three_phase=False,
|
168 |
+
div_factor=100,
|
169 |
+
final_div_factor=100,
|
170 |
+
anneal_strategy='linear',
|
171 |
+
),
|
172 |
+
"interval": "step",
|
173 |
+
}
|
174 |
+
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
|
175 |
+
####################
|
176 |
+
# DATA RELATED HOOKS
|
177 |
+
####################
|
178 |
+
|
179 |
+
def prepare_data(self, data_path='../data'):
|
180 |
+
CifarAlbumentationsDataset(
|
181 |
+
data_path, train=True, download=True)
|
182 |
+
CifarAlbumentationsDataset(
|
183 |
+
data_path, train=False, download=True)
|
184 |
+
|
185 |
+
def setup(self, stage=None, data_dir='../data'):
|
186 |
+
|
187 |
+
if stage == "fit" or stage is None:
|
188 |
+
self.train_dataset = CifarAlbumentationsDataset(data_dir, train=True, transform=TRAIN_TRANSFORM)
|
189 |
+
self.test_dataset = CifarAlbumentationsDataset(data_dir, train=False, transform=TEST_TRANSFORM)
|
190 |
+
|
191 |
+
|
192 |
+
def train_dataloader(self):
|
193 |
+
return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size,
|
194 |
+
shuffle=True, pin_memory=True) #num_workers=4,
|
195 |
+
|
196 |
+
def val_dataloader(self):
|
197 |
+
return DataLoader(self.test_dataset, batch_size=self.hparams.batch_size,
|
198 |
+
shuffle=False, pin_memory=True)
|
199 |
+
|
requirements.txt
ADDED
File without changes
|