Spaces:
Sleeping
Sleeping
from torchvision import transforms | |
from torch.utils.data import DataLoader | |
from lightning.pytorch.loggers.wandb import WandbLogger | |
from lightning.pytorch.callbacks import ModelCheckpoint | |
import lightning as pl | |
import wandb | |
from src.dataset import ClassifierDataset, CustomDataset | |
from src.classifier import Classifier | |
from src.models import CycleGAN | |
from src.config import CFG | |
def train_classifier(image_size, | |
batch_size, | |
epochs, | |
resume_ckpt_path, | |
train_dir, | |
val_dir, | |
checkpoint_dir, | |
project, | |
job_name): | |
clf_wandb_logger = WandbLogger(project=project, name=job_name, log_model="all") | |
transform = transforms.Compose([ | |
transforms.Resize((image_size, image_size)), # Resize image to 512x512 | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485], std=[0.229]) # Normalize image | |
]) | |
# Define dataset paths | |
# train_dir = "/kaggle/working/CycleGan-CFE/train-data/train" | |
# val_dir = "/kaggle/working/CycleGan-CFE/train-data/val" | |
# Create datasets | |
train_dataset = ClassifierDataset(root_dir=train_dir, transform=transform) | |
val_dataset = ClassifierDataset(root_dir=val_dir, transform=transform) | |
print("Total Training Images: ",len(train_dataset)) | |
print("Total Validation Images: ",len(val_dataset)) | |
# Create data loaders | |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4) | |
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=4) | |
# Instantiate the classifier model | |
clf = Classifier(transfer=True) | |
checkpoint_callback = ModelCheckpoint( | |
monitor='val_loss', | |
dirpath=checkpoint_dir, | |
filename='efficientnet_b2-epoch{epoch:02d}-val_loss{val_loss:.2f}', | |
auto_insert_metric_name=False, | |
save_weights_only=False, | |
save_top_k=3, | |
mode='min' | |
) | |
# Set up PyTorch Lightning Trainer with multiple GPUs and tqdm progress bar | |
trainer = pl.Trainer( | |
devices="auto", | |
precision="16-mixed", | |
accelerator="auto", | |
max_epochs=epochs, | |
accumulate_grad_batches=10, | |
log_every_n_steps=1, | |
check_val_every_n_epoch=1, | |
benchmark=True, | |
logger=clf_wandb_logger, | |
callbacks=[checkpoint_callback], | |
) | |
# Train the classifier | |
trainer.fit(clf, train_loader, val_loader, ckpt_path=resume_ckpt_path) | |
wandb.finish() | |
def train_cyclegan(image_size, | |
batch_size, | |
epochs, | |
classifier_path, | |
resume_ckpt_path, | |
train_dir, | |
val_dir, | |
test_dir, | |
checkpoint_dir, | |
project, | |
job_name, | |
): | |
testdata_dir = test_dir | |
train_N = "0" | |
train_P = "1" | |
img_res = (image_size, image_size) | |
test_dataset = CustomDataset(root_dir=testdata_dir, train_N=train_N, train_P=train_P, img_res=img_res) | |
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) | |
wandb_logger = WandbLogger(project=project, name=job_name, log_model="all") | |
print(classifier_path) | |
cyclegan = CycleGAN(train_dir=train_dir, val_dir=val_dir, test_dataloader=test_dataloader, classifier_path=classifier_path, checkpoint_dir=checkpoint_dir, gf=CFG.GAN_FILTERS, df=CFG.DIS_FILTERS) | |
gan_checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir, | |
filename='cyclegan-epoch_{epoch}-vloss_{val_generator_loss:.2f}', | |
monitor='val_generator_loss', | |
save_top_k=3, | |
save_last=True, | |
save_weights_only=False, | |
verbose=True, | |
mode='min') | |
# Create the trainer | |
trainer = pl.Trainer( | |
accelerator="auto", | |
precision="16-mixed", | |
max_epochs=epochs, | |
log_every_n_steps=1, | |
benchmark=True, | |
devices="auto", | |
logger=wandb_logger, | |
callbacks= [gan_checkpoint_callback] | |
) | |
# Train the CycleGAN model | |
trainer.fit(cyclegan, ckpt_path=resume_ckpt_path) | |