File size: 4,508 Bytes
65eeb0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from torchvision import transforms
from torch.utils.data import DataLoader
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
import lightning as pl
import wandb

from src.dataset import ClassifierDataset, CustomDataset
from src.classifier import Classifier
from src.models import CycleGAN
from src.config import CFG

def train_classifier(image_size,
                     batch_size,
                     epochs,
                     resume_ckpt_path,
                     train_dir,
                     val_dir,
                     checkpoint_dir,
                     project,
                     job_name):
    
    clf_wandb_logger = WandbLogger(project=project, name=job_name, log_model="all")

    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),  # Resize image to 512x512
        transforms.ToTensor(),          
        transforms.Normalize(mean=[0.485], std=[0.229])  # Normalize image
    ])

    # Define dataset paths
    # train_dir = "/kaggle/working/CycleGan-CFE/train-data/train"
    # val_dir = "/kaggle/working/CycleGan-CFE/train-data/val"

    # Create datasets
    train_dataset = ClassifierDataset(root_dir=train_dir, transform=transform)
    val_dataset = ClassifierDataset(root_dir=val_dir, transform=transform)
    print("Total Training Images: ",len(train_dataset))
    print("Total Validation Images: ",len(val_dataset))

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=4)
    # Instantiate the classifier model
    clf = Classifier(transfer=True)

    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath=checkpoint_dir,
        filename='efficientnet_b2-epoch{epoch:02d}-val_loss{val_loss:.2f}',
        auto_insert_metric_name=False,
        save_weights_only=False,
        save_top_k=3,
        mode='min'
    )
    # Set up PyTorch Lightning Trainer with multiple GPUs and tqdm progress bar
    trainer = pl.Trainer(
        devices="auto",
        precision="16-mixed",
        accelerator="auto",
        max_epochs=epochs,
        accumulate_grad_batches=10,
        log_every_n_steps=1,
        check_val_every_n_epoch=1,
        benchmark=True,
        logger=clf_wandb_logger,
        callbacks=[checkpoint_callback],
    )

    # Train the classifier
    trainer.fit(clf, train_loader, val_loader, ckpt_path=resume_ckpt_path)
    wandb.finish()


def train_cyclegan(image_size,
                   batch_size,
                   epochs,
                   classifier_path,
                   resume_ckpt_path,
                   train_dir,
                   val_dir,
                   test_dir,
                   checkpoint_dir,
                   project,
                   job_name,
):


    testdata_dir = test_dir
    train_N = "0"
    train_P = "1"
    img_res = (image_size, image_size)

    test_dataset = CustomDataset(root_dir=testdata_dir, train_N=train_N, train_P=train_P, img_res=img_res)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    wandb_logger = WandbLogger(project=project, name=job_name, log_model="all")
    print(classifier_path)
    cyclegan = CycleGAN(train_dir=train_dir, val_dir=val_dir, test_dataloader=test_dataloader, classifier_path=classifier_path, checkpoint_dir=checkpoint_dir, gf=CFG.GAN_FILTERS, df=CFG.DIS_FILTERS)

    gan_checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir,
                                        filename='cyclegan-epoch_{epoch}-vloss_{val_generator_loss:.2f}',
                                        monitor='val_generator_loss',
                                        save_top_k=3,
                                        save_last=True,
                                        save_weights_only=False,
                                        verbose=True,
                                        mode='min')


    # Create the trainer
    trainer = pl.Trainer(
        accelerator="auto",
        precision="16-mixed",
        max_epochs=epochs,
        log_every_n_steps=1,
        benchmark=True,
        devices="auto",
        logger=wandb_logger,
        callbacks= [gan_checkpoint_callback]
    )

    # Train the CycleGAN model
    trainer.fit(cyclegan, ckpt_path=resume_ckpt_path)