File size: 3,186 Bytes
65eeb0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from torchvision import transforms
from PIL import Image
import numpy as np

from src.classifier import Classifier

# Ensure the necessary directories exist
# os.makedirs('results/translated_N', exist_ok=True)
# os.makedirs('results/translated_P', exist_ok=True)

# Load the classifier model
def load_classifier(classifier_path):
    classifier = Classifier()
    classifier_checkpoint = torch.load(classifier_path, map_location=torch.device('cpu'))
    classifier.load_state_dict(classifier_checkpoint['state_dict'])
    classifier.eval()
    return classifier

# Load the generator models
def load_model(checkpoint_path, model):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint)
    model.eval()
    return model

def load_image(input_image, image_size):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),  # Resize image to 512x512
        transforms.ToTensor(),          
        transforms.Normalize(mean=[0.485], std=[0.229])  # Normalize image
    ])
    input_image = input_image.convert('L')
    return transform(input_image).unsqueeze(0)

def convert_into_image(tensor):
    if tensor.is_cuda:
        tensor = tensor.cpu()
    array = tensor.squeeze(0).permute(1, 2, 0).detach().numpy()
    array = (array * 0.5 + 0.5) * 255
    array = array.astype(np.uint8)
    
    if array.shape[2] == 1:
        array = array.squeeze(2)
        image = Image.fromarray(array, mode='L')
    else:
        image = Image.fromarray(array)
    
    return image

def generate_images(input_image, classifier, g_PN, g_NP, image_size=512):

    image = load_image(input_image, image_size)

    # Classify the image
    classifier_output = classifier(image).cpu().detach().numpy()
    pred = np.argmax(classifier_output, axis=1)[0]
  
    if pred > 0.5:
        print("Classified as Domain P")
        translate_to_domain = g_PN
        folder_to_save = 'results/translated_N'
        reverse_translate = g_NP
    else:
        print("Classified as Domain N")
        translate_to_domain = g_NP
        folder_to_save = 'results/translated_P'
        reverse_translate = g_PN

    # Perform translation and save images
    with torch.no_grad():
        for i in range(1):  # Generate and save 10 images
            translated_image = translate_to_domain(image)
            # save_image(translated_image, os.path.join(folder_to_save, f'translated_{i}.png'))

            # Translate back to the original domain and save
            recon_image = reverse_translate(translated_image)
            # save_image(recon_image, os.path.join(folder_to_save, f'recon_{i}.png'))

    return translated_image, recon_image

def classify_image(input_image, classifier, image_size=512):
    
    image = load_image(input_image, image_size)
    classifier_output = classifier(image).cpu().detach().numpy()
    pred = np.argmax(classifier_output, axis=1)[0]
    if pred > 0.5:
        return {"Pneumonia": classifier_output[0][1]}, 1
    
    else:
        return {"Normal": classifier_output[0][0]}, 0