|
import os |
|
import numpy as np |
|
from torchvision import transforms |
|
from torch import nn |
|
import torchvision.models as models |
|
from PIL import Image |
|
import torch |
|
import gradio as gr |
|
|
|
|
|
|
|
data_transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
]) |
|
|
|
|
|
class CustomVGG16(nn.Module): |
|
def __init__(self, num_classes): |
|
super(CustomVGG16, self).__init__() |
|
self.features = models.vgg16(pretrained=True).features |
|
self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) |
|
self.classifier = nn.Sequential( |
|
nn.Linear(512 * 7 * 7, 4096), |
|
nn.ReLU(True), |
|
nn.Dropout(), |
|
nn.Linear(4096, 4096), |
|
nn.ReLU(True), |
|
nn.Dropout(), |
|
nn.Linear(4096, num_classes) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.features(x) |
|
x = self.avgpool(x) |
|
x = torch.flatten(x, 1) |
|
x = self.classifier(x) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
classes = ['No Tumour','Tumour Present'] |
|
num_classes = 2 |
|
model = CustomVGG16(num_classes) |
|
|
|
|
|
pathofpretrain = './models/brain_active_classifiaction_on_big.pth' |
|
checkpoint = torch.load(pathofpretrain)['model'] |
|
state_dict = {k.replace("parent_module.", ""): v for k, v in checkpoint.items()} |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
|
|
def classify_image(input_image): |
|
input_tensor = data_transform(input_image) |
|
input_batch = input_tensor.unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(input_batch) |
|
|
|
probabilities = torch.softmax(output, dim=1)[0] |
|
predicted_class_index = torch.argmax(probabilities).item() |
|
predicted_class = classes[predicted_class_index] |
|
|
|
return predicted_class |
|
|
|
|
|
input_image = gr.Image(type="pil") |
|
output_text = gr.Label(num_top_classes=2) |
|
|
|
|
|
gr.Interface(fn=classify_image, inputs=input_image, outputs=output_text, title="Brain Tumor Classification ACTIVE LEARNING").launch(share=True) |
|
|