Spaces:
Sleeping
Sleeping
File size: 4,508 Bytes
65eeb0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
from torchvision import transforms
from torch.utils.data import DataLoader
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
import lightning as pl
import wandb
from src.dataset import ClassifierDataset, CustomDataset
from src.classifier import Classifier
from src.models import CycleGAN
from src.config import CFG
def train_classifier(image_size,
batch_size,
epochs,
resume_ckpt_path,
train_dir,
val_dir,
checkpoint_dir,
project,
job_name):
clf_wandb_logger = WandbLogger(project=project, name=job_name, log_model="all")
transform = transforms.Compose([
transforms.Resize((image_size, image_size)), # Resize image to 512x512
transforms.ToTensor(),
transforms.Normalize(mean=[0.485], std=[0.229]) # Normalize image
])
# Define dataset paths
# train_dir = "/kaggle/working/CycleGan-CFE/train-data/train"
# val_dir = "/kaggle/working/CycleGan-CFE/train-data/val"
# Create datasets
train_dataset = ClassifierDataset(root_dir=train_dir, transform=transform)
val_dataset = ClassifierDataset(root_dir=val_dir, transform=transform)
print("Total Training Images: ",len(train_dataset))
print("Total Validation Images: ",len(val_dataset))
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=4)
# Instantiate the classifier model
clf = Classifier(transfer=True)
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
dirpath=checkpoint_dir,
filename='efficientnet_b2-epoch{epoch:02d}-val_loss{val_loss:.2f}',
auto_insert_metric_name=False,
save_weights_only=False,
save_top_k=3,
mode='min'
)
# Set up PyTorch Lightning Trainer with multiple GPUs and tqdm progress bar
trainer = pl.Trainer(
devices="auto",
precision="16-mixed",
accelerator="auto",
max_epochs=epochs,
accumulate_grad_batches=10,
log_every_n_steps=1,
check_val_every_n_epoch=1,
benchmark=True,
logger=clf_wandb_logger,
callbacks=[checkpoint_callback],
)
# Train the classifier
trainer.fit(clf, train_loader, val_loader, ckpt_path=resume_ckpt_path)
wandb.finish()
def train_cyclegan(image_size,
batch_size,
epochs,
classifier_path,
resume_ckpt_path,
train_dir,
val_dir,
test_dir,
checkpoint_dir,
project,
job_name,
):
testdata_dir = test_dir
train_N = "0"
train_P = "1"
img_res = (image_size, image_size)
test_dataset = CustomDataset(root_dir=testdata_dir, train_N=train_N, train_P=train_P, img_res=img_res)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
wandb_logger = WandbLogger(project=project, name=job_name, log_model="all")
print(classifier_path)
cyclegan = CycleGAN(train_dir=train_dir, val_dir=val_dir, test_dataloader=test_dataloader, classifier_path=classifier_path, checkpoint_dir=checkpoint_dir, gf=CFG.GAN_FILTERS, df=CFG.DIS_FILTERS)
gan_checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir,
filename='cyclegan-epoch_{epoch}-vloss_{val_generator_loss:.2f}',
monitor='val_generator_loss',
save_top_k=3,
save_last=True,
save_weights_only=False,
verbose=True,
mode='min')
# Create the trainer
trainer = pl.Trainer(
accelerator="auto",
precision="16-mixed",
max_epochs=epochs,
log_every_n_steps=1,
benchmark=True,
devices="auto",
logger=wandb_logger,
callbacks= [gan_checkpoint_callback]
)
# Train the CycleGAN model
trainer.fit(cyclegan, ckpt_path=resume_ckpt_path)
|