Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import pytorch_lightning as pl | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
import numpy as np | |
from torchvision import transforms | |
import matplotlib.pyplot as plt | |
from loss import YoloLoss | |
import config as cfg | |
""" | |
Information about architecture config: | |
Tuple is structured by (filters, kernel_size, stride) | |
Every conv is a same convolution. | |
List is structured by "B" indicating a residual block followed by the number of repeats | |
"S" is for scale prediction block and computing the yolo loss | |
"U" is for upsampling the feature map and concatenating with a previous layer | |
""" | |
config = [ | |
(32, 3, 1), | |
(64, 3, 2), | |
["B", 1], | |
(128, 3, 2), | |
["B", 2], | |
(256, 3, 2), | |
["B", 8], | |
(512, 3, 2), | |
["B", 8], | |
(1024, 3, 2), | |
["B", 4], # To this point is Darknet-53 | |
(512, 1, 1), | |
(1024, 3, 1), | |
"S", | |
(256, 1, 1), | |
"U", | |
(256, 1, 1), | |
(512, 3, 1), | |
"S", | |
(128, 1, 1), | |
"U", | |
(128, 1, 1), | |
(256, 3, 1), | |
"S", | |
] | |
class CNNBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, bn_act=True, **kwargs): | |
super().__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs) | |
self.bn = nn.BatchNorm2d(out_channels) | |
self.leaky = nn.LeakyReLU(0.1) | |
self.use_bn_act = bn_act | |
def forward(self, x): | |
if self.use_bn_act: | |
return self.leaky(self.bn(self.conv(x))) | |
else: | |
return self.conv(x) | |
class ResidualBlock(nn.Module): | |
def __init__(self, channels, use_residual=True, num_repeats=1): | |
super().__init__() | |
self.layers = nn.ModuleList() | |
for repeat in range(num_repeats): | |
self.layers += [ | |
nn.Sequential( | |
CNNBlock(channels, channels // 2, kernel_size=1), | |
CNNBlock(channels // 2, channels, kernel_size=3, padding=1), | |
) | |
] | |
self.use_residual = use_residual | |
self.num_repeats = num_repeats | |
def forward(self, x): | |
for layer in self.layers: | |
if self.use_residual: | |
x = x + layer(x) | |
else: | |
x = layer(x) | |
return x | |
class ScalePrediction(nn.Module): | |
def __init__(self, in_channels, num_classes): | |
super().__init__() | |
self.pred = nn.Sequential( | |
CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1), | |
CNNBlock( | |
2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1 | |
), | |
) | |
self.num_classes = num_classes | |
def forward(self, x): | |
return ( | |
self.pred(x) | |
.reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3]) | |
.permute(0, 1, 3, 4, 2) | |
) | |
class YOLOv3LightningModel(pl.LightningModule): | |
def __init__(self, in_channels=3, num_classes=20, anchors=None, S=None): | |
super().__init__() | |
self.num_classes = num_classes | |
self.in_channels = in_channels | |
self.layers = self._create_conv_layers() | |
self.anchor_list = ( | |
torch.tensor(anchors) | |
* torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
) | |
self.criterion = YoloLoss() | |
self.metric = dict( | |
total_train_steps=0, | |
epoch_train_loss=[], | |
epoch_train_acc=[], | |
epoch_train_steps=0, | |
total_val_steps=0, | |
epoch_val_loss=[], | |
epoch_val_acc=[], | |
epoch_val_steps=0, | |
train_loss=[], | |
val_loss=[], | |
train_acc=[], | |
val_acc=[] | |
) | |
def forward(self, x): | |
outputs = [] # for each scale | |
route_connections = [] | |
for layer in self.layers: | |
if isinstance(layer, ScalePrediction): | |
outputs.append(layer(x)) | |
continue | |
x = layer(x) | |
if isinstance(layer, ResidualBlock) and layer.num_repeats == 8: | |
route_connections.append(x) | |
elif isinstance(layer, nn.Upsample): | |
x = torch.cat([x, route_connections[-1]], dim=1) | |
route_connections.pop() | |
return outputs | |
def _create_conv_layers(self): | |
layers = nn.ModuleList() | |
in_channels = self.in_channels | |
for module in config: | |
if isinstance(module, tuple): | |
out_channels, kernel_size, stride = module | |
layers.append( | |
CNNBlock( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=1 if kernel_size == 3 else 0, | |
) | |
) | |
in_channels = out_channels | |
elif isinstance(module, list): | |
num_repeats = module[1] | |
layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,)) | |
elif isinstance(module, str): | |
if module == "S": | |
layers += [ | |
ResidualBlock(in_channels, use_residual=False, num_repeats=1), | |
CNNBlock(in_channels, in_channels // 2, kernel_size=1), | |
ScalePrediction(in_channels // 2, num_classes=self.num_classes), | |
] | |
in_channels = in_channels // 2 | |
elif module == "U": | |
layers.append(nn.Upsample(scale_factor=2),) | |
in_channels = in_channels * 3 | |
return layers | |
def get_layer(self, idx): | |
if idx < len(self.layers) and idx >= 0: | |
return self.layers[idx] | |
def training_step(self, train_batch, batch_idx): | |
x, target = train_batch | |
output = self.forward(x) | |
loss = self.criterion(output, target, loss_dict=True, anchor_list=self.anchor_list) | |
acc = self.criterion.check_class_accuracy(output, target, cfg.CONF_THRESHOLD) | |
self.metric['total_train_steps'] += 1 | |
self.metric['epoch_train_steps'] += 1 | |
self.metric['epoch_train_loss'].append(loss) | |
self.metric['epoch_train_acc'].append(acc) | |
self.log_dict({'train_loss': loss['total_loss']}) | |
return loss['total_loss'] | |
def validation_step(self, val_batch, batch_idx): | |
x, target = val_batch | |
output = self.forward(x) | |
loss = self.criterion(output, target, loss_dict=True, anchor_list=self.anchor_list) | |
acc = self.criterion.check_class_accuracy(output, target, cfg.CONF_THRESHOLD) | |
self.metric['total_val_steps'] += 1 | |
self.metric['epoch_val_steps'] += 1 | |
self.metric['epoch_val_loss'].append(loss) | |
self.metric['epoch_val_acc'].append(acc) | |
self.log_dict({'val_loss': loss['total_loss']}) | |
def on_validation_epoch_end(self): | |
if self.metric['total_train_steps']: | |
print('Epoch ', self.current_epoch) | |
epoch_loss = 0 | |
epoch_acc = dict( | |
correct_class=0, | |
correct_noobj=0, | |
correct_obj=0, | |
total_class_preds=0, | |
total_noobj=0, | |
total_obj=0 | |
) | |
for i in range(self.metric['epoch_train_steps']): | |
lo = self.metric['epoch_train_loss'][i] | |
epoch_loss += lo['total_loss'] | |
acc = self.metric['epoch_train_acc'][i] | |
epoch_acc['correct_class'] += acc['correct_class'] | |
epoch_acc['correct_noobj'] += acc['correct_noobj'] | |
epoch_acc['correct_obj'] += acc['correct_obj'] | |
epoch_acc['total_class_preds'] += acc['total_class_preds'] | |
epoch_acc['total_noobj'] += acc['total_noobj'] | |
epoch_acc['total_obj'] += acc['total_obj'] | |
print("Train -") | |
print(f"Class accuracy is: {(epoch_acc['correct_class']/(epoch_acc['total_class_preds']+1e-16))*100:2f}%") | |
print(f"No obj accuracy is: {(epoch_acc['correct_noobj']/(epoch_acc['total_noobj']+1e-16))*100:2f}%") | |
print(f"Obj accuracy is: {(epoch_acc['correct_obj']/(epoch_acc['total_obj']+1e-16))*100:2f}%") | |
print(f"Total loss: {(epoch_loss/(len(self.metric['epoch_train_loss'])+1e-16)):2f}") | |
self.metric['epoch_train_loss'] = [] | |
self.metric['epoch_train_acc'] = [] | |
self.metric['epoch_train_steps'] = 0 | |
#--- | |
epoch_loss = 0 | |
epoch_acc = dict( | |
correct_class=0, | |
correct_noobj=0, | |
correct_obj=0, | |
total_class_preds=0, | |
total_noobj=0, | |
total_obj=0 | |
) | |
for i in range(self.metric['epoch_val_steps']): | |
lo = self.metric['epoch_val_loss'][i] | |
epoch_loss += lo['total_loss'] | |
acc = self.metric['epoch_val_acc'][i] | |
epoch_acc['correct_class'] += acc['correct_class'] | |
epoch_acc['correct_noobj'] += acc['correct_noobj'] | |
epoch_acc['correct_obj'] += acc['correct_obj'] | |
epoch_acc['total_class_preds'] += acc['total_class_preds'] | |
epoch_acc['total_noobj'] += acc['total_noobj'] | |
epoch_acc['total_obj'] += acc['total_obj'] | |
print("Validation -") | |
print(f"Class accuracy is: {(epoch_acc['correct_class']/(epoch_acc['total_class_preds']+1e-16))*100:2f}%") | |
print(f"No obj accuracy is: {(epoch_acc['correct_noobj']/(epoch_acc['total_noobj']+1e-16))*100:2f}%") | |
print(f"Obj accuracy is: {(epoch_acc['correct_obj']/(epoch_acc['total_obj']+1e-16))*100:2f}%") | |
print(f"Total loss: {(epoch_loss/(len(self.metric['epoch_val_loss'])+1e-16)):2f}") | |
self.metric['epoch_val_loss'] = [] | |
self.metric['epoch_val_acc'] = [] | |
self.metric['epoch_val_steps'] = 0 | |
print("Creating checkpoint...") | |
self.trainer.save_checkpoint(cfg.CHECKPOINT_FILE) | |
def test_step(self, test_batch, batch_idx): | |
self.validation_step(test_batch, batch_idx) | |
def train_dataloader(self): | |
if not self.trainer.train_dataloader: | |
self.trainer.fit_loop.setup_data() | |
return self.trainer.train_dataloader | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=cfg.LEARNING_RATE, weight_decay=cfg.WEIGHT_DECAY) | |
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, | |
max_lr=cfg.LEARNING_RATE, | |
epochs=self.trainer.max_epochs, | |
steps_per_epoch=len(self.train_dataloader()), | |
pct_start=8 / self.trainer.max_epochs, | |
div_factor=100, | |
final_div_factor=100, | |
three_phase=False, | |
verbose=False | |
) | |
return { | |
"optimizer": optimizer, | |
"lr_scheduler": { | |
"scheduler": scheduler, | |
'interval': 'step', # or 'epoch' | |
'frequency': 1 | |
}, | |
} | |
def plot_grad_cam(self, img, target_layers, grad_opacity=1.0): | |
mean, std = [0, 0, 0], [1, 1, 1] | |
cam = GradCAM(model=self, target_layers=target_layers) | |
transform = transforms.ToTensor() | |
img = transform(img) | |
if self.device != img.device: | |
img = img.to(self.device) | |
x = img.unsqueeze(0) | |
out = self.forward(img) | |
bboxes = [] | |
#fig = plt.figure() | |
for i in range(count): | |
plt.subplot(int(count / 5), 5, i + 1) | |
plt.tight_layout() | |
targets = [ClassifierOutputTarget(pred_dict['ground_truths'][i].cpu().item())] | |
grayscale_cam = cam(input_tensor=pred_dict['images'][i][None, :].cpu(), targets=targets) | |
x = denormalize(pred_dict['images'][i].cpu(), mean, std) | |
image = np.array(255 * x, np.int16).transpose(1, 2, 0) | |
img_tensor = np.array(x, np.float16).transpose(1, 2, 0) | |
visualization = show_cam_on_image(img_tensor, grayscale_cam.transpose(1, 2, 0), use_rgb=True, | |
image_weight=(1.0 - grad_opacity) ) | |
plt.imshow(image, vmin=0, vmax=255) | |
plt.imshow(visualization, vmin=0, vmax=255, alpha=grad_opacity) | |
plt.xticks([]) | |
plt.yticks([]) | |
title = get_data_label_name(pred_dict['ground_truths'][i].item()) + ' / ' + \ | |
get_data_label_name(pred_dict['predicted_vals'][i].item()) | |
plt.title(title, fontsize=8) | |
def sanity_check(model): | |
x = torch.randn((2, 3, cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)) | |
out = model(x) | |
assert model(x)[0].shape == (2, 3, cfg.IMAGE_SIZE // 32, cfg.IMAGE_SIZE // 32, cfg.NUM_CLASSES + 5) | |
assert model(x)[1].shape == (2, 3, cfg.IMAGE_SIZE // 16, cfg.IMAGE_SIZE // 16, cfg.NUM_CLASSES + 5) | |
assert model(x)[2].shape == (2, 3, cfg.IMAGE_SIZE // 8, cfg.IMAGE_SIZE // 8, cfg.NUM_CLASSES + 5) | |
print("Success!") |