File size: 2,304 Bytes
92ac70f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import os
import numpy as np
from torchvision import transforms
from torch import nn
import torchvision.models as models
from PIL import Image
import torch
import gradio as gr
# Define the data transformation
data_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Load the pre-trained model
class CustomVGG16(nn.Module):
def __init__(self, num_classes):
super(CustomVGG16, self).__init__()
self.features = models.vgg16(pretrained=True).features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# Load the trained model and classes
# data_dir = 'data'
# classes = os.listdir(data_dir)
# num_classes = len(classes)
classes = ['No Tumour','Tumour Present']
num_classes = 2
model = CustomVGG16(num_classes)
# Load the model weights
pathofpretrain = './models/brain_active_classifiaction_on_big.pth'
checkpoint = torch.load(pathofpretrain)['model']
state_dict = {k.replace("parent_module.", ""): v for k, v in checkpoint.items()}
model.load_state_dict(state_dict)
model.eval()
# Function to predict image class
def classify_image(input_image):
input_tensor = data_transform(input_image)
input_batch = input_tensor.unsqueeze(0)
# Perform inference
with torch.no_grad():
output = model(input_batch)
probabilities = torch.softmax(output, dim=1)[0]
predicted_class_index = torch.argmax(probabilities).item()
predicted_class = classes[predicted_class_index]
return predicted_class
# Define Gradio interface
input_image = gr.Image(type="pil")
output_text = gr.Label(num_top_classes=2)
# Create the Gradio app
gr.Interface(fn=classify_image, inputs=input_image, outputs=output_text, title="Brain Tumor Classification ACTIVE LEARNING").launch(share=True)
|