File size: 2,304 Bytes
92ac70f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import numpy as np
from torchvision import transforms
from torch import nn
import torchvision.models as models
from PIL import Image
import torch
import gradio as gr


# Define the data transformation
data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load the pre-trained model
class CustomVGG16(nn.Module):
    def __init__(self, num_classes):
        super(CustomVGG16, self).__init__()
        self.features = models.vgg16(pretrained=True).features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))  
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes)  
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Load the trained model and classes
# data_dir = 'data'
# classes = os.listdir(data_dir)
# num_classes = len(classes)  
classes = ['No Tumour','Tumour Present']
num_classes = 2
model = CustomVGG16(num_classes)

# Load the model weights
pathofpretrain = './models/brain_active_classifiaction_on_big.pth'
checkpoint = torch.load(pathofpretrain)['model']
state_dict = {k.replace("parent_module.", ""): v for k, v in checkpoint.items()}
model.load_state_dict(state_dict)
model.eval()

# Function to predict image class
def classify_image(input_image):
    input_tensor = data_transform(input_image)
    input_batch = input_tensor.unsqueeze(0)

    # Perform inference
    with torch.no_grad():
        output = model(input_batch)

    probabilities = torch.softmax(output, dim=1)[0]
    predicted_class_index = torch.argmax(probabilities).item()
    predicted_class = classes[predicted_class_index]

    return predicted_class

# Define Gradio interface
input_image = gr.Image(type="pil")
output_text = gr.Label(num_top_classes=2)

# Create the Gradio app
gr.Interface(fn=classify_image, inputs=input_image, outputs=output_text, title="Brain Tumor Classification ACTIVE LEARNING").launch(share=True)