cfe-gen / src /train.py
anindya-hf-2002's picture
upload application files
65eeb0e verified
raw
history blame
4.51 kB
from torchvision import transforms
from torch.utils.data import DataLoader
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
import lightning as pl
import wandb
from src.dataset import ClassifierDataset, CustomDataset
from src.classifier import Classifier
from src.models import CycleGAN
from src.config import CFG
def train_classifier(image_size,
batch_size,
epochs,
resume_ckpt_path,
train_dir,
val_dir,
checkpoint_dir,
project,
job_name):
clf_wandb_logger = WandbLogger(project=project, name=job_name, log_model="all")
transform = transforms.Compose([
transforms.Resize((image_size, image_size)), # Resize image to 512x512
transforms.ToTensor(),
transforms.Normalize(mean=[0.485], std=[0.229]) # Normalize image
])
# Define dataset paths
# train_dir = "/kaggle/working/CycleGan-CFE/train-data/train"
# val_dir = "/kaggle/working/CycleGan-CFE/train-data/val"
# Create datasets
train_dataset = ClassifierDataset(root_dir=train_dir, transform=transform)
val_dataset = ClassifierDataset(root_dir=val_dir, transform=transform)
print("Total Training Images: ",len(train_dataset))
print("Total Validation Images: ",len(val_dataset))
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=4)
# Instantiate the classifier model
clf = Classifier(transfer=True)
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
dirpath=checkpoint_dir,
filename='efficientnet_b2-epoch{epoch:02d}-val_loss{val_loss:.2f}',
auto_insert_metric_name=False,
save_weights_only=False,
save_top_k=3,
mode='min'
)
# Set up PyTorch Lightning Trainer with multiple GPUs and tqdm progress bar
trainer = pl.Trainer(
devices="auto",
precision="16-mixed",
accelerator="auto",
max_epochs=epochs,
accumulate_grad_batches=10,
log_every_n_steps=1,
check_val_every_n_epoch=1,
benchmark=True,
logger=clf_wandb_logger,
callbacks=[checkpoint_callback],
)
# Train the classifier
trainer.fit(clf, train_loader, val_loader, ckpt_path=resume_ckpt_path)
wandb.finish()
def train_cyclegan(image_size,
batch_size,
epochs,
classifier_path,
resume_ckpt_path,
train_dir,
val_dir,
test_dir,
checkpoint_dir,
project,
job_name,
):
testdata_dir = test_dir
train_N = "0"
train_P = "1"
img_res = (image_size, image_size)
test_dataset = CustomDataset(root_dir=testdata_dir, train_N=train_N, train_P=train_P, img_res=img_res)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
wandb_logger = WandbLogger(project=project, name=job_name, log_model="all")
print(classifier_path)
cyclegan = CycleGAN(train_dir=train_dir, val_dir=val_dir, test_dataloader=test_dataloader, classifier_path=classifier_path, checkpoint_dir=checkpoint_dir, gf=CFG.GAN_FILTERS, df=CFG.DIS_FILTERS)
gan_checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir,
filename='cyclegan-epoch_{epoch}-vloss_{val_generator_loss:.2f}',
monitor='val_generator_loss',
save_top_k=3,
save_last=True,
save_weights_only=False,
verbose=True,
mode='min')
# Create the trainer
trainer = pl.Trainer(
accelerator="auto",
precision="16-mixed",
max_epochs=epochs,
log_every_n_steps=1,
benchmark=True,
devices="auto",
logger=wandb_logger,
callbacks= [gan_checkpoint_callback]
)
# Train the CycleGAN model
trainer.fit(cyclegan, ckpt_path=resume_ckpt_path)