Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from PIL import Image | |
import gradio as gr | |
class FireModule(nn.Module): | |
def __init__(self, in_channels, s1x1, e1x1, e3x3): | |
super(FireModule, self).__init__() | |
self.squeeze = nn.Conv2d(in_channels=in_channels, out_channels=s1x1, kernel_size=1, stride=1) | |
self.expand1x1 = nn.Conv2d(in_channels=s1x1, out_channels=e1x1, kernel_size=1) | |
self.expand3x3 = nn.Conv2d(in_channels=s1x1, out_channels=e3x3, kernel_size=3, padding=1) | |
def forward(self, x): | |
x = F.relu(self.squeeze(x)) | |
x1 = self.expand1x1(x) | |
x2 = self.expand3x3(x) | |
x = F.relu(torch.cat((x1, x2), dim=1)) | |
return x | |
class SqueezeNet(nn.Module): | |
def __init__(self, out_channels): | |
super(SqueezeNet, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=7, stride=2) | |
self.max_pool1 = nn.MaxPool2d(kernel_size=3, stride=2) | |
self.fire2 = FireModule(in_channels=96, s1x1=16, e1x1=64, e3x3=64) | |
self.fire3 = FireModule(in_channels=128, s1x1=16, e1x1=64, e3x3=64) | |
self.fire4 = FireModule(in_channels=128, s1x1=32, e1x1=128, e3x3=128) | |
self.max_pool2 = nn.MaxPool2d(kernel_size=3, stride=2) | |
self.fire5 = FireModule(in_channels=256, s1x1=32, e1x1=128, e3x3=128) | |
self.fire6 = FireModule(in_channels=256, s1x1=48, e1x1=192, e3x3=192) | |
self.fire7 = FireModule(in_channels=384, s1x1=48, e1x1=192, e3x3=192) | |
self.fire8 = FireModule(in_channels=384, s1x1=64, e1x1=256, e3x3=256) | |
self.max_pool3 = nn.MaxPool2d(kernel_size=3, stride=2) | |
self.fire9 = FireModule(in_channels=512, s1x1=64, e1x1=256, e3x3=256) | |
self.conv10 = nn.Conv2d(in_channels=512, out_channels=out_channels, kernel_size=1, stride=1) | |
self.avgpool = nn.AvgPool2d(kernel_size=12, stride=1) | |
def forward(self, x): | |
x = self.max_pool1(self.conv1(x)) | |
x = self.max_pool2(self.fire4(self.fire3(self.fire2(x)))) | |
x = self.max_pool3(self.fire8(self.fire7(self.fire6(self.fire5(x))))) | |
x = self.avgpool(self.conv10(self.fire9(x))) | |
return torch.flatten(x, start_dim=1) | |
model = SqueezeNet(out_channels=1) | |
model.load_state_dict(torch.load("squeezenet.pth", map_location=torch.device('cpu'))) | |
model.eval() | |
transform = transforms.Compose([transforms.Resize((224,224)), | |
transforms.RandomHorizontalFlip(0.2), | |
transforms.RandomVerticalFlip(0.1), | |
transforms.RandomAutocontrast(0.2), | |
transforms.RandomAdjustSharpness(0.3), | |
transforms.ToTensor() | |
]) | |
def classify_brain_tumor(image): | |
image = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
output = model(image) | |
prediction = torch.sigmoid(output).item() | |
return "Tumor" if prediction >= 0.5 else "No Tumor" | |
interface = gr.Interface( | |
fn=classify_brain_tumor, | |
inputs=gr.Image(type="pil"), | |
outputs="text", | |
title="Brain Tumor Classification", | |
description="Upload an MRI image to classify if it has a tumor or not." | |
) | |
interface.launch() | |