Spaces:
Sleeping
Sleeping
anindya-hf-2002
commited on
Commit
•
634fc83
1
Parent(s):
8e3c016
upload 3 files
Browse files- src/dataset.py +65 -0
- src/generate_images.py +96 -0
- src/train.py +124 -0
src/dataset.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from torchvision import transforms
|
4 |
+
import os
|
5 |
+
|
6 |
+
class ClassifierDataset(Dataset):
|
7 |
+
def __init__(self, root_dir, transform=None):
|
8 |
+
self.root_dir = root_dir
|
9 |
+
self.transform = transform
|
10 |
+
|
11 |
+
self.classes = ['0', '1']
|
12 |
+
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
|
13 |
+
|
14 |
+
self.samples = self._make_dataset()
|
15 |
+
|
16 |
+
def _make_dataset(self):
|
17 |
+
samples = []
|
18 |
+
for class_name in self.classes:
|
19 |
+
class_dir = os.path.join(self.root_dir, class_name)
|
20 |
+
for img_name in os.listdir(class_dir):
|
21 |
+
img_path = os.path.join(class_dir, img_name)
|
22 |
+
samples.append((img_path, self.class_to_idx[class_name]))
|
23 |
+
return samples
|
24 |
+
|
25 |
+
def __len__(self):
|
26 |
+
return len(self.samples)
|
27 |
+
|
28 |
+
def __getitem__(self, idx):
|
29 |
+
img_path, label = self.samples[idx]
|
30 |
+
img = Image.open(img_path).convert('L') # Convert to grayscale
|
31 |
+
if self.transform:
|
32 |
+
img = self.transform(img)
|
33 |
+
return img, label
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
class CustomDataset(Dataset):
|
38 |
+
def __init__(self, root_dir, train_N, train_P, img_res):
|
39 |
+
self.root_dir = root_dir
|
40 |
+
self.train_N = train_N
|
41 |
+
self.train_P = train_P
|
42 |
+
self.img_res = img_res
|
43 |
+
self.transforms = transforms.Compose([
|
44 |
+
transforms.Resize(img_res),
|
45 |
+
transforms.ToTensor(),
|
46 |
+
transforms.Normalize(mean=[0.5], std=[0.5]) # Assuming grayscale images
|
47 |
+
])
|
48 |
+
|
49 |
+
def __len__(self):
|
50 |
+
return min(len(os.listdir(os.path.join(self.root_dir, self.train_N))),
|
51 |
+
len(os.listdir(os.path.join(self.root_dir, self.train_P))))
|
52 |
+
|
53 |
+
def __getitem__(self, idx):
|
54 |
+
normal_path = os.path.join(self.root_dir, self.train_N, os.listdir(os.path.join(self.root_dir, self.train_N))[idx])
|
55 |
+
pneumo_path = os.path.join(self.root_dir, self.train_P, os.listdir(os.path.join(self.root_dir, self.train_P))[idx])
|
56 |
+
|
57 |
+
normal_img = Image.open(normal_path).convert("L") # Load as grayscale
|
58 |
+
pneumo_img = Image.open(pneumo_path).convert("L") # Load as grayscale
|
59 |
+
|
60 |
+
normal_img = self.transforms(normal_img)
|
61 |
+
pneumo_img = self.transforms(pneumo_img)
|
62 |
+
|
63 |
+
return normal_img, pneumo_img
|
64 |
+
|
65 |
+
|
src/generate_images.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
from torch.utils.data import DataLoader, Dataset
|
6 |
+
from torchvision import transforms
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from src.models import ResUNetGenerator
|
10 |
+
|
11 |
+
# Custom Dataset
|
12 |
+
class ImageDataset(Dataset):
|
13 |
+
def __init__(self, image_paths, transform=None):
|
14 |
+
self.image_paths = image_paths
|
15 |
+
self.transform = transform
|
16 |
+
|
17 |
+
def __len__(self):
|
18 |
+
return len(self.image_paths)
|
19 |
+
|
20 |
+
def __getitem__(self, idx):
|
21 |
+
img_path = self.image_paths[idx]
|
22 |
+
image = Image.open(img_path).convert('L')
|
23 |
+
if self.transform:
|
24 |
+
image = self.transform(image)
|
25 |
+
return image, img_path
|
26 |
+
|
27 |
+
# Function to save image
|
28 |
+
def save_image(tensor, path):
|
29 |
+
if tensor.is_cuda:
|
30 |
+
tensor = tensor.cpu()
|
31 |
+
|
32 |
+
array = tensor.permute(1, 2, 0).detach().numpy()
|
33 |
+
array = (array * 0.5 + 0.5) * 255
|
34 |
+
array = array.astype(np.uint8)
|
35 |
+
if array.shape[2] == 1:
|
36 |
+
array = array.squeeze(2)
|
37 |
+
image = Image.fromarray(array, mode='L')
|
38 |
+
else:
|
39 |
+
image = Image.fromarray(array)
|
40 |
+
image.save(path)
|
41 |
+
|
42 |
+
# Function to load model
|
43 |
+
def load_model(checkpoint_path, model_class, device):
|
44 |
+
model = model_class().to(device)
|
45 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
|
46 |
+
model.eval()
|
47 |
+
return model
|
48 |
+
|
49 |
+
def generate_images(image_folder, g_NP_checkpoint, g_PN_checkpoint, output_dir='data/translated_images', batch_size=16):
|
50 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
51 |
+
|
52 |
+
# Load models
|
53 |
+
g_NP = load_model(g_NP_checkpoint, lambda: ResUNetGenerator(gf=32, channels=1), device)
|
54 |
+
g_PN = load_model(g_PN_checkpoint, lambda: ResUNetGenerator(gf=32, channels=1), device)
|
55 |
+
|
56 |
+
# Create output directories
|
57 |
+
os.makedirs(os.path.join(output_dir, '0'), exist_ok=True)
|
58 |
+
os.makedirs(os.path.join(output_dir, '1'), exist_ok=True)
|
59 |
+
|
60 |
+
# Collect image paths
|
61 |
+
image_paths_0 = [os.path.join(image_folder, '0', fname) for fname in os.listdir(os.path.join(image_folder, '0')) if fname.endswith(('.png', '.jpg', '.jpeg'))]
|
62 |
+
image_paths_1 = [os.path.join(image_folder, '1', fname) for fname in os.listdir(os.path.join(image_folder, '1')) if fname.endswith(('.png', '.jpg', '.jpeg'))]
|
63 |
+
|
64 |
+
# Prepare dataset and dataloader
|
65 |
+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485], std=[0.229])])
|
66 |
+
dataset_0 = ImageDataset(image_paths_0, transform)
|
67 |
+
dataset_1 = ImageDataset(image_paths_1, transform)
|
68 |
+
dataloader_0 = DataLoader(dataset_0, batch_size=batch_size, shuffle=False)
|
69 |
+
dataloader_1 = DataLoader(dataset_1, batch_size=batch_size, shuffle=False)
|
70 |
+
|
71 |
+
# Process images from negative (0) to positive (1)
|
72 |
+
with torch.no_grad():
|
73 |
+
for batch, paths in tqdm(dataloader_0, desc="Converting N to P: "):
|
74 |
+
batch = batch.to(device)
|
75 |
+
translated_images = g_NP(batch)
|
76 |
+
translated_images = g_PN(translated_images)
|
77 |
+
for img, path in zip(translated_images, paths):
|
78 |
+
save_path = os.path.join(output_dir, '1', os.path.basename(path))
|
79 |
+
save_image(img, save_path)
|
80 |
+
|
81 |
+
# Process images from positive (1) to negative (0)
|
82 |
+
for batch, paths in tqdm(dataloader_1, desc="Converting P to N: "):
|
83 |
+
batch = batch.to(device)
|
84 |
+
translated_images = g_PN(batch)
|
85 |
+
translated_images = g_NP(translated_images)
|
86 |
+
for img, path in zip(translated_images, paths):
|
87 |
+
save_path = os.path.join(output_dir, '0', os.path.basename(path))
|
88 |
+
save_image(img, save_path)
|
89 |
+
|
90 |
+
if __name__ == '__main__':
|
91 |
+
image_folder = r'data\rsna-pneumonia-dataset\train'
|
92 |
+
g_NP_checkpoint = 'models\g_NP_best.ckpt'
|
93 |
+
g_PN_checkpoint = 'models\g_PN_best.ckpt'
|
94 |
+
|
95 |
+
|
96 |
+
generate_images(image_folder, g_NP_checkpoint, g_PN_checkpoint)
|
src/train.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from lightning.pytorch.loggers.wandb import WandbLogger
|
4 |
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
5 |
+
import lightning as pl
|
6 |
+
import wandb
|
7 |
+
|
8 |
+
from src.dataset import ClassifierDataset, CustomDataset
|
9 |
+
from src.classifier import Classifier
|
10 |
+
from src.models import CycleGAN
|
11 |
+
from src.config import CFG
|
12 |
+
|
13 |
+
def train_classifier(image_size,
|
14 |
+
batch_size,
|
15 |
+
epochs,
|
16 |
+
resume_ckpt_path,
|
17 |
+
train_dir,
|
18 |
+
val_dir,
|
19 |
+
checkpoint_dir,
|
20 |
+
project,
|
21 |
+
job_name):
|
22 |
+
|
23 |
+
clf_wandb_logger = WandbLogger(project=project, name=job_name, log_model="all")
|
24 |
+
|
25 |
+
transform = transforms.Compose([
|
26 |
+
transforms.Resize((image_size, image_size)), # Resize image to 512x512
|
27 |
+
transforms.ToTensor(),
|
28 |
+
transforms.Normalize(mean=[0.485], std=[0.229]) # Normalize image
|
29 |
+
])
|
30 |
+
|
31 |
+
# Define dataset paths
|
32 |
+
# train_dir = "/kaggle/working/CycleGan-CFE/train-data/train"
|
33 |
+
# val_dir = "/kaggle/working/CycleGan-CFE/train-data/val"
|
34 |
+
|
35 |
+
# Create datasets
|
36 |
+
train_dataset = ClassifierDataset(root_dir=train_dir, transform=transform)
|
37 |
+
val_dataset = ClassifierDataset(root_dir=val_dir, transform=transform)
|
38 |
+
print("Total Training Images: ",len(train_dataset))
|
39 |
+
print("Total Validation Images: ",len(val_dataset))
|
40 |
+
|
41 |
+
# Create data loaders
|
42 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)
|
43 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=4)
|
44 |
+
# Instantiate the classifier model
|
45 |
+
clf = Classifier(transfer=True)
|
46 |
+
|
47 |
+
checkpoint_callback = ModelCheckpoint(
|
48 |
+
monitor='val_loss',
|
49 |
+
dirpath=checkpoint_dir,
|
50 |
+
filename='efficientnet_b2-epoch{epoch:02d}-val_loss{val_loss:.2f}',
|
51 |
+
auto_insert_metric_name=False,
|
52 |
+
save_weights_only=False,
|
53 |
+
save_top_k=3,
|
54 |
+
mode='min'
|
55 |
+
)
|
56 |
+
# Set up PyTorch Lightning Trainer with multiple GPUs and tqdm progress bar
|
57 |
+
trainer = pl.Trainer(
|
58 |
+
devices="auto",
|
59 |
+
precision="16-mixed",
|
60 |
+
accelerator="auto",
|
61 |
+
max_epochs=epochs,
|
62 |
+
accumulate_grad_batches=10,
|
63 |
+
log_every_n_steps=1,
|
64 |
+
check_val_every_n_epoch=1,
|
65 |
+
benchmark=True,
|
66 |
+
logger=clf_wandb_logger,
|
67 |
+
callbacks=[checkpoint_callback],
|
68 |
+
)
|
69 |
+
|
70 |
+
# Train the classifier
|
71 |
+
trainer.fit(clf, train_loader, val_loader, ckpt_path=resume_ckpt_path)
|
72 |
+
wandb.finish()
|
73 |
+
|
74 |
+
|
75 |
+
def train_cyclegan(image_size,
|
76 |
+
batch_size,
|
77 |
+
epochs,
|
78 |
+
classifier_path,
|
79 |
+
resume_ckpt_path,
|
80 |
+
train_dir,
|
81 |
+
val_dir,
|
82 |
+
test_dir,
|
83 |
+
checkpoint_dir,
|
84 |
+
project,
|
85 |
+
job_name,
|
86 |
+
):
|
87 |
+
|
88 |
+
|
89 |
+
testdata_dir = test_dir
|
90 |
+
train_N = "0"
|
91 |
+
train_P = "1"
|
92 |
+
img_res = (image_size, image_size)
|
93 |
+
|
94 |
+
test_dataset = CustomDataset(root_dir=testdata_dir, train_N=train_N, train_P=train_P, img_res=img_res)
|
95 |
+
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
96 |
+
|
97 |
+
wandb_logger = WandbLogger(project=project, name=job_name, log_model="all")
|
98 |
+
print(classifier_path)
|
99 |
+
cyclegan = CycleGAN(train_dir=train_dir, val_dir=val_dir, test_dataloader=test_dataloader, classifier_path=classifier_path, checkpoint_dir=checkpoint_dir, gf=CFG.GAN_FILTERS, df=CFG.DIS_FILTERS)
|
100 |
+
|
101 |
+
gan_checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir,
|
102 |
+
filename='cyclegan-epoch_{epoch}-vloss_{val_generator_loss:.2f}',
|
103 |
+
monitor='val_generator_loss',
|
104 |
+
save_top_k=3,
|
105 |
+
save_last=True,
|
106 |
+
save_weights_only=False,
|
107 |
+
verbose=True,
|
108 |
+
mode='min')
|
109 |
+
|
110 |
+
|
111 |
+
# Create the trainer
|
112 |
+
trainer = pl.Trainer(
|
113 |
+
accelerator="auto",
|
114 |
+
precision="16-mixed",
|
115 |
+
max_epochs=epochs,
|
116 |
+
log_every_n_steps=1,
|
117 |
+
benchmark=True,
|
118 |
+
devices="auto",
|
119 |
+
logger=wandb_logger,
|
120 |
+
callbacks= [gan_checkpoint_callback]
|
121 |
+
)
|
122 |
+
|
123 |
+
# Train the CycleGAN model
|
124 |
+
trainer.fit(cyclegan, ckpt_path=resume_ckpt_path)
|