File size: 3,375 Bytes
7b34bdd
 
6241dff
7b34bdd
 
 
 
6241dff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1e0a70
6241dff
 
7b34bdd
d74135d
 
 
 
 
 
 
7b34bdd
 
c1e0a70
7b34bdd
 
c1e0a70
6241dff
7b34bdd
 
 
c1e0a70
7b34bdd
 
6241dff
7b34bdd
 
6241dff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import gradio as gr

class FireModule(nn.Module):
    def __init__(self, in_channels, s1x1, e1x1, e3x3):
        super(FireModule, self).__init__()
        self.squeeze = nn.Conv2d(in_channels=in_channels, out_channels=s1x1, kernel_size=1, stride=1)
        self.expand1x1 = nn.Conv2d(in_channels=s1x1, out_channels=e1x1, kernel_size=1)
        self.expand3x3 = nn.Conv2d(in_channels=s1x1, out_channels=e3x3, kernel_size=3, padding=1)
        
    def forward(self, x):
        x = F.relu(self.squeeze(x)) 
        x1 = self.expand1x1(x)   
        x2 = self.expand3x3(x)      
        x = F.relu(torch.cat((x1, x2), dim=1))  
        return x

class SqueezeNet(nn.Module):
    def __init__(self, out_channels):
        super(SqueezeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=7, stride=2)
        self.max_pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.fire2 = FireModule(in_channels=96, s1x1=16, e1x1=64, e3x3=64)
        self.fire3 = FireModule(in_channels=128, s1x1=16, e1x1=64, e3x3=64)  
        self.fire4 = FireModule(in_channels=128, s1x1=32, e1x1=128, e3x3=128)
        self.max_pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.fire5 = FireModule(in_channels=256, s1x1=32, e1x1=128, e3x3=128)
        self.fire6 = FireModule(in_channels=256, s1x1=48, e1x1=192, e3x3=192)
        self.fire7 = FireModule(in_channels=384, s1x1=48, e1x1=192, e3x3=192)
        self.fire8 = FireModule(in_channels=384, s1x1=64, e1x1=256, e3x3=256)
        self.max_pool3 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.fire9 = FireModule(in_channels=512, s1x1=64, e1x1=256, e3x3=256)
        self.conv10 = nn.Conv2d(in_channels=512, out_channels=out_channels, kernel_size=1, stride=1)
        self.avgpool = nn.AvgPool2d(kernel_size=12, stride=1)
        
    def forward(self, x):
        x = self.max_pool1(self.conv1(x))
        x = self.max_pool2(self.fire4(self.fire3(self.fire2(x))))
        x = self.max_pool3(self.fire8(self.fire7(self.fire6(self.fire5(x)))))
        x = self.avgpool(self.conv10(self.fire9(x)))
        return torch.flatten(x, start_dim=1)

model = SqueezeNet(out_channels=1)  
model.load_state_dict(torch.load("squeezenet.pth", map_location=torch.device('cpu')))
model.eval()  

transform = transforms.Compose([transforms.Resize((224,224)),
                                          transforms.RandomHorizontalFlip(0.2),
                                          transforms.RandomVerticalFlip(0.1),
                                          transforms.RandomAutocontrast(0.2),
                                          transforms.RandomAdjustSharpness(0.3),
                                          transforms.ToTensor()
                                         ])

def classify_brain_tumor(image):
    image = transform(image).unsqueeze(0)  
    with torch.no_grad():
        output = model(image)
        prediction = torch.sigmoid(output).item()  
    return "Tumor" if prediction >= 0.5 else "No Tumor"

interface = gr.Interface(
    fn=classify_brain_tumor,
    inputs=gr.Image(type="pil"),
    outputs="text",
    title="Brain Tumor Classification",
    description="Upload an MRI image to classify if it has a tumor or not."
)

interface.launch()