|
|
|
"""M23CSA003_M23CSA017.ipynb |
|
|
|
Automatically generated by Colab. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1crxV2cbekIb0oyUata3m4VRqxx9mmcB4 |
|
|
|
# Dataset |
|
""" |
|
|
|
import glob |
|
import random |
|
import os |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
import pandas as pd |
|
import glob |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageSketchDataset(Dataset): |
|
def __init__( |
|
self, image_dir, sketch_dir, labels_df, transform_image, transform_sketch, |
|
paired = False): |
|
self.image_dir = image_dir |
|
self.sketch_dir = sketch_dir |
|
self.labels_df = pd.read_csv(labels_df) |
|
self.transform_image = transform_image |
|
self.transform_sketch = transform_sketch |
|
self.all_sketches = glob.glob1( |
|
self.sketch_dir, "*.png" |
|
) |
|
self.paired = paired |
|
def __len__(self): |
|
return len(self.labels_df) |
|
|
|
def __getitem__(self, index): |
|
|
|
while True: |
|
image_filename = self.labels_df.iloc[index]["image"] |
|
|
|
label_cols = ["MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"] |
|
label = self.labels_df.loc[index, label_cols].values.astype( |
|
"float32" |
|
) |
|
|
|
image_path = os.path.join(self.image_dir, image_filename + ".jpg") |
|
|
|
if self.paired: |
|
sketch_filename = image_filename + "_segmentation.png" |
|
sketch_path = os.path.join(self.sketch_dir, sketch_filename) |
|
if not os.path.exists(sketch_path): |
|
index = (index + 1 ) % self.__len__() |
|
continue |
|
else: |
|
|
|
sketch_filename = np.random.choice(self.all_sketches) |
|
sketch_path = os.path.join(self.sketch_dir, sketch_filename) |
|
break |
|
|
|
break |
|
|
|
image = Image.open(image_path) |
|
|
|
|
|
sketch = Image.open(sketch_path) |
|
|
|
if self.transform_image: |
|
|
|
image = self.transform_image(image) |
|
|
|
if self.transform_sketch: |
|
sketch = self.transform_sketch(sketch) |
|
|
|
|
|
image_np = np.array(image) |
|
|
|
sketch_np = np.zeros_like(sketch) |
|
sketch_np[np.all(sketch) == 255] = 1.0 |
|
sketch_np = sketch_np.astype(np.float32) |
|
|
|
|
|
|
|
|
|
return ( |
|
torch.from_numpy(image_np), |
|
torch.from_numpy(sketch_np), |
|
torch.from_numpy(label), |
|
) |
|
|
|
"""# Segmentation Model""" |
|
|
|
import torch.nn as nn |
|
from torchvision import models |
|
class EncoderWithFeatures(nn.Module): |
|
def __init__(self, encoder): |
|
super().__init__() |
|
self.features = encoder.features |
|
self.feature_outputs = [] |
|
|
|
def forward(self, x): |
|
for name, layer in self.features.named_children(): |
|
x = layer(x) |
|
|
|
if name in ['3', '7', '11', '15']: |
|
self.feature_outputs.append(x) |
|
return x |
|
class DoubleConv(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.conv = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout2d(p=0.1), |
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU(inplace=True) |
|
) |
|
|
|
def forward(self, x): |
|
return self.conv(x) |
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, num_encoder_features, num_classes): |
|
super().__init__() |
|
self.up1 = nn.ConvTranspose2d(num_encoder_features, num_encoder_features // 2, kernel_size=2, stride=2) |
|
self.conv1 = DoubleConv(num_encoder_features // 2, num_encoder_features // 2) |
|
|
|
self.up2 = nn.ConvTranspose2d(num_encoder_features // 2, num_encoder_features // 4, kernel_size=2, stride=2) |
|
self.conv2 = DoubleConv(num_encoder_features // 4, num_encoder_features // 4) |
|
|
|
|
|
self.up3 = nn.ConvTranspose2d(num_encoder_features // 4, num_encoder_features // 8, kernel_size=2, stride=2) |
|
self.conv3 = DoubleConv(num_encoder_features // 8, num_encoder_features // 8) |
|
|
|
|
|
self.up4 = nn.ConvTranspose2d(num_encoder_features // 8, num_encoder_features // 16, kernel_size=2, stride=2) |
|
self.conv4 = DoubleConv(num_encoder_features // 16, num_encoder_features // 16) |
|
|
|
self.up5 = nn.ConvTranspose2d(num_encoder_features // 16, num_encoder_features//16, kernel_size=2, stride=2) |
|
|
|
self.final_conv = nn.Conv2d(num_encoder_features // 16, num_classes, kernel_size=1) |
|
|
|
def forward(self, x): |
|
x1 = self.up1(x) |
|
|
|
x1 = self.conv1(x1) |
|
|
|
x2 = self.up2(x1) |
|
|
|
x2 = self.conv2(x2) |
|
|
|
x3 = self.up3(x2) |
|
|
|
x3 = self.conv3(x3) |
|
|
|
x4 = self.up4(x3) |
|
|
|
x4 = self.conv4(x4) |
|
|
|
x5 = self.up5(x4) |
|
|
|
output = self.final_conv(x5) |
|
|
|
return output |
|
|
|
class SegmentationModel(nn.Module): |
|
def __init__(self, encoder=None, decoder=None, num_classes=1,ngpu=0): |
|
super().__init__() |
|
self.ngpu = ngpu |
|
if encoder is None: |
|
base_model = models.mobilenet_v2(pretrained=True) |
|
base_model.classifier = nn.Identity() |
|
for param in base_model.parameters(): |
|
param.requires_grad = False |
|
self.encoder = EncoderWithFeatures(base_model) |
|
else: |
|
self.encoder = encoder |
|
|
|
if decoder is None: |
|
self.decoder = Decoder(num_encoder_features=1280, num_classes=num_classes) |
|
else: |
|
self.decoder = decoder |
|
|
|
def forward(self, x): |
|
x = self.encoder(x) |
|
return self.decoder(x) |
|
|
|
"""# Fine Tune Segmentation Model""" |
|
|
|
import torch.nn as nn |
|
from torchvision import models |
|
from torch.utils.data import DataLoader |
|
import torchvision.transforms as T |
|
from PIL import Image |
|
import torch |
|
import torchvision.models as models |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import numpy as np |
|
import random |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
from models.segmenatation_model import SegmentationModel, Decoder |
|
|
|
from tqdm import tqdm |
|
from datasets import * |
|
|
|
image_dir = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Train_data" |
|
sketch_dir = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Paired_train_sketches" |
|
labels_df = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Train/Train_labels.csv" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
labels_df = "/teamspace/studios/this_studio/DL_Assignment_4/CBNGAN/train_split.csv" |
|
image_dir_val = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Train_data" |
|
sketch_dir_val = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Paired_train_sketches" |
|
labels_df_val = "/teamspace/studios/this_studio/DL_Assignment_4/CBNGAN/test_split.csv" |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def i_over_u(y_true, y_pred): |
|
y_true = y_true.astype(np.float32) |
|
y_pred = y_pred.astype(np.float32) |
|
|
|
intersection = np.sum(y_true * y_pred) |
|
union = np.sum(y_true) + np.sum(y_pred) - intersection |
|
iou = (intersection + 1e-5) / (union + 1e-5) |
|
return iou |
|
|
|
def dice_coefficient(y_true, y_pred): |
|
y_true = y_true.astype(np.float32) |
|
y_pred = y_pred.astype(np.float32) |
|
|
|
intersection = np.sum(y_true * y_pred) |
|
smooth = 1.0 |
|
dice = (2. * intersection + smooth) / (np.sum(y_true) + np.sum(y_pred) + smooth) |
|
return dice |
|
|
|
|
|
def accuracy(preds, masks, threshold=0.5): |
|
preds = (preds > threshold).float() |
|
correct = (preds == masks).sum().item() |
|
total = masks.numel() |
|
acc = correct / total |
|
return acc |
|
|
|
|
|
|
|
image_size = 128 |
|
batch_size = 64 *2 |
|
stats_image = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) |
|
stats_sketch = (0,5), (0.5) |
|
|
|
|
|
|
|
|
|
|
|
transform_image = T.Compose( |
|
[ |
|
T.Resize(image_size), |
|
T.CenterCrop(image_size), |
|
T.ToTensor(), |
|
T.Normalize(*stats_image), |
|
] |
|
) |
|
|
|
transform_sketch = T.Compose( |
|
[ |
|
T.Resize(image_size), |
|
T.CenterCrop(image_size), |
|
|
|
|
|
] |
|
) |
|
dataset = ImageSketchDataset( |
|
image_dir, |
|
sketch_dir, |
|
labels_df, |
|
transform_image=transform_image, |
|
transform_sketch=transform_sketch, |
|
paired=True |
|
) |
|
|
|
val_dataset = ImageSketchDataset( |
|
image_dir_val, |
|
sketch_dir_val, |
|
labels_df_val, |
|
transform_image=transform_image, |
|
transform_sketch=transform_sketch, |
|
paired = True |
|
) |
|
|
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=6, |
|
pin_memory=True, |
|
) |
|
|
|
val_dataloader = DataLoader( |
|
val_dataset, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=6, |
|
pin_memory=True, |
|
) |
|
|
|
|
|
|
|
model_unfreeze = models.mobilenet_v2(pretrained=True) |
|
for param in model_unfreeze.parameters(): |
|
param.requires_grad = True |
|
|
|
model_unfreeze.classifier = nn.Identity() |
|
|
|
model_unfrozen = model_unfreeze.features |
|
|
|
decoder = Decoder(num_encoder_features=1280, num_classes=1) |
|
|
|
model = SegmentationModel(encoder=model_unfrozen, decoder=decoder).to(device) |
|
criterion = nn.BCEWithLogitsLoss() |
|
encoder_lr = 0.0001 |
|
decoder_lr = 0.001 |
|
|
|
encoder_optimizer = optim.Adam(model.encoder.parameters(), lr=encoder_lr) |
|
decoder_optimizer = optim.Adam(model.decoder.parameters(), lr=decoder_lr) |
|
|
|
encoder_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(encoder_optimizer, mode='min', factor=0.2, patience=5) |
|
decoder_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(decoder_optimizer, mode='min', factor=0.2, patience=5) |
|
|
|
num_epochs = 10 |
|
|
|
train_losses = [] |
|
val_losses = [] |
|
val_accuracies = [] |
|
train_iou = [] |
|
val_iou = [] |
|
train_dice = [] |
|
val_dice = [] |
|
|
|
for epoch in range(num_epochs): |
|
running_loss = 0.0 |
|
running_train_iou = 0.0 |
|
running_train_dice = 0.0 |
|
running_val_iou = 0.0 |
|
running_val_dice = 0.0 |
|
running_val_acc = 0.0 |
|
|
|
model.train() |
|
for i, (images, masks, _) in tqdm(enumerate(dataloader),total = len(dataloader)): |
|
images = images.to(device) |
|
masks = masks.to(device) |
|
encoder_optimizer.zero_grad() |
|
decoder_optimizer.zero_grad() |
|
|
|
outputs = model(images) |
|
outputs = outputs.squeeze(1) |
|
loss = criterion(outputs, masks) |
|
loss.backward() |
|
|
|
encoder_optimizer.step() |
|
decoder_optimizer.step() |
|
|
|
running_loss += loss.item() |
|
|
|
batch_iou = 0.0 |
|
batch_dice = 0.0 |
|
for j in range(len(images)): |
|
iou = i_over_u(masks[j].cpu().numpy(), outputs[j].detach().cpu().numpy() > 0.5) |
|
dice = dice_coefficient(masks[j].cpu().numpy(), outputs[j].detach().cpu().numpy() > 0.5) |
|
running_train_iou += iou |
|
running_train_dice += dice |
|
batch_iou += iou |
|
batch_dice += dice |
|
|
|
train_iou.append(batch_iou / len(images)) |
|
train_dice.append(batch_dice / len(images)) |
|
|
|
train_losses.append(running_loss / len(dataloader)) |
|
|
|
avg_train_iou = running_train_iou / len(dataset) |
|
avg_train_dice = running_train_dice / len(dataset) |
|
|
|
model.eval() |
|
val_loss = 0.0 |
|
with torch.no_grad(): |
|
for i, (images, masks, _) in tqdm(enumerate(val_dataloader),total = len(val_dataloader)): |
|
images = images.to(device) |
|
masks = masks.to(device) |
|
|
|
outputs = model(images) |
|
outputs = outputs.squeeze(1) |
|
loss = criterion(outputs, masks) |
|
val_loss += loss.item() |
|
encoder_scheduler.step(val_loss) |
|
decoder_scheduler.step(val_loss) |
|
batch_iou = 0.0 |
|
batch_dice = 0.0 |
|
batch_acc = 0.0 |
|
for j in range(len(images)): |
|
iou = i_over_u(masks[j].cpu().numpy(), outputs[j].detach().cpu().numpy() > 0.5) |
|
dice = dice_coefficient(masks[j].cpu().numpy(), outputs[j].detach().cpu().numpy() > 0.5) |
|
acc = accuracy(outputs[j], masks[j]) |
|
running_val_iou += iou |
|
running_val_dice += dice |
|
running_val_acc += acc |
|
batch_iou += iou |
|
batch_dice += dice |
|
batch_acc += acc |
|
|
|
val_iou.append(batch_iou / len(images)) |
|
val_dice.append(batch_dice / len(images)) |
|
val_accuracies.append(batch_acc / len(images)) |
|
|
|
val_losses.append(val_loss / len(val_dataloader)) |
|
|
|
avg_val_iou = running_val_iou / len(val_dataset) |
|
avg_val_dice = running_val_dice / len(val_dataset) |
|
avg_val_acc = running_val_acc / len(val_dataset) |
|
|
|
print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}, Avg Train IoU : {avg_train_iou:.4f}, Avg Train Dice: {avg_train_dice:.4f}, Avg Val IoU : {avg_val_iou:.4f}, Avg Val Dice: {avg_val_dice:.4f}, Val Accuracy: {avg_val_acc:.4f}') |
|
|
|
save_model_path_task2 = 'segmentation_model.pth' |
|
torch.save(model.state_dict(), save_model_path_task2) |
|
|
|
"""# Generator Network""" |
|
|
|
import torch.nn as nn |
|
import torch |
|
class SandwichBatchNorm2d(nn.Module): |
|
def __init__(self, num_features, num_classes): |
|
super().__init__() |
|
self.num_features = num_features |
|
self.bn = nn.BatchNorm2d(num_features, affine=True) |
|
self.embed = nn.Embedding(num_classes, num_features * 2) |
|
self.embed.weight.data[:, :num_features].normal_( |
|
1, 0.02 |
|
) |
|
self.embed.weight.data[:, num_features:].zero_() |
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
|
out = self.bn(x) |
|
gamma, beta = self.embed(y).chunk(2, 1) |
|
out = gamma.reshape(gamma.size(0), self.num_features, 1, 1) *out + beta.reshape(beta.size(0), self.num_features, 1, 1) |
|
|
|
return out |
|
|
|
class CategoricalConditionalBatchNorm2d(nn.Module): |
|
def __init__(self, num_features, num_classes): |
|
super().__init__() |
|
self.num_features = num_features |
|
self.bn = nn.BatchNorm2d(num_features, affine=False) |
|
self.embed = nn.Embedding(num_classes, num_features * 2) |
|
self.embed.weight.data[:, :num_features].normal_( |
|
1, 0.02 |
|
) |
|
self.embed.weight.data[:, num_features:].zero_() |
|
|
|
def forward(self, x, y): |
|
out = self.bn(x) |
|
gamma, beta = self.embed(y).chunk(2, 1) |
|
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view( |
|
-1, self.num_features, 1, 1 |
|
) |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
class double_convolution(nn.Module): |
|
def __init__(self, in_channels, out_channels, num_classes): |
|
super(double_convolution, self).__init__() |
|
|
|
self.c1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) |
|
self.swn1 = SandwichBatchNorm2d(out_channels, num_classes) |
|
self.act1 = nn.ReLU(inplace=True) |
|
self.c2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) |
|
self.swn2 = SandwichBatchNorm2d(out_channels, num_classes) |
|
self.act2 = nn.ReLU(inplace=True) |
|
def forward(self, x,y): |
|
x = self.c1(x) |
|
x = self.swn1(x,y) |
|
x = self.act1(x) |
|
x = self.c2(x) |
|
x = self.swn2(x,y) |
|
x = self.act2(x) |
|
return x |
|
|
|
|
|
|
|
class Generator(nn.Module): |
|
def __init__(self, ngpu, num_classes): |
|
super(Generator, self).__init__() |
|
self.ngpu = ngpu |
|
self.swn1 = SandwichBatchNorm2d(512, num_classes) |
|
self.swn2 = SandwichBatchNorm2d(256, num_classes) |
|
self.swn3 = SandwichBatchNorm2d(128, num_classes) |
|
self.swn4 = SandwichBatchNorm2d(64, num_classes) |
|
|
|
|
|
|
|
|
|
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
|
|
self.down_convolution__2 = double_convolution(8, 4, num_classes) |
|
self.down_convolution__1 = double_convolution(4, 8, num_classes) |
|
self.down_convolution_0 = double_convolution(8, 1, num_classes) |
|
self.down_convolution_1 = double_convolution(1, 64, num_classes) |
|
self.down_convolution_2 = double_convolution(64, 128, num_classes) |
|
self.down_convolution_3 = double_convolution(128, 256, num_classes) |
|
self.down_convolution_4 = double_convolution(256, 512, num_classes) |
|
self.down_convolution_5 = double_convolution(512, 1024, num_classes) |
|
|
|
|
|
self.up_transpose_1 = nn.ConvTranspose2d( |
|
in_channels=1024, out_channels=512, kernel_size=2, stride=2 |
|
) |
|
|
|
self.up_convolution_1 = double_convolution(1024, 512,num_classes) |
|
|
|
self.up_transpose_2 = nn.ConvTranspose2d( |
|
in_channels=512, out_channels=256, kernel_size=2, stride=2 |
|
) |
|
self.up_convolution_2 = double_convolution(512, 256,num_classes) |
|
|
|
self.up_transpose_3 = nn.ConvTranspose2d( |
|
in_channels=256, out_channels=128, kernel_size=2, stride=2 |
|
) |
|
self.up_convolution_3 = double_convolution(256, 128,num_classes) |
|
self.up_transpose_4 = nn.ConvTranspose2d( |
|
in_channels=128, out_channels=64, kernel_size=2, stride=2 |
|
) |
|
self.up_convolution_4 = double_convolution(128, 64,num_classes) |
|
|
|
self.out = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1) |
|
|
|
def forward(self, x, y): |
|
down__2 = self.down_convolution__2(x, y) |
|
down__1 = self.down_convolution__1(down__2, y) |
|
down_0 = self.down_convolution_0(down__1, y) |
|
down_1 = self.down_convolution_1(down_0, y) |
|
down_2 = self.max_pool2d(down_1) |
|
down_3 = self.down_convolution_2(down_2, y) |
|
down_4 = self.max_pool2d(down_3) |
|
down_5 = self.down_convolution_3(down_4, y) |
|
down_6 = self.max_pool2d(down_5) |
|
down_7 = self.down_convolution_4(down_6, y) |
|
down_8 = self.max_pool2d(down_7) |
|
down_9 = self.down_convolution_5(down_8, y) |
|
|
|
up_1 = self.up_transpose_1(down_9) |
|
x = self.up_convolution_1(torch.cat([down_7, up_1], 1),y) |
|
self.swn1(x, y) |
|
up_2 = self.up_transpose_2(x) |
|
x = self.up_convolution_2(torch.cat([down_5, up_2], 1),y) |
|
self.swn2(x, y) |
|
up_3 = self.up_transpose_3(x) |
|
x = self.up_convolution_3(torch.cat([down_3, up_3], 1),y) |
|
self.swn3(x, y) |
|
up_4 = self.up_transpose_4(x) |
|
x = self.up_convolution_4(torch.cat([down_1, up_4], 1),y) |
|
self.swn4(x, y) |
|
out = self.out(x) |
|
return out |
|
|
|
"""# Discriminator Networ""" |
|
|
|
import torch.nn as nn |
|
import torch |
|
class SandwichBatchNorm2d(nn.Module): |
|
def __init__(self, num_features, num_classes): |
|
super().__init__() |
|
self.num_features = num_features |
|
self.bn = nn.BatchNorm2d(num_features, affine=True) |
|
self.embed = nn.Embedding(num_classes, num_features * 2) |
|
self.embed.weight.data[:, :num_features].normal_( |
|
1, 0.02 |
|
) |
|
self.embed.weight.data[:, num_features:].zero_() |
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
|
out = self.bn(x) |
|
gamma, beta = self.embed(y).chunk(2, 1) |
|
out = gamma.reshape(gamma.size(0), self.num_features, 1, 1) *out + beta.reshape(beta.size(0), self.num_features, 1, 1) |
|
|
|
return out |
|
|
|
class CategoricalConditionalBatchNorm2d(nn.Module): |
|
def __init__(self, num_features, num_classes): |
|
super().__init__() |
|
self.num_features = num_features |
|
self.bn = nn.BatchNorm2d(num_features, affine=False) |
|
self.embed = nn.Embedding(num_classes, num_features * 2) |
|
self.embed.weight.data[:, :num_features].normal_( |
|
1, 0.02 |
|
) |
|
self.embed.weight.data[:, num_features:].zero_() |
|
|
|
def forward(self, x, y): |
|
out = self.bn(x) |
|
gamma, beta = self.embed(y).chunk(2, 1) |
|
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view( |
|
-1, self.num_features, 1, 1 |
|
) |
|
return out |
|
class single_convolution(nn.Module): |
|
def __init__(self, in_channels, out_channels, num_classes): |
|
super(single_convolution, self).__init__() |
|
|
|
self.c1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) |
|
|
|
self.swn1 = nn.BatchNorm2d(out_channels) |
|
self.act1 = nn.LeakyReLU(0.2,inplace=True) |
|
|
|
def forward(self, x): |
|
x = self.c1(x) |
|
x = self.swn1(x) |
|
x = self.act1(x) |
|
|
|
return x |
|
|
|
class Discriminator(nn.Module): |
|
def __init__(self, num_classes, ngpu=0): |
|
super(Discriminator, self).__init__() |
|
self.ngpu = ngpu |
|
|
|
|
|
|
|
|
|
|
|
self.des_block1 = single_convolution(3, 64, num_classes) |
|
self.des_block2 = single_convolution(64, 128, num_classes) |
|
self.des_block3 = single_convolution(128, 256, num_classes) |
|
self.des_block4 = single_convolution(256, 512, num_classes) |
|
|
|
self.conv5 = nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0, bias=False) |
|
|
|
self.flatten = nn.Flatten() |
|
|
|
|
|
|
|
self.fc_dis = nn.Linear(3969, 1) |
|
self.fc_aux = nn.Linear(3969, num_classes) |
|
|
|
self.sigmoid = nn.Sigmoid() |
|
self.softmax = nn.Softmax(dim = 1) |
|
|
|
def forward(self, x): |
|
x = self.des_block1(x) |
|
x = self.des_block2(x) |
|
x = self.des_block3(x) |
|
x = self.des_block4(x) |
|
x = self.conv5(x) |
|
x = self.flatten(x) |
|
|
|
realfake = self.sigmoid(self.fc_dis(x)).view(-1, 1).squeeze(1) |
|
|
|
|
|
classes = self.softmax(self.fc_aux(x)) |
|
|
|
return realfake, classes |
|
|
|
"""# GAN Training""" |
|
|
|
import argparse |
|
import os |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader |
|
from torchvision.datasets import ImageFolder |
|
import torchvision.transforms as T |
|
import torch |
|
from torchvision.utils import make_grid |
|
from torch.utils.data import Dataset |
|
from torchvision.utils import save_image |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from tqdm import tqdm |
|
import torch.nn.functional as F |
|
import pandas as pd |
|
from PIL import Image |
|
import os |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import pandas as pd |
|
import glob |
|
from torch.autograd import Variable |
|
import torch.autograd as autograd |
|
from torchvision.models import mobilenet_v2 |
|
from torchvision import models, transforms |
|
|
|
import wandb |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training") |
|
parser.add_argument("--batch_size", type=int, default=32, help="size of the batches") |
|
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") |
|
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") |
|
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") |
|
parser.add_argument("--n_cpu", type=int, default=4, help="number of cpu threads to use during batch generation") |
|
parser.add_argument("--n_critic", type=int, default=1, help="number of training steps for discriminator per iter") |
|
parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights") |
|
parser.add_argument("--sample_interval", type=int, default=100, help="interval betwen image samples") |
|
args = parser.parse_args() |
|
print(args) |
|
|
|
|
|
ngpu = torch.cuda.device_count() |
|
print('num gpus available: ', ngpu) |
|
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu") |
|
|
|
wandb.init(project = 'DL_Assignment_4', entity='m23csa017') |
|
|
|
image_dir = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Train_data" |
|
sketch_dir = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Train/Contours" |
|
labels_df = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Train/Train_labels.csv" |
|
|
|
image_dir_test = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Test/Test" |
|
sketch_dir_val = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Test/Test_countours " |
|
labels_df_val = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Test/Test_Labels.csv" |
|
|
|
lambda_seg = 2.0 |
|
|
|
image_size = 128 |
|
batch_size = args.batch_size |
|
stats_image = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) |
|
stats_sketch = (0,5), (0.5) |
|
|
|
|
|
|
|
|
|
transform_image = T.Compose( |
|
[ |
|
T.Resize(image_size), |
|
T.CenterCrop(image_size), |
|
T.ToTensor(), |
|
T.Normalize(*stats_image), |
|
] |
|
) |
|
|
|
transform_sketch = T.Compose( |
|
[ |
|
T.Resize(image_size), |
|
T.CenterCrop(image_size), |
|
] |
|
) |
|
train_ds = ImageSketchDataset( |
|
image_dir, |
|
sketch_dir, |
|
labels_df, |
|
transform_image=transform_image, |
|
transform_sketch=transform_sketch, |
|
) |
|
|
|
val_ds = ImageSketchDataset( |
|
image_dir_test, |
|
sketch_dir_val, |
|
labels_df_val, |
|
transform_image=transform_image, |
|
transform_sketch=transform_sketch, |
|
) |
|
|
|
train_dl = DataLoader( |
|
train_ds, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=6, |
|
pin_memory=True, |
|
) |
|
|
|
val_dl = DataLoader( |
|
val_ds, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=6, |
|
pin_memory=True, |
|
) |
|
|
|
def denorm(img_tensors): |
|
return img_tensors * stats_image[1][0] + stats_image[0][0] |
|
|
|
|
|
|
|
|
|
num_classes = len(train_ds.labels_df.columns) - 1 |
|
print('number of classes in dataset: ',num_classes) |
|
|
|
discriminator = Discriminator(num_classes=num_classes, ngpu=ngpu).to(device) |
|
generator = Generator(ngpu=ngpu, num_classes=num_classes).to(device) |
|
|
|
|
|
model_unfreeze = models.mobilenet_v2(pretrained=True) |
|
model_unfreeze.classifier = nn.Identity() |
|
model_unfrozen = model_unfreeze.features |
|
decoder = Decoder(num_encoder_features=1280, num_classes=1) |
|
seg_model = SegmentationModel(encoder=model_unfrozen, decoder=decoder) |
|
|
|
seg_model_saved = 'segmentation_model.pth' |
|
seg_model.load_state_dict(torch.load(seg_model_saved)) |
|
seg_model.to(device) |
|
|
|
if (device.type == 'cuda') and (ngpu > 1): |
|
generator = nn.DataParallel(generator, list(range(ngpu))) |
|
discriminator = nn.DataParallel(discriminator, list(range(ngpu))) |
|
seg_model = nn.DataParallel(seg_model, list(range(ngpu))) |
|
|
|
|
|
|
|
|
|
def Generate_Fakes(sketches): |
|
noisy_sketchs = sketches |
|
noisy_sketchs_ = [] |
|
fake_labels = torch.randint(0, num_classes, (sketches.size(0), ), device=sketches.device) |
|
for noisy_sketch, fake_label in zip(noisy_sketchs, fake_labels): |
|
channels = torch.zeros( |
|
size=(num_classes, *noisy_sketch.shape), device=noisy_sketch.device |
|
) |
|
channels[fake_label] = 1.0 |
|
noisy_sketch = torch.cat((noisy_sketch.unsqueeze(0), channels), dim=0) |
|
noisy_sketchs_.append(noisy_sketch) |
|
|
|
noisy_sketchs = torch.stack(noisy_sketchs_) |
|
|
|
|
|
fake_labels = F.one_hot(fake_labels, num_classes=7).squeeze(1).float().to(device) |
|
|
|
return noisy_sketchs, fake_labels |
|
|
|
|
|
|
|
sample_dir = "generated_SBNGAN_images" |
|
os.makedirs(sample_dir, exist_ok=True) |
|
|
|
|
|
def save_samples(index, generator, train_dl, show=True): |
|
real_images, sketches, real_labels = next(iter(train_dl)) |
|
latent_input, gen_labels = Generate_Fakes(sketches=sketches) |
|
|
|
aux_fake_labels = torch.argmax(gen_labels, dim=1) |
|
aux_fake_labels = aux_fake_labels.type(torch.long) |
|
|
|
fake_images = generator(latent_input.to(device),aux_fake_labels) |
|
|
|
fake_fname = "generated-images-{0:0=4d}.png".format(index) |
|
save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=8) |
|
print("Saving", fake_fname) |
|
if show: |
|
fig, ax = plt.subplots(figsize=(8, 8)) |
|
ax.set_xticks([]) |
|
ax.set_yticks([]) |
|
ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0)) |
|
|
|
adversarial_loss = torch.nn.BCELoss() |
|
aux_criterion = nn.NLLLoss() |
|
seg_criterion = nn.BCEWithLogitsLoss() |
|
Tensor = torch.cuda.FloatTensor if (device.type == 'cuda') else torch.FloatTensor |
|
|
|
|
|
def fit(mask_gen, epochs, lr, start_idx=1): |
|
|
|
torch.cuda.empty_cache() |
|
generator.train() |
|
discriminator.train() |
|
mask_gen.eval() |
|
|
|
losses_g = [] |
|
losses_d = [] |
|
real_scores = [] |
|
fake_scores = [] |
|
|
|
k = 2 |
|
p = 6 |
|
|
|
|
|
opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999)) |
|
opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999)) |
|
|
|
for epoch in range(epochs): |
|
|
|
for idx, (real_images, sketches, real_labels_onehot) in tqdm(enumerate(train_dl), |
|
desc= "Training", dynamic_ncols=True,total=len(train_dl)): |
|
|
|
real_images = Variable(real_images.type(Tensor).to(device), requires_grad=True) |
|
sketches = sketches.to(device) |
|
real_labels_onehot = real_labels_onehot.to(device) |
|
|
|
|
|
batch_size = real_images.shape[0] |
|
|
|
valid = torch.full((batch_size,), 1.0, dtype=torch.float, device=device) |
|
fake = torch.full((batch_size,), 0.0, dtype=torch.float, device=device) |
|
|
|
|
|
latent_input, gen_labels_onehot = Generate_Fakes(sketches=sketches) |
|
|
|
latent_input = Variable(latent_input.to(device)) |
|
|
|
|
|
|
|
|
|
opt_d.zero_grad() |
|
|
|
|
|
aux_real_labels = torch.argmax(real_labels_onehot, dim=1) |
|
aux_fake_labels = torch.argmax(gen_labels_onehot, dim=1) |
|
|
|
gen_labels_onehot_long = aux_fake_labels.type(torch.long) |
|
real_labels_onehot_long = aux_real_labels.type(torch.long) |
|
|
|
fake_images = generator(latent_input,gen_labels_onehot_long) |
|
|
|
|
|
validity_real, real_aux_output = discriminator(real_images) |
|
|
|
validity_fake, fake_aux_output = discriminator(fake_images) |
|
|
|
loss_d_validity = adversarial_loss(validity_real, valid) + adversarial_loss(validity_fake, fake) |
|
|
|
|
|
loss_d_aux = aux_criterion(fake_aux_output, aux_fake_labels) + aux_criterion(real_aux_output, aux_real_labels) |
|
|
|
loss_d = loss_d_validity + loss_d_aux |
|
|
|
real_score =torch.mean(validity_real).item() |
|
|
|
|
|
fake_score = torch.mean(validity_fake).item() |
|
|
|
loss_d.backward() |
|
opt_d.step() |
|
|
|
|
|
if idx % args.n_critic == 0: |
|
|
|
|
|
|
|
opt_g.zero_grad() |
|
fake_images = generator(latent_input,gen_labels_onehot_long) |
|
validity_fake, fake_aux_output = discriminator(fake_images) |
|
generated_mask = mask_gen(fake_images) |
|
loss_g_adv = adversarial_loss(validity_fake, valid) + aux_criterion(fake_aux_output, aux_fake_labels) |
|
generated_mask = generated_mask.squeeze(1) |
|
loss_g_seg = lambda_seg * seg_criterion(generated_mask, sketches) |
|
loss_g = loss_g_adv + loss_g_seg |
|
loss_g.backward() |
|
opt_g.step() |
|
|
|
wandb.log( |
|
{ |
|
"loss_g": loss_g, |
|
"loss_d":loss_d, |
|
'real_score': real_score, |
|
'fake_score': fake_score, |
|
|
|
} |
|
) |
|
print( |
|
"Epoch [{}/{}], Batch [{}/{}], loss_g:{:.4f}, loss_d:{:.4f}, real_scores:{:.4f}, fake_score:{:.4f}".format( |
|
epoch + 1, epochs, idx, len(train_dl), loss_g, loss_d, real_score, fake_score |
|
) |
|
) |
|
batches_done = epoch * len(train_dl) + idx |
|
if batches_done % args.sample_interval == 0: |
|
save_samples(batches_done, generator, train_dl, show=False) |
|
|
|
batches_done += args.n_critic |
|
|
|
losses_d.append(loss_d.item()) |
|
losses_g.append(loss_g.item()) |
|
real_scores.append(real_score) |
|
fake_scores.append(fake_score) |
|
|
|
if epoch % 4 == 0: |
|
save_model_path_task2 = f'generator_model_{epoch}.pth' |
|
torch.save(generator.state_dict(), save_model_path_task2) |
|
return losses_g, losses_d, real_scores, fake_scores |
|
|
|
|
|
lr = args.lr |
|
epochs = args.n_epochs |
|
|
|
history = fit(seg_model,epochs, lr) |
|
|
|
losses_g, losses_d, real_scores, fake_scores = history |
|
|
|
"""# Generator Evaluation""" |
|
|
|
import argparse |
|
import os |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader |
|
from torchvision.datasets import ImageFolder |
|
import torchvision.transforms as T |
|
import torch |
|
from torchvision.utils import make_grid |
|
from torch.utils.data import Dataset |
|
from torchvision.utils import save_image |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from tqdm import tqdm |
|
import torch.nn.functional as F |
|
import pandas as pd |
|
from PIL import Image |
|
import os |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import pandas as pd |
|
import glob |
|
from torch.autograd import Variable |
|
import torch.autograd as autograd |
|
from torchvision.models import mobilenet_v2 |
|
|
|
from torchvision import models, transforms |
|
|
|
from datasets import * |
|
from models.segmenatation_model import * |
|
from models.Generator import Generator |
|
from models.Discriminator import Discriminator |
|
ngpu = torch.cuda.device_count() |
|
print('num gpus available: ', ngpu) |
|
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu") |
|
|
|
|
|
|
|
image_dir = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Train_data" |
|
sketch_dir = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Train/Contours" |
|
labels_df = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Train/Train_labels.csv" |
|
|
|
image_dir_test = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Test/Test" |
|
sketch_dir_val = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Test/Test_contours " |
|
labels_df_val = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Test/Test_Labels.csv" |
|
|
|
lambda_seg = 2.0 |
|
num_classes = 7 |
|
image_size = 128 |
|
batch_size = 8 |
|
stats_image = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) |
|
stats_sketch = (0,5), (0.5) |
|
|
|
|
|
def add_gaussian_noise(image, mean=0, stddev=1): |
|
|
|
noise = torch.randn_like(image) |
|
|
|
noisy_image = image + noise |
|
|
|
return noisy_image |
|
|
|
|
|
|
|
transform_image = T.Compose( |
|
[ |
|
T.Resize(image_size), |
|
T.CenterCrop(image_size), |
|
T.ToTensor(), |
|
T.Normalize(*stats_image), |
|
] |
|
) |
|
|
|
transform_sketch = T.Compose( |
|
[ |
|
T.Resize(image_size), |
|
T.CenterCrop(image_size), |
|
|
|
|
|
] |
|
) |
|
train_ds = ImageSketchDataset( |
|
image_dir, |
|
sketch_dir, |
|
labels_df, |
|
transform_image=transform_image, |
|
transform_sketch=transform_sketch, |
|
) |
|
|
|
val_ds = ImageSketchDataset( |
|
image_dir_test, |
|
sketch_dir_val, |
|
labels_df_val, |
|
transform_image=transform_image, |
|
transform_sketch=transform_sketch, |
|
) |
|
|
|
train_dl = DataLoader( |
|
train_ds, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=6, |
|
pin_memory=True, |
|
) |
|
|
|
val_dl = DataLoader( |
|
val_ds, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=6, |
|
pin_memory=True, |
|
) |
|
|
|
def Generate_Fakes(sketches): |
|
|
|
noisy_sketchs = sketches |
|
noisy_sketchs_ = [] |
|
fake_labels = torch.randint(0, num_classes, (sketches.size(0), ), device=sketches.device) |
|
for noisy_sketch, fake_label in zip(noisy_sketchs, fake_labels): |
|
channels = torch.zeros( |
|
size=(num_classes, *noisy_sketch.shape), device=noisy_sketch.device |
|
) |
|
channels[fake_label] = 1.0 |
|
noisy_sketch = torch.cat((noisy_sketch.unsqueeze(0), channels), dim=0) |
|
noisy_sketchs_.append(noisy_sketch) |
|
|
|
noisy_sketchs = torch.stack(noisy_sketchs_) |
|
|
|
|
|
fake_labels = F.one_hot(fake_labels, num_classes=7).squeeze(1).float().to(device) |
|
|
|
return noisy_sketchs, fake_labels |
|
|
|
generator = Generator(ngpu=ngpu, num_classes=7).to(device) |
|
Tensor = torch.cuda.FloatTensor if (device.type == 'cuda') else torch.FloatTensor |
|
|
|
|
|
|
|
import torch |
|
from torchmetrics.image.inception import InceptionScore |
|
|
|
|
|
from torchmetrics.image.fid import FrechetInceptionDistance |
|
fid = FrechetInceptionDistance(feature=64) |
|
|
|
|
|
def calculate_fid_is_score(generator, num_classes, n_samples=2000, eps=1e-6): |
|
|
|
fake_images=[] |
|
real_images = [] |
|
for idx, (real_image, sketches, real_labels_onehot) in tqdm(enumerate(train_dl), |
|
desc= "Training", dynamic_ncols=True,total=len(train_dl)): |
|
|
|
|
|
sketches = sketches.to(device) |
|
real_labels_onehot = real_labels_onehot.to(device) |
|
|
|
|
|
latent_input, gen_labels_onehot = Generate_Fakes(sketches=sketches) |
|
|
|
latent_input = Variable(latent_input.to(device)) |
|
|
|
|
|
aux_real_labels = torch.argmax(real_labels_onehot, dim=1) |
|
aux_fake_labels = torch.argmax(gen_labels_onehot, dim=1) |
|
|
|
gen_labels_onehot_long = aux_fake_labels.type(torch.long) |
|
real_labels_onehot_long = aux_real_labels.type(torch.long) |
|
|
|
fake_image = generator(latent_input,gen_labels_onehot_long) |
|
fake_images.append(fake_image.detach().cpu()) |
|
|
|
|
|
real_images.append(real_image.detach().cpu()) |
|
|
|
if (idx+1) * batch_size > n_samples: |
|
break |
|
|
|
|
|
fake_images = torch.cat(fake_images, dim=0) |
|
fake_images = fake_images.type(torch.uint8) |
|
real_images = torch.cat(real_images, dim=0) |
|
real_images = real_images.type(torch.uint8) |
|
|
|
fid.update(fake_images, real=True) |
|
fid.update(real_images, real=False) |
|
fid_score = fid.compute() |
|
|
|
inception = InceptionScore() |
|
|
|
|
|
inception.update(fake_images) |
|
incep_score = inception.compute() |
|
|
|
|
|
return fid_score.item(), incep_score[0].item() |
|
|
|
fid_score, incep_score = calculate_fid_is_score(generator, num_classes) |
|
print("FID Score:", fid_score) |
|
print("IS Score:", incep_score) |
|
|
|
"""# Classifier Training""" |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import matplotlib.pyplot as plt |
|
from tqdm import tqdm |
|
from torch.utils.data import TensorDataset |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data import random_split |
|
import torch.optim as optim |
|
from torch.optim.lr_scheduler import StepLR |
|
import torch.nn.functional as F |
|
import matplotlib.pyplot as plt |
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
import argparse |
|
import os |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader |
|
from torchvision.datasets import ImageFolder |
|
import torchvision.transforms as T |
|
import torch |
|
from torchvision.utils import make_grid |
|
from torch.utils.data import Dataset |
|
from torchvision.utils import save_image |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from tqdm import tqdm |
|
import torch.nn.functional as F |
|
import pandas as pd |
|
from PIL import Image |
|
import os |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import pandas as pd |
|
import glob |
|
from torch.autograd import Variable |
|
import torch.autograd as autograd |
|
from torchvision.models import mobilenet_v2 |
|
|
|
from torchvision import models, transforms |
|
ngpu = torch.cuda.device_count() |
|
print('num gpus available: ', ngpu) |
|
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu") |
|
|
|
|
|
|
|
image_dir = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Train_data" |
|
sketch_dir = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Train/Contours" |
|
labels_df = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Train/Train_labels.csv" |
|
|
|
image_dir_test = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Test/Test" |
|
sketch_dir_val = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Test/Test_contours" |
|
labels_df_val = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Test/Test_Labels.csv" |
|
|
|
lambda_seg = 2.0 |
|
num_classes = 7 |
|
image_size = 128 |
|
batch_size = 32 |
|
stats_image = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) |
|
stats_sketch = (0,5), (0.5) |
|
|
|
|
|
def add_gaussian_noise(image, mean=0, stddev=1): |
|
|
|
noise = torch.randn_like(image) |
|
|
|
noisy_image = image + noise |
|
|
|
return noisy_image |
|
|
|
|
|
|
|
transform_image = T.Compose( |
|
[ |
|
T.Resize(image_size), |
|
T.CenterCrop(image_size), |
|
T.ToTensor(), |
|
T.Normalize(*stats_image), |
|
] |
|
) |
|
|
|
transform_sketch = T.Compose( |
|
[ |
|
T.Resize(image_size), |
|
T.CenterCrop(image_size), |
|
|
|
|
|
] |
|
) |
|
train_ds = ImageSketchDataset( |
|
image_dir, |
|
sketch_dir, |
|
labels_df, |
|
transform_image=transform_image, |
|
transform_sketch=transform_sketch, |
|
) |
|
|
|
val_ds = ImageSketchDataset( |
|
image_dir_test, |
|
sketch_dir_val, |
|
labels_df_val, |
|
transform_image=transform_image, |
|
transform_sketch=transform_sketch, |
|
) |
|
|
|
train_dl = DataLoader( |
|
train_ds, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=6, |
|
pin_memory=True, |
|
) |
|
|
|
val_dl = DataLoader( |
|
val_ds, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=6, |
|
pin_memory=True, |
|
) |
|
def display_digit(image, label): |
|
fig, ax = plt.subplots(figsize=(4, 4)) |
|
ax.imshow(image.numpy().squeeze(), cmap='gray') |
|
ax.axis('off') |
|
ax.set_title(f"Label: {label}", fontsize=10) |
|
plt.show() |
|
|
|
|
|
|
|
class ImageClassificationBase(nn.Module): |
|
|
|
def training_step(self, batch): |
|
images, _, labels = batch |
|
images = images.to(device) |
|
labels = labels.to(device) |
|
labels = torch.argmax(labels, dim=1).type(torch.long) |
|
out = self(images) |
|
loss = F.cross_entropy(out, labels) |
|
acc = self.accuracy(out, labels) |
|
return loss,acc.item() |
|
|
|
def training_epoch_end(self, outputs): |
|
batch_losses = [x['loss'] for x in outputs] |
|
epoch_losses = torch.stack(batch_losses).mean() |
|
batch_train_acc = [x['train_acc'] for x in outputs] |
|
epoch_train_acc = torch.stack(batch_train_acc).mean() |
|
return {'train_loss':epoch_losses.item(), 'train_acc':epoch_train_acc.item()} |
|
|
|
def validation_step(self, batch): |
|
images,_, labels = batch |
|
images = images.to(device) |
|
labels = labels.to(device) |
|
labels = torch.argmax(labels, dim=1).type(torch.long) |
|
out = self(images) |
|
loss = F.cross_entropy(out, labels) |
|
acc = self.accuracy(out, labels) |
|
return {'val_loss' : loss, 'val_acc' : acc} |
|
|
|
def validation_epoch_end(self, outputs): |
|
batch_losses = [x['val_loss'] for x in outputs] |
|
epoch_loss = torch.stack(batch_losses).mean() |
|
batch_accs = [x['val_acc'] for x in outputs] |
|
epoch_acc = torch.stack(batch_accs).mean() |
|
return {'val_loss' : epoch_loss.item(), 'val_acc' : epoch_acc.item()} |
|
|
|
def accuracy(self, outputs, labels): |
|
_, preds = torch.max(outputs, dim=1) |
|
return torch.tensor(torch.sum(preds == labels).item() / len(preds)) |
|
|
|
def epoch_end(self, epoch, result): |
|
print("Epoch [{}], train_loss:{:.4f}, val_loss:{:.4f}, val_acc:{:.4f}, train_acc:{:.4f}".format( |
|
epoch, result['train_loss'], result['val_loss'], result['val_acc'], result['train_acc'])) |
|
|
|
class MNISTCNNModel(ImageClassificationBase): |
|
def __init__(self, num_classes): |
|
super().__init__() |
|
self.num_classes = num_classes |
|
self.network = nn.Sequential ( |
|
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=7, stride=1, padding=3), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=2, stride=1), |
|
|
|
nn.Conv2d(in_channels=16, out_channels=8, kernel_size=5, stride=1, padding=2), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=2, stride=1), |
|
|
|
nn.Conv2d(in_channels=8, out_channels=4, kernel_size=3, stride=2, padding=1), |
|
nn.ReLU(), |
|
nn.AvgPool2d(kernel_size=2, stride=1), |
|
nn.Flatten(), |
|
nn.Linear(15376, num_classes), |
|
nn.Softmax(dim=1) |
|
) |
|
|
|
def forward(self, x): |
|
return self.network(x) |
|
|
|
def evaluate(model, val_loader): |
|
model.eval() |
|
outputs = [model.validation_step(batch) for batch in val_loader] |
|
return model.validation_epoch_end(outputs) |
|
|
|
def fit(epochs, lr, model, train_loader, val_loader, opt_func = torch.optim.Adam): |
|
history = [] |
|
optimizer = opt_func(model.parameters(), lr) |
|
scheduler = StepLR(optimizer, step_size=3, gamma=0.5) |
|
for epoch in range(epochs): |
|
model.train() |
|
train_losses = [] |
|
train_acc = [] |
|
for idx, batch in enumerate(tqdm(train_loader)): |
|
loss, acc = model.training_step(batch) |
|
train_losses.append(loss.item()) |
|
train_acc.append(acc) |
|
loss.backward() |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
scheduler.step() |
|
result = evaluate(model, val_loader) |
|
result['train_loss'] = (torch.tensor(train_losses)).mean().item() |
|
result['train_acc'] = (torch.tensor(train_acc)).mean().item() |
|
model.epoch_end(epoch, result) |
|
history.append(result) |
|
torch.save(model.state_dict(), f"classifier_{epoch}.pth") |
|
return history |
|
|
|
|
|
num_epochs = 10 |
|
lr = 0.001 |
|
opt_func = torch.optim.Adam |
|
|
|
model_10 = MNISTCNNModel(num_classes=7) |
|
model_10.to(device) |
|
|
|
history_10 = fit(num_epochs, lr, model_10, train_dl, val_dl, opt_func) |
|
|
|
"""# Classifier Evaluation""" |
|
|
|
import glob |
|
import random |
|
import os |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
import pandas as pd |
|
import glob |
|
from models.Generator import Generator |
|
import torchvision.transforms as T |
|
import torch |
|
|
|
sketch_dir_val = "/teamspace/studios/this_studio/DL_Assignment_4/Dataset/Test/Test_contours" |
|
ngpu = torch.cuda.device_count() |
|
all_sketches = glob.glob1(sketch_dir_val, "*.png") |
|
|
|
num_samples = 60 |
|
|
|
lambda_seg = 2.0 |
|
num_classes = 7 |
|
image_size = 128 |
|
batch_size = 32 |
|
stats_image = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) |
|
stats_sketch = (0,5), (0.5) |
|
transform_sketch = T.Compose( |
|
[ |
|
T.Resize(image_size), |
|
T.CenterCrop(image_size), |
|
|
|
|
|
] |
|
) |
|
class ImageClassificationBase(nn.Module): |
|
|
|
def training_step(self, batch): |
|
images, _, labels = batch |
|
images = images.to(device) |
|
labels = labels.to(device) |
|
labels = torch.argmax(labels, dim=1).type(torch.long) |
|
out = self(images) |
|
loss = F.cross_entropy(out, labels) |
|
acc = self.accuracy(out, labels) |
|
return loss,acc.item() |
|
|
|
def training_epoch_end(self, outputs): |
|
batch_losses = [x['loss'] for x in outputs] |
|
epoch_losses = torch.stack(batch_losses).mean() |
|
batch_train_acc = [x['train_acc'] for x in outputs] |
|
epoch_train_acc = torch.stack(batch_train_acc).mean() |
|
return {'train_loss':epoch_losses.item(), 'train_acc':epoch_train_acc.item()} |
|
|
|
def validation_step(self, batch): |
|
images,_, labels = batch |
|
images = images.to(device) |
|
labels = labels.to(device) |
|
labels = torch.argmax(labels, dim=1).type(torch.long) |
|
out = self(images) |
|
loss = F.cross_entropy(out, labels) |
|
acc = self.accuracy(out, labels) |
|
return {'val_loss' : loss, 'val_acc' : acc} |
|
|
|
def validation_epoch_end(self, outputs): |
|
batch_losses = [x['val_loss'] for x in outputs] |
|
epoch_loss = torch.stack(batch_losses).mean() |
|
batch_accs = [x['val_acc'] for x in outputs] |
|
epoch_acc = torch.stack(batch_accs).mean() |
|
return {'val_loss' : epoch_loss.item(), 'val_acc' : epoch_acc.item()} |
|
|
|
def accuracy(self, outputs, labels): |
|
_, preds = torch.max(outputs, dim=1) |
|
return torch.tensor(torch.sum(preds == labels).item() / len(preds)) |
|
|
|
def epoch_end(self, epoch, result): |
|
print("Epoch [{}], train_loss:{:.4f}, val_loss:{:.4f}, val_acc:{:.4f}, train_acc:{:.4f}".format( |
|
epoch, result['train_loss'], result['val_loss'], result['val_acc'], result['train_acc'])) |
|
|
|
class MNISTCNNModel(ImageClassificationBase): |
|
def __init__(self, num_classes): |
|
super().__init__() |
|
self.num_classes = num_classes |
|
self.network = nn.Sequential ( |
|
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=7, stride=1, padding=3), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=2, stride=1), |
|
|
|
nn.Conv2d(in_channels=16, out_channels=8, kernel_size=5, stride=1, padding=2), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=2, stride=1), |
|
|
|
nn.Conv2d(in_channels=8, out_channels=4, kernel_size=3, stride=2, padding=1), |
|
nn.ReLU(), |
|
nn.AvgPool2d(kernel_size=2, stride=1), |
|
nn.Flatten(), |
|
nn.Linear(15376, num_classes), |
|
nn.Softmax(dim=1) |
|
) |
|
|
|
def forward(self, x): |
|
return self.network(x) |
|
|
|
def Generate_Fakes(sketches,classof): |
|
|
|
noisy_sketchs = sketches |
|
noisy_sketchs_ = [] |
|
fake_labels = torch.ones(sketches.size(0) , device=sketches.device,dtype=torch.long) * classof |
|
for noisy_sketch, fake_label in zip(noisy_sketchs, fake_labels): |
|
channels = torch.zeros( |
|
size=(num_classes, *noisy_sketch.shape), device=noisy_sketch.device |
|
) |
|
channels[fake_label] = 1.0 |
|
noisy_sketch = torch.cat((noisy_sketch.unsqueeze(0), channels), dim=0) |
|
noisy_sketchs_.append(noisy_sketch) |
|
|
|
noisy_sketchs = torch.stack(noisy_sketchs_) |
|
|
|
|
|
fake_labels = F.one_hot(fake_labels, num_classes=7).squeeze(1).float().to(device) |
|
|
|
return noisy_sketchs, fake_labels |
|
|
|
def accuracy( outputs, labels): |
|
_, preds = torch.max(outputs, dim=1) |
|
return torch.tensor(torch.sum(preds == labels).item() / len(preds)) |
|
|
|
def load_sketch(sketch_path): |
|
sketch = transform_sketch(Image.open(sketch_path)) |
|
sketch_np = np.zeros_like(sketch) |
|
sketch_np[np.all(sketch) == 255] = 1.0 |
|
sketch_np = sketch_np.astype(np.float32) |
|
return torch.from_numpy(sketch_np).unsqueeze(0) |
|
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu") |
|
|
|
classifier = MNISTCNNModel(num_classes=7) |
|
classifier.load_state_dict(torch.load("/teamspace/studios/this_studio/DL_Assignment_4/CBNGAN/classifier_9.pth")) |
|
classifier.to(device) |
|
classifier.eval() |
|
|
|
generator = Generator(ngpu=1, num_classes=7).to(device) |
|
|
|
|
|
for class_label, classes in enumerate(["MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"]): |
|
sketch_filenames = np.random.choice(all_sketches,num_samples) |
|
sketches =[] |
|
for sketch_filename in sketch_filenames: |
|
sketches.append(load_sketch(os.path.join(sketch_dir_val, sketch_filename))) |
|
|
|
sketches = torch.cat(sketches) |
|
|
|
latent_input,gen_labels = Generate_Fakes(sketches,class_label) |
|
aux_fake_labels = torch.argmax(gen_labels, dim=1) |
|
aux_fake_labels = aux_fake_labels.type(torch.long).to(device) |
|
fake_images = generator(latent_input.to(device),aux_fake_labels) |
|
|
|
pred_class = classifier(fake_images) |
|
acc = accuracy(pred_class, aux_fake_labels) |
|
print(f"acc of class {class_label}: {acc}") |