brain_cancer_detection / gradio_web_app.py
ANeuronI
initial commit
92ac70f
import os
import numpy as np
from torchvision import transforms
from torch import nn
import torchvision.models as models
from PIL import Image
import torch
import gradio as gr
# Define the data transformation
data_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Load the pre-trained model
class CustomVGG16(nn.Module):
def __init__(self, num_classes):
super(CustomVGG16, self).__init__()
self.features = models.vgg16(pretrained=True).features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# Load the trained model and classes
# data_dir = 'data'
# classes = os.listdir(data_dir)
# num_classes = len(classes)
classes = ['No Tumour','Tumour Present']
num_classes = 2
model = CustomVGG16(num_classes)
# Load the model weights
pathofpretrain = './models/brain_active_classifiaction_on_big.pth'
checkpoint = torch.load(pathofpretrain)['model']
state_dict = {k.replace("parent_module.", ""): v for k, v in checkpoint.items()}
model.load_state_dict(state_dict)
model.eval()
# Function to predict image class
def classify_image(input_image):
input_tensor = data_transform(input_image)
input_batch = input_tensor.unsqueeze(0)
# Perform inference
with torch.no_grad():
output = model(input_batch)
probabilities = torch.softmax(output, dim=1)[0]
predicted_class_index = torch.argmax(probabilities).item()
predicted_class = classes[predicted_class_index]
return predicted_class
# Define Gradio interface
input_image = gr.Image(type="pil")
output_text = gr.Label(num_top_classes=2)
# Create the Gradio app
gr.Interface(fn=classify_image, inputs=input_image, outputs=output_text, title="Brain Tumor Classification ACTIVE LEARNING").launch(share=True)