Spaces:
Sleeping
Sleeping
anindya-hf-2002
commited on
Commit
•
4236119
1
Parent(s):
7f5db1f
delete files
Browse files- src/generate_images.py +0 -96
src/generate_images.py
DELETED
@@ -1,96 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
from PIL import Image
|
4 |
-
import numpy as np
|
5 |
-
from torch.utils.data import DataLoader, Dataset
|
6 |
-
from torchvision import transforms
|
7 |
-
from tqdm import tqdm
|
8 |
-
|
9 |
-
from src.models import ResUNetGenerator
|
10 |
-
|
11 |
-
# Custom Dataset
|
12 |
-
class ImageDataset(Dataset):
|
13 |
-
def __init__(self, image_paths, transform=None):
|
14 |
-
self.image_paths = image_paths
|
15 |
-
self.transform = transform
|
16 |
-
|
17 |
-
def __len__(self):
|
18 |
-
return len(self.image_paths)
|
19 |
-
|
20 |
-
def __getitem__(self, idx):
|
21 |
-
img_path = self.image_paths[idx]
|
22 |
-
image = Image.open(img_path).convert('L')
|
23 |
-
if self.transform:
|
24 |
-
image = self.transform(image)
|
25 |
-
return image, img_path
|
26 |
-
|
27 |
-
# Function to save image
|
28 |
-
def save_image(tensor, path):
|
29 |
-
if tensor.is_cuda:
|
30 |
-
tensor = tensor.cpu()
|
31 |
-
|
32 |
-
array = tensor.permute(1, 2, 0).detach().numpy()
|
33 |
-
array = (array * 0.5 + 0.5) * 255
|
34 |
-
array = array.astype(np.uint8)
|
35 |
-
if array.shape[2] == 1:
|
36 |
-
array = array.squeeze(2)
|
37 |
-
image = Image.fromarray(array, mode='L')
|
38 |
-
else:
|
39 |
-
image = Image.fromarray(array)
|
40 |
-
image.save(path)
|
41 |
-
|
42 |
-
# Function to load model
|
43 |
-
def load_model(checkpoint_path, model_class, device):
|
44 |
-
model = model_class().to(device)
|
45 |
-
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
|
46 |
-
model.eval()
|
47 |
-
return model
|
48 |
-
|
49 |
-
def generate_images(image_folder, g_NP_checkpoint, g_PN_checkpoint, output_dir='data/translated_images', batch_size=16):
|
50 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
51 |
-
|
52 |
-
# Load models
|
53 |
-
g_NP = load_model(g_NP_checkpoint, lambda: ResUNetGenerator(gf=32, channels=1), device)
|
54 |
-
g_PN = load_model(g_PN_checkpoint, lambda: ResUNetGenerator(gf=32, channels=1), device)
|
55 |
-
|
56 |
-
# Create output directories
|
57 |
-
os.makedirs(os.path.join(output_dir, '0'), exist_ok=True)
|
58 |
-
os.makedirs(os.path.join(output_dir, '1'), exist_ok=True)
|
59 |
-
|
60 |
-
# Collect image paths
|
61 |
-
image_paths_0 = [os.path.join(image_folder, '0', fname) for fname in os.listdir(os.path.join(image_folder, '0')) if fname.endswith(('.png', '.jpg', '.jpeg'))]
|
62 |
-
image_paths_1 = [os.path.join(image_folder, '1', fname) for fname in os.listdir(os.path.join(image_folder, '1')) if fname.endswith(('.png', '.jpg', '.jpeg'))]
|
63 |
-
|
64 |
-
# Prepare dataset and dataloader
|
65 |
-
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485], std=[0.229])])
|
66 |
-
dataset_0 = ImageDataset(image_paths_0, transform)
|
67 |
-
dataset_1 = ImageDataset(image_paths_1, transform)
|
68 |
-
dataloader_0 = DataLoader(dataset_0, batch_size=batch_size, shuffle=False)
|
69 |
-
dataloader_1 = DataLoader(dataset_1, batch_size=batch_size, shuffle=False)
|
70 |
-
|
71 |
-
# Process images from negative (0) to positive (1)
|
72 |
-
with torch.no_grad():
|
73 |
-
for batch, paths in tqdm(dataloader_0, desc="Converting N to P: "):
|
74 |
-
batch = batch.to(device)
|
75 |
-
translated_images = g_NP(batch)
|
76 |
-
translated_images = g_PN(translated_images)
|
77 |
-
for img, path in zip(translated_images, paths):
|
78 |
-
save_path = os.path.join(output_dir, '1', os.path.basename(path))
|
79 |
-
save_image(img, save_path)
|
80 |
-
|
81 |
-
# Process images from positive (1) to negative (0)
|
82 |
-
for batch, paths in tqdm(dataloader_1, desc="Converting P to N: "):
|
83 |
-
batch = batch.to(device)
|
84 |
-
translated_images = g_PN(batch)
|
85 |
-
translated_images = g_NP(translated_images)
|
86 |
-
for img, path in zip(translated_images, paths):
|
87 |
-
save_path = os.path.join(output_dir, '0', os.path.basename(path))
|
88 |
-
save_image(img, save_path)
|
89 |
-
|
90 |
-
if __name__ == '__main__':
|
91 |
-
image_folder = r'data\rsna-pneumonia-dataset\train'
|
92 |
-
g_NP_checkpoint = 'models\g_NP_best.ckpt'
|
93 |
-
g_PN_checkpoint = 'models\g_PN_best.ckpt'
|
94 |
-
|
95 |
-
|
96 |
-
generate_images(image_folder, g_NP_checkpoint, g_PN_checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|