import torch import torch.nn as nn import pytorch_lightning as pl from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image import numpy as np from torchvision import transforms import matplotlib.pyplot as plt from loss import YoloLoss import config as cfg """ Information about architecture config: Tuple is structured by (filters, kernel_size, stride) Every conv is a same convolution. List is structured by "B" indicating a residual block followed by the number of repeats "S" is for scale prediction block and computing the yolo loss "U" is for upsampling the feature map and concatenating with a previous layer """ config = [ (32, 3, 1), (64, 3, 2), ["B", 1], (128, 3, 2), ["B", 2], (256, 3, 2), ["B", 8], (512, 3, 2), ["B", 8], (1024, 3, 2), ["B", 4], # To this point is Darknet-53 (512, 1, 1), (1024, 3, 1), "S", (256, 1, 1), "U", (256, 1, 1), (512, 3, 1), "S", (128, 1, 1), "U", (128, 1, 1), (256, 3, 1), "S", ] class CNNBlock(nn.Module): def __init__(self, in_channels, out_channels, bn_act=True, **kwargs): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs) self.bn = nn.BatchNorm2d(out_channels) self.leaky = nn.LeakyReLU(0.1) self.use_bn_act = bn_act def forward(self, x): if self.use_bn_act: return self.leaky(self.bn(self.conv(x))) else: return self.conv(x) class ResidualBlock(nn.Module): def __init__(self, channels, use_residual=True, num_repeats=1): super().__init__() self.layers = nn.ModuleList() for repeat in range(num_repeats): self.layers += [ nn.Sequential( CNNBlock(channels, channels // 2, kernel_size=1), CNNBlock(channels // 2, channels, kernel_size=3, padding=1), ) ] self.use_residual = use_residual self.num_repeats = num_repeats def forward(self, x): for layer in self.layers: if self.use_residual: x = x + layer(x) else: x = layer(x) return x class ScalePrediction(nn.Module): def __init__(self, in_channels, num_classes): super().__init__() self.pred = nn.Sequential( CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1), CNNBlock( 2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1 ), ) self.num_classes = num_classes def forward(self, x): return ( self.pred(x) .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3]) .permute(0, 1, 3, 4, 2) ) class YOLOv3LightningModel(pl.LightningModule): def __init__(self, in_channels=3, num_classes=20, anchors=None, S=None): super().__init__() self.num_classes = num_classes self.in_channels = in_channels self.layers = self._create_conv_layers() self.anchor_list = ( torch.tensor(anchors) * torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) ) self.criterion = YoloLoss() self.metric = dict( total_train_steps=0, epoch_train_loss=[], epoch_train_acc=[], epoch_train_steps=0, total_val_steps=0, epoch_val_loss=[], epoch_val_acc=[], epoch_val_steps=0, train_loss=[], val_loss=[], train_acc=[], val_acc=[] ) def forward(self, x): outputs = [] # for each scale route_connections = [] for layer in self.layers: if isinstance(layer, ScalePrediction): outputs.append(layer(x)) continue x = layer(x) if isinstance(layer, ResidualBlock) and layer.num_repeats == 8: route_connections.append(x) elif isinstance(layer, nn.Upsample): x = torch.cat([x, route_connections[-1]], dim=1) route_connections.pop() return outputs def _create_conv_layers(self): layers = nn.ModuleList() in_channels = self.in_channels for module in config: if isinstance(module, tuple): out_channels, kernel_size, stride = module layers.append( CNNBlock( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1 if kernel_size == 3 else 0, ) ) in_channels = out_channels elif isinstance(module, list): num_repeats = module[1] layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,)) elif isinstance(module, str): if module == "S": layers += [ ResidualBlock(in_channels, use_residual=False, num_repeats=1), CNNBlock(in_channels, in_channels // 2, kernel_size=1), ScalePrediction(in_channels // 2, num_classes=self.num_classes), ] in_channels = in_channels // 2 elif module == "U": layers.append(nn.Upsample(scale_factor=2),) in_channels = in_channels * 3 return layers def get_layer(self, idx): if idx < len(self.layers) and idx >= 0: return self.layers[idx] def training_step(self, train_batch, batch_idx): x, target = train_batch output = self.forward(x) loss = self.criterion(output, target, loss_dict=True, anchor_list=self.anchor_list) acc = self.criterion.check_class_accuracy(output, target, cfg.CONF_THRESHOLD) self.metric['total_train_steps'] += 1 self.metric['epoch_train_steps'] += 1 self.metric['epoch_train_loss'].append(loss) self.metric['epoch_train_acc'].append(acc) self.log_dict({'train_loss': loss['total_loss']}) return loss['total_loss'] def validation_step(self, val_batch, batch_idx): x, target = val_batch output = self.forward(x) loss = self.criterion(output, target, loss_dict=True, anchor_list=self.anchor_list) acc = self.criterion.check_class_accuracy(output, target, cfg.CONF_THRESHOLD) self.metric['total_val_steps'] += 1 self.metric['epoch_val_steps'] += 1 self.metric['epoch_val_loss'].append(loss) self.metric['epoch_val_acc'].append(acc) self.log_dict({'val_loss': loss['total_loss']}) def on_validation_epoch_end(self): if self.metric['total_train_steps']: print('Epoch ', self.current_epoch) epoch_loss = 0 epoch_acc = dict( correct_class=0, correct_noobj=0, correct_obj=0, total_class_preds=0, total_noobj=0, total_obj=0 ) for i in range(self.metric['epoch_train_steps']): lo = self.metric['epoch_train_loss'][i] epoch_loss += lo['total_loss'] acc = self.metric['epoch_train_acc'][i] epoch_acc['correct_class'] += acc['correct_class'] epoch_acc['correct_noobj'] += acc['correct_noobj'] epoch_acc['correct_obj'] += acc['correct_obj'] epoch_acc['total_class_preds'] += acc['total_class_preds'] epoch_acc['total_noobj'] += acc['total_noobj'] epoch_acc['total_obj'] += acc['total_obj'] print("Train -") print(f"Class accuracy is: {(epoch_acc['correct_class']/(epoch_acc['total_class_preds']+1e-16))*100:2f}%") print(f"No obj accuracy is: {(epoch_acc['correct_noobj']/(epoch_acc['total_noobj']+1e-16))*100:2f}%") print(f"Obj accuracy is: {(epoch_acc['correct_obj']/(epoch_acc['total_obj']+1e-16))*100:2f}%") print(f"Total loss: {(epoch_loss/(len(self.metric['epoch_train_loss'])+1e-16)):2f}") self.metric['epoch_train_loss'] = [] self.metric['epoch_train_acc'] = [] self.metric['epoch_train_steps'] = 0 #--- epoch_loss = 0 epoch_acc = dict( correct_class=0, correct_noobj=0, correct_obj=0, total_class_preds=0, total_noobj=0, total_obj=0 ) for i in range(self.metric['epoch_val_steps']): lo = self.metric['epoch_val_loss'][i] epoch_loss += lo['total_loss'] acc = self.metric['epoch_val_acc'][i] epoch_acc['correct_class'] += acc['correct_class'] epoch_acc['correct_noobj'] += acc['correct_noobj'] epoch_acc['correct_obj'] += acc['correct_obj'] epoch_acc['total_class_preds'] += acc['total_class_preds'] epoch_acc['total_noobj'] += acc['total_noobj'] epoch_acc['total_obj'] += acc['total_obj'] print("Validation -") print(f"Class accuracy is: {(epoch_acc['correct_class']/(epoch_acc['total_class_preds']+1e-16))*100:2f}%") print(f"No obj accuracy is: {(epoch_acc['correct_noobj']/(epoch_acc['total_noobj']+1e-16))*100:2f}%") print(f"Obj accuracy is: {(epoch_acc['correct_obj']/(epoch_acc['total_obj']+1e-16))*100:2f}%") print(f"Total loss: {(epoch_loss/(len(self.metric['epoch_val_loss'])+1e-16)):2f}") self.metric['epoch_val_loss'] = [] self.metric['epoch_val_acc'] = [] self.metric['epoch_val_steps'] = 0 print("Creating checkpoint...") self.trainer.save_checkpoint(cfg.CHECKPOINT_FILE) def test_step(self, test_batch, batch_idx): self.validation_step(test_batch, batch_idx) def train_dataloader(self): if not self.trainer.train_dataloader: self.trainer.fit_loop.setup_data() return self.trainer.train_dataloader def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=cfg.LEARNING_RATE, weight_decay=cfg.WEIGHT_DECAY) scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=cfg.LEARNING_RATE, epochs=self.trainer.max_epochs, steps_per_epoch=len(self.train_dataloader()), pct_start=8 / self.trainer.max_epochs, div_factor=100, final_div_factor=100, three_phase=False, verbose=False ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, 'interval': 'step', # or 'epoch' 'frequency': 1 }, } def plot_grad_cam(self, img, target_layers, grad_opacity=1.0): mean, std = [0, 0, 0], [1, 1, 1] cam = GradCAM(model=self, target_layers=target_layers) transform = transforms.ToTensor() img = transform(img) if self.device != img.device: img = img.to(self.device) x = img.unsqueeze(0) out = self.forward(img) bboxes = [] #fig = plt.figure() for i in range(count): plt.subplot(int(count / 5), 5, i + 1) plt.tight_layout() targets = [ClassifierOutputTarget(pred_dict['ground_truths'][i].cpu().item())] grayscale_cam = cam(input_tensor=pred_dict['images'][i][None, :].cpu(), targets=targets) x = denormalize(pred_dict['images'][i].cpu(), mean, std) image = np.array(255 * x, np.int16).transpose(1, 2, 0) img_tensor = np.array(x, np.float16).transpose(1, 2, 0) visualization = show_cam_on_image(img_tensor, grayscale_cam.transpose(1, 2, 0), use_rgb=True, image_weight=(1.0 - grad_opacity) ) plt.imshow(image, vmin=0, vmax=255) plt.imshow(visualization, vmin=0, vmax=255, alpha=grad_opacity) plt.xticks([]) plt.yticks([]) title = get_data_label_name(pred_dict['ground_truths'][i].item()) + ' / ' + \ get_data_label_name(pred_dict['predicted_vals'][i].item()) plt.title(title, fontsize=8) def sanity_check(model): x = torch.randn((2, 3, cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)) out = model(x) assert model(x)[0].shape == (2, 3, cfg.IMAGE_SIZE // 32, cfg.IMAGE_SIZE // 32, cfg.NUM_CLASSES + 5) assert model(x)[1].shape == (2, 3, cfg.IMAGE_SIZE // 16, cfg.IMAGE_SIZE // 16, cfg.NUM_CLASSES + 5) assert model(x)[2].shape == (2, 3, cfg.IMAGE_SIZE // 8, cfg.IMAGE_SIZE // 8, cfg.NUM_CLASSES + 5) print("Success!")