anindya-hf-2002 commited on
Commit
4aac425
1 Parent(s): 4236119

delete files

Browse files
Files changed (1) hide show
  1. src/train.py +0 -124
src/train.py DELETED
@@ -1,124 +0,0 @@
1
- from torchvision import transforms
2
- from torch.utils.data import DataLoader
3
- from lightning.pytorch.loggers.wandb import WandbLogger
4
- from lightning.pytorch.callbacks import ModelCheckpoint
5
- import lightning as pl
6
- import wandb
7
-
8
- from src.dataset import ClassifierDataset, CustomDataset
9
- from src.classifier import Classifier
10
- from src.models import CycleGAN
11
- from src.config import CFG
12
-
13
- def train_classifier(image_size,
14
- batch_size,
15
- epochs,
16
- resume_ckpt_path,
17
- train_dir,
18
- val_dir,
19
- checkpoint_dir,
20
- project,
21
- job_name):
22
-
23
- clf_wandb_logger = WandbLogger(project=project, name=job_name, log_model="all")
24
-
25
- transform = transforms.Compose([
26
- transforms.Resize((image_size, image_size)), # Resize image to 512x512
27
- transforms.ToTensor(),
28
- transforms.Normalize(mean=[0.485], std=[0.229]) # Normalize image
29
- ])
30
-
31
- # Define dataset paths
32
- # train_dir = "/kaggle/working/CycleGan-CFE/train-data/train"
33
- # val_dir = "/kaggle/working/CycleGan-CFE/train-data/val"
34
-
35
- # Create datasets
36
- train_dataset = ClassifierDataset(root_dir=train_dir, transform=transform)
37
- val_dataset = ClassifierDataset(root_dir=val_dir, transform=transform)
38
- print("Total Training Images: ",len(train_dataset))
39
- print("Total Validation Images: ",len(val_dataset))
40
-
41
- # Create data loaders
42
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)
43
- val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=4)
44
- # Instantiate the classifier model
45
- clf = Classifier(transfer=True)
46
-
47
- checkpoint_callback = ModelCheckpoint(
48
- monitor='val_loss',
49
- dirpath=checkpoint_dir,
50
- filename='efficientnet_b2-epoch{epoch:02d}-val_loss{val_loss:.2f}',
51
- auto_insert_metric_name=False,
52
- save_weights_only=False,
53
- save_top_k=3,
54
- mode='min'
55
- )
56
- # Set up PyTorch Lightning Trainer with multiple GPUs and tqdm progress bar
57
- trainer = pl.Trainer(
58
- devices="auto",
59
- precision="16-mixed",
60
- accelerator="auto",
61
- max_epochs=epochs,
62
- accumulate_grad_batches=10,
63
- log_every_n_steps=1,
64
- check_val_every_n_epoch=1,
65
- benchmark=True,
66
- logger=clf_wandb_logger,
67
- callbacks=[checkpoint_callback],
68
- )
69
-
70
- # Train the classifier
71
- trainer.fit(clf, train_loader, val_loader, ckpt_path=resume_ckpt_path)
72
- wandb.finish()
73
-
74
-
75
- def train_cyclegan(image_size,
76
- batch_size,
77
- epochs,
78
- classifier_path,
79
- resume_ckpt_path,
80
- train_dir,
81
- val_dir,
82
- test_dir,
83
- checkpoint_dir,
84
- project,
85
- job_name,
86
- ):
87
-
88
-
89
- testdata_dir = test_dir
90
- train_N = "0"
91
- train_P = "1"
92
- img_res = (image_size, image_size)
93
-
94
- test_dataset = CustomDataset(root_dir=testdata_dir, train_N=train_N, train_P=train_P, img_res=img_res)
95
- test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
96
-
97
- wandb_logger = WandbLogger(project=project, name=job_name, log_model="all")
98
- print(classifier_path)
99
- cyclegan = CycleGAN(train_dir=train_dir, val_dir=val_dir, test_dataloader=test_dataloader, classifier_path=classifier_path, checkpoint_dir=checkpoint_dir, gf=CFG.GAN_FILTERS, df=CFG.DIS_FILTERS)
100
-
101
- gan_checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir,
102
- filename='cyclegan-epoch_{epoch}-vloss_{val_generator_loss:.2f}',
103
- monitor='val_generator_loss',
104
- save_top_k=3,
105
- save_last=True,
106
- save_weights_only=False,
107
- verbose=True,
108
- mode='min')
109
-
110
-
111
- # Create the trainer
112
- trainer = pl.Trainer(
113
- accelerator="auto",
114
- precision="16-mixed",
115
- max_epochs=epochs,
116
- log_every_n_steps=1,
117
- benchmark=True,
118
- devices="auto",
119
- logger=wandb_logger,
120
- callbacks= [gan_checkpoint_callback]
121
- )
122
-
123
- # Train the CycleGAN model
124
- trainer.fit(cyclegan, ckpt_path=resume_ckpt_path)