Swekerr's picture
Update app.py
92eea01 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import gradio as gr
class FireModule(nn.Module):
def __init__(self, in_channels, s1x1, e1x1, e3x3):
super(FireModule, self).__init__()
self.squeeze = nn.Conv2d(in_channels=in_channels, out_channels=s1x1, kernel_size=1, stride=1)
self.expand1x1 = nn.Conv2d(in_channels=s1x1, out_channels=e1x1, kernel_size=1)
self.expand3x3 = nn.Conv2d(in_channels=s1x1, out_channels=e3x3, kernel_size=3, padding=1)
def forward(self, x):
x = F.relu(self.squeeze(x))
x1 = self.expand1x1(x)
x2 = self.expand3x3(x)
x = F.relu(torch.cat((x1, x2), dim=1))
return x
class SqueezeNet(nn.Module):
def __init__(self, out_channels):
super(SqueezeNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=7, stride=2)
self.max_pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
self.fire2 = FireModule(in_channels=96, s1x1=16, e1x1=64, e3x3=64)
self.fire3 = FireModule(in_channels=128, s1x1=16, e1x1=64, e3x3=64)
self.fire4 = FireModule(in_channels=128, s1x1=32, e1x1=128, e3x3=128)
self.max_pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
self.fire5 = FireModule(in_channels=256, s1x1=32, e1x1=128, e3x3=128)
self.fire6 = FireModule(in_channels=256, s1x1=48, e1x1=192, e3x3=192)
self.fire7 = FireModule(in_channels=384, s1x1=48, e1x1=192, e3x3=192)
self.fire8 = FireModule(in_channels=384, s1x1=64, e1x1=256, e3x3=256)
self.max_pool3 = nn.MaxPool2d(kernel_size=3, stride=2)
self.fire9 = FireModule(in_channels=512, s1x1=64, e1x1=256, e3x3=256)
self.conv10 = nn.Conv2d(in_channels=512, out_channels=out_channels, kernel_size=1, stride=1)
self.avgpool = nn.AvgPool2d(kernel_size=12, stride=1)
def forward(self, x):
x = self.max_pool1(self.conv1(x))
x = self.max_pool2(self.fire4(self.fire3(self.fire2(x))))
x = self.max_pool3(self.fire8(self.fire7(self.fire6(self.fire5(x)))))
x = self.avgpool(self.conv10(self.fire9(x)))
return torch.flatten(x, start_dim=1)
model = SqueezeNet(out_channels=1)
model.load_state_dict(torch.load("squeezenet.pth", map_location=torch.device('cpu')))
model.eval()
transform = transforms.Compose([transforms.Resize((224,224)),
transforms.RandomHorizontalFlip(0.2),
transforms.RandomVerticalFlip(0.1),
transforms.RandomAutocontrast(0.2),
transforms.RandomAdjustSharpness(0.3),
transforms.ToTensor()
])
def classify_brain_tumor(image):
image = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(image)
prediction = torch.sigmoid(output).item()
return "Tumor" if prediction >= 0.5 else "No Tumor"
interface = gr.Interface(
fn=classify_brain_tumor,
inputs=gr.Image(type="pil"),
outputs="text",
title="Brain Tumor Classification",
description="Upload an MRI image to classify if it has a tumor or not."
)
interface.launch()