import torch import torch.nn as nn import lightning as pl import wandb import itertools from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import numpy as np import matplotlib.pyplot as plt from src.classifier import Classifier from src.dataset import CustomDataset class AttentionGate(nn.Module): def __init__(self, in_channels, out_channels): super(AttentionGate, self).__init__() self.conv_gate = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.conv_x = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.softmax = nn.Softmax(dim=1) def forward(self, x, g): gate = self.conv_gate(g) x = self.conv_x(x) attention = self.softmax(gate) x_att = x * attention return x_att class ResUNetGenerator(nn.Module): def __init__(self, gf, channels): super(ResUNetGenerator, self).__init__() # self.img_shape = img_shape self.channels = channels # Downsampling layers self.conv1 = nn.Sequential( nn.Conv2d(channels, gf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.GroupNorm(num_groups=1, num_channels=gf) ) self.conv2 = nn.Sequential( nn.Conv2d(gf, gf * 2, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.GroupNorm(num_groups=1, num_channels=gf * 2) ) self.conv3 = nn.Sequential( nn.Conv2d(gf * 2, gf * 4, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.GroupNorm(num_groups=1, num_channels=gf * 4) ) self.conv4 = nn.Sequential( nn.Conv2d(gf * 4, gf * 8, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.GroupNorm(num_groups=1, num_channels=gf * 8) ) self.attn_layer = nn.ModuleList([ AttentionGate(gf * 2**(i), gf * 2**(i+1)) for i in range(3) ]) # Upsampling layers self.deconv1 = nn.Sequential( nn.ConvTranspose2d(gf * 8, gf * 4, kernel_size=4, stride=2, padding=1), nn.ReLU(inplace=True), nn.GroupNorm(num_groups=1, num_channels=gf * 4) ) self.deconv2 = nn.Sequential( nn.ConvTranspose2d(gf * 8, gf * 2, kernel_size=4, stride=2, padding=1), nn.ReLU(inplace=True), nn.GroupNorm(num_groups=1, num_channels=gf * 2) ) self.deconv3 = nn.Sequential( nn.ConvTranspose2d(gf * 4, gf, kernel_size=4, stride=2, padding=1), nn.ReLU(inplace=True), nn.GroupNorm(num_groups=1, num_channels=gf) ) self.deconv4 = nn.Sequential( nn.ConvTranspose2d(gf * 2, channels, kernel_size=4, stride=2, padding=1), nn.Tanh() ) def forward(self, x): # Downsampling d1 = self.conv1(x) d2 = self.conv2(d1) d3 = self.conv3(d2) d4 = self.conv4(d3) # Upsampling u1 = self.deconv1(d4) u1 = self.attn_layer[2](d3, u1) u2 = self.deconv2(u1) u2 = self.attn_layer[1](d2, u2) u3 = self.deconv3(u2) u3 = self.attn_layer[0](d1, u3) output = self.deconv4(u3) return output def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=0.0002, betas=(0.5, 0.999)) return optimizer class Discriminator(pl.LightningModule): def __init__(self, df): super(Discriminator, self).__init__() self.df = df # Define the layers for the discriminator self.conv_layers = nn.ModuleList([nn.Sequential( nn.Conv2d(1 if i == 0 else df * 2**(i-1), df * 2**i, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2), nn.GroupNorm(8, df * 2**i)) for i in range(4)]) self.final_conv = nn.Conv2d(df * 8, 1, kernel_size=4, stride=1, padding=1) def forward(self, x): out = x for conv_layer in self.conv_layers: out = conv_layer(out) validity = self.final_conv(out) return validity def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=0.0002, betas=(0.5, 0.999)) return optimizer class CycleGAN(pl.LightningModule): def __init__(self, train_dir, val_dir, test_dataloader, classifier_path, checkpoint_dir, image_size=512, batch_size=4, channels=1, gf=32, df=64, lambda_cycle=10.0, lambda_id=0.1, classifier_weight=1): super(CycleGAN, self).__init__() self.image_size = image_size self.batch_size = batch_size self.channels = channels self.gf = gf self.df = df self.lambda_cycle = lambda_cycle self.lambda_id = lambda_id * lambda_cycle self.classifier_path = classifier_path self.classifier_weight = classifier_weight self.lowest_val_loss = float('inf') self.validation_step_outputs = [] self.train_dir = train_dir self.val_dir = val_dir self.test_dataloader = test_dataloader self.checkpoint_dir = checkpoint_dir # Initialize the generator, discriminator, and classifier models self.g_NP = ResUNetGenerator(gf, channels=self.channels) self.g_PN = ResUNetGenerator(gf, channels=self.channels) self.d_N = Discriminator(df) self.d_P = Discriminator(df) self.automatic_optimization = False self.classifier = Classifier() checkpoint = torch.load(classifier_path) self.classifier.load_state_dict(checkpoint['state_dict']) self.classifier.eval() self.freeze_classifier() def freeze_classifier(self): print("freezing Classifier...") for p in self.classifier.parameters() : p.requires_grad = False def generator_training_step(self, img_N, img_P, opt): self.toggle_optimizer(opt) # Translate images to the other domain fake_P = self.g_NP(img_N) fake_N = self.g_PN(img_P) # Translate images back to original domain reconstr_N = self.g_PN(fake_P) reconstr_P = self.g_NP(fake_N) # Identity mapping of images img_N_id = self.g_PN(img_N) img_P_id = self.g_NP(img_P) # Discriminators determine validity of translated images valid_N = self.d_N(fake_N) valid_P = self.d_P(fake_P) class_N_loss = self.classifier(fake_N) class_P_loss = self.classifier(fake_P) # Adversarial loss valid_target = torch.ones_like(valid_N) adversarial_loss = nn.MSELoss()(valid_N, valid_target) + nn.MSELoss()(valid_P, valid_target) # Cycle consistency loss cycle_loss = nn.L1Loss()(reconstr_N, img_N) + nn.L1Loss()(reconstr_P, img_P) # Identity loss identity_loss = nn.L1Loss()(img_N_id, img_N) + nn.L1Loss()(img_P_id, img_P) # Classifier loss class_loss = nn.MSELoss()(class_N_loss, torch.ones_like(class_N_loss)) + nn.MSELoss()(class_P_loss, torch.zeros_like(class_P_loss)) # Total generator loss total_loss = adversarial_loss + self.lambda_cycle * cycle_loss + self.lambda_id * identity_loss + self.classifier_weight * class_loss self.log('adversarial_loss', adversarial_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) self.log('reconstruction_loss', cycle_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) self.log('identity_loss', identity_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) self.log('class_loss', class_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) self.log('generator_loss', total_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) opt.zero_grad() self.manual_backward(total_loss) opt.step() self.untoggle_optimizer(opt) return total_loss, adversarial_loss, cycle_loss def discriminator_training_step(self, img_N, img_P, opt): # Pass real images through discriminator D_N self.toggle_optimizer(opt) pred_real_N = self.d_N(img_N) mse_real_N = nn.MSELoss()(pred_real_N, torch.ones_like(pred_real_N)) fake_P = self.g_PN(img_P) pred_fake_N = self.d_N(fake_P) mse_fake_N = nn.MSELoss()(pred_fake_N, torch.zeros_like(pred_fake_N)) pred_real_P = self.d_P(img_P) mse_real_P = nn.MSELoss()(pred_real_P, torch.ones_like(pred_real_P)) fake_N = self.g_NP(img_N) pred_fake_P = self.d_P(fake_N) mse_fake_P = nn.MSELoss()(pred_fake_P, torch.zeros_like(pred_fake_P)) # Compute total discriminator loss dis_loss = 0.5 * (mse_real_N + mse_fake_N + mse_real_P + mse_fake_P) opt.zero_grad() self.manual_backward(mse_fake_P) opt.step() self.untoggle_optimizer(opt) self.log('mse_fake_N', mse_fake_N, on_step=True, on_epoch=True, prog_bar=True, logger=True) self.log('mse_fake_P', mse_fake_P, on_step=True, on_epoch=True, prog_bar=True, logger=True) self.log('discriminator_loss', dis_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) return dis_loss, mse_fake_N, mse_fake_P def training_step(self, batch, batch_idx): img_N, img_P = batch optD, optG = self.optimizers() total_loss, adversarial_loss, cycle_loss = self.generator_training_step(img_N, img_P, optG) dis_loss, mse_fake_N, mse_fake_P = self.discriminator_training_step(img_N, img_P, optD) return {"generator_loss": total_loss, "adversarial_loss": adversarial_loss, "reconstruction_loss": cycle_loss, "discriminator_loss": dis_loss, "mse_fake_N": mse_fake_N, "mse_fake_P": mse_fake_P} def validation_step(self, batch, batch_idx): img_N, img_P = batch # Translate images to the other domain fake_P = self.g_NP(img_N) fake_N = self.g_PN(img_P) # Translate images back to original domain reconstr_N = self.g_PN(fake_P) reconstr_P = self.g_NP(fake_N) # Identity mapping of images img_N_id = self.g_PN(img_N) img_P_id = self.g_NP(img_P) # Discriminators determine validity of translated images valid_N = self.d_N(fake_N) valid_P = self.d_P(fake_P) class_N_loss = self.classifier(fake_N) class_P_loss = self.classifier(fake_P) # Adversarial loss valid_target = torch.ones_like(valid_N) adversarial_loss = nn.MSELoss()(valid_N, valid_target) + nn.MSELoss()(valid_P, valid_target) # Cycle consistency loss cycle_loss = nn.L1Loss()(reconstr_N, img_N) + nn.L1Loss()(reconstr_P, img_P) # Identity loss identity_loss = nn.L1Loss()(img_N_id, img_N) + nn.L1Loss()(img_P_id, img_P) # Classifier loss class_loss = nn.MSELoss()(class_N_loss, torch.ones_like(class_N_loss)) + nn.MSELoss()(class_P_loss, torch.zeros_like(class_P_loss)) # Total generator loss total_loss = adversarial_loss + self.lambda_cycle * cycle_loss + self.lambda_id * identity_loss + self.classifier_weight * class_loss self.validation_step_outputs.append(total_loss) self.log('val_adversarial_loss', adversarial_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_cycle_loss', cycle_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_identity_loss', identity_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_class_loss', class_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_generator_loss', total_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) return total_loss def on_validation_end(self): # Calculate average validation loss avg_val_loss = torch.stack(self.validation_step_outputs).mean() # Check if current validation loss is lower than the lowest recorded validation loss if avg_val_loss < self.lowest_val_loss: # Update lowest validation loss and corresponding epoch self.lowest_val_loss = avg_val_loss # Save the generators' state dictionaries torch.save(self.g_NP.state_dict(), f"{self.checkpoint_dir}/g_NP_best.ckpt") torch.save(self.g_PN.state_dict(), f"{self.checkpoint_dir}/g_PN_best.ckpt") print(f"Model saved! loss reduced to {self.lowest_val_loss}") def configure_optimizers(self): optG = torch.optim.Adam(itertools.chain(self.g_NP.parameters(), self.g_PN.parameters()),lr=2e-4, betas=(0.5, 0.999)) optD = torch.optim.Adam(itertools.chain(self.d_N.parameters(), self.d_P.parameters()),lr=2e-4, betas=(0.5, 0.999)) gamma = lambda epoch: 1 - max(0, epoch + 1 - 100) / 101 schD = LambdaLR(optD, lr_lambda=gamma) # Optimizer= [optD, optG] return optD, optG def train_dataloader(self): root_dir = self.train_dir train_N = "0" train_P = "1" img_res = (self.image_size, self.image_size) dataset = CustomDataset(root_dir=root_dir, train_N=train_N, train_P=train_P, img_res=img_res) # Set up DataLoader for parallel processing and GPU acceleration dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True) return dataloader def val_dataloader(self): root_dir = self.val_dir train_N = "0" train_P = "1" img_res = (self.image_size, self.image_size) dataset = CustomDataset(root_dir=root_dir, train_N=train_N, train_P=train_P, img_res=img_res) # Set up DataLoader for parallel processing and GPU acceleration dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True) return dataloader def on_train_batch_end(self, outputs, batch, batch_idx): if batch_idx % 100 == 0: # Get a random batch from the test dataloader batch = next(iter(self.test_dataloader)) img_N, img_P = batch # Pick a random image from the batch idx = np.random.randint(img_N.size(0)) img_N = img_N[idx].unsqueeze(0).to('cuda') img_P = img_P[idx].unsqueeze(0).to('cuda') # Translate images to the other domain fake_P = self.g_NP(img_N) fake_N = self.g_PN(img_P) # Translate images back to original domain reconstr_N = self.g_PN(fake_P) reconstr_P = self.g_NP(fake_N) # Plot the images fig, axes = plt.subplots(2, 3, figsize=(15, 10)) # Plot real N, translated P, and reconstructed N axes[0, 0].imshow(img_N.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray') axes[0, 0].set_title("Real N") axes[0, 0].axis('off') axes[0, 1].imshow(fake_P.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray') axes[0, 1].set_title("Translated P") axes[0, 1].axis('off') axes[0, 2].imshow(reconstr_N.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray') axes[0, 2].set_title("Reconstructed N") axes[0, 2].axis('off') # Plot real P, translated N, and reconstructed P axes[1, 0].imshow(img_P.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray') axes[1, 0].set_title("Real P") axes[1, 0].axis('off') axes[1, 1].imshow(fake_N.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray') axes[1, 1].set_title("Translated N") axes[1, 1].axis('off') axes[1, 2].imshow(reconstr_P.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray') axes[1, 2].set_title("Reconstructed P") axes[1, 2].axis('off') # Log the figure in WandB wandb.log({"test_images": wandb.Image(fig)}) plt.close(fig)