anindya-hf-2002 commited on
Commit
4236119
1 Parent(s): 7f5db1f

delete files

Browse files
Files changed (1) hide show
  1. src/generate_images.py +0 -96
src/generate_images.py DELETED
@@ -1,96 +0,0 @@
1
- import os
2
- import torch
3
- from PIL import Image
4
- import numpy as np
5
- from torch.utils.data import DataLoader, Dataset
6
- from torchvision import transforms
7
- from tqdm import tqdm
8
-
9
- from src.models import ResUNetGenerator
10
-
11
- # Custom Dataset
12
- class ImageDataset(Dataset):
13
- def __init__(self, image_paths, transform=None):
14
- self.image_paths = image_paths
15
- self.transform = transform
16
-
17
- def __len__(self):
18
- return len(self.image_paths)
19
-
20
- def __getitem__(self, idx):
21
- img_path = self.image_paths[idx]
22
- image = Image.open(img_path).convert('L')
23
- if self.transform:
24
- image = self.transform(image)
25
- return image, img_path
26
-
27
- # Function to save image
28
- def save_image(tensor, path):
29
- if tensor.is_cuda:
30
- tensor = tensor.cpu()
31
-
32
- array = tensor.permute(1, 2, 0).detach().numpy()
33
- array = (array * 0.5 + 0.5) * 255
34
- array = array.astype(np.uint8)
35
- if array.shape[2] == 1:
36
- array = array.squeeze(2)
37
- image = Image.fromarray(array, mode='L')
38
- else:
39
- image = Image.fromarray(array)
40
- image.save(path)
41
-
42
- # Function to load model
43
- def load_model(checkpoint_path, model_class, device):
44
- model = model_class().to(device)
45
- model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
46
- model.eval()
47
- return model
48
-
49
- def generate_images(image_folder, g_NP_checkpoint, g_PN_checkpoint, output_dir='data/translated_images', batch_size=16):
50
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51
-
52
- # Load models
53
- g_NP = load_model(g_NP_checkpoint, lambda: ResUNetGenerator(gf=32, channels=1), device)
54
- g_PN = load_model(g_PN_checkpoint, lambda: ResUNetGenerator(gf=32, channels=1), device)
55
-
56
- # Create output directories
57
- os.makedirs(os.path.join(output_dir, '0'), exist_ok=True)
58
- os.makedirs(os.path.join(output_dir, '1'), exist_ok=True)
59
-
60
- # Collect image paths
61
- image_paths_0 = [os.path.join(image_folder, '0', fname) for fname in os.listdir(os.path.join(image_folder, '0')) if fname.endswith(('.png', '.jpg', '.jpeg'))]
62
- image_paths_1 = [os.path.join(image_folder, '1', fname) for fname in os.listdir(os.path.join(image_folder, '1')) if fname.endswith(('.png', '.jpg', '.jpeg'))]
63
-
64
- # Prepare dataset and dataloader
65
- transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485], std=[0.229])])
66
- dataset_0 = ImageDataset(image_paths_0, transform)
67
- dataset_1 = ImageDataset(image_paths_1, transform)
68
- dataloader_0 = DataLoader(dataset_0, batch_size=batch_size, shuffle=False)
69
- dataloader_1 = DataLoader(dataset_1, batch_size=batch_size, shuffle=False)
70
-
71
- # Process images from negative (0) to positive (1)
72
- with torch.no_grad():
73
- for batch, paths in tqdm(dataloader_0, desc="Converting N to P: "):
74
- batch = batch.to(device)
75
- translated_images = g_NP(batch)
76
- translated_images = g_PN(translated_images)
77
- for img, path in zip(translated_images, paths):
78
- save_path = os.path.join(output_dir, '1', os.path.basename(path))
79
- save_image(img, save_path)
80
-
81
- # Process images from positive (1) to negative (0)
82
- for batch, paths in tqdm(dataloader_1, desc="Converting P to N: "):
83
- batch = batch.to(device)
84
- translated_images = g_PN(batch)
85
- translated_images = g_NP(translated_images)
86
- for img, path in zip(translated_images, paths):
87
- save_path = os.path.join(output_dir, '0', os.path.basename(path))
88
- save_image(img, save_path)
89
-
90
- if __name__ == '__main__':
91
- image_folder = r'data\rsna-pneumonia-dataset\train'
92
- g_NP_checkpoint = 'models\g_NP_best.ckpt'
93
- g_PN_checkpoint = 'models\g_PN_best.ckpt'
94
-
95
-
96
- generate_images(image_folder, g_NP_checkpoint, g_PN_checkpoint)