Swekerr commited on
Commit
6241dff
·
verified ·
1 Parent(s): c007375

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -10
app.py CHANGED
@@ -1,31 +1,73 @@
1
  import torch
2
  import torch.nn as nn
 
3
  from torchvision import transforms
4
  from PIL import Image
5
  import gradio as gr
6
 
7
- model = torch.load("squeezenet.pth", map_location=torch.device('cpu'))
8
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  transform = transforms.Compose([
11
- transforms.Resize((128, 128)),
12
- transforms.ToTensor(),
13
- transforms.Normalize([0.5], [0.5])
14
  ])
15
 
16
  def classify_brain_tumor(image):
17
- image = transform(image).unsqueeze(0)
18
  with torch.no_grad():
19
  output = model(image)
20
- _, predicted = torch.max(output, 1)
21
- return "Tumor" if predicted.item() == 1 else "No Tumor"
22
 
23
  interface = gr.Interface(
24
  fn=classify_brain_tumor,
25
  inputs=gr.inputs.Image(type="pil"),
26
  outputs="text",
27
  title="Brain Tumor Classification",
28
- description="Upload an MRI image to classify if it has a tumor or not. The Model is SqueezeNet."
29
  )
30
 
31
- interface.launch()
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
  from torchvision import transforms
5
  from PIL import Image
6
  import gradio as gr
7
 
8
+ class FireModule(nn.Module):
9
+ def __init__(self, in_channels, s1x1, e1x1, e3x3):
10
+ super(FireModule, self).__init__()
11
+ self.squeeze = nn.Conv2d(in_channels=in_channels, out_channels=s1x1, kernel_size=1, stride=1)
12
+ self.expand1x1 = nn.Conv2d(in_channels=s1x1, out_channels=e1x1, kernel_size=1)
13
+ self.expand3x3 = nn.Conv2d(in_channels=s1x1, out_channels=e3x3, kernel_size=3, padding=1)
14
+
15
+ def forward(self, x):
16
+ x = F.relu(self.squeeze(x))
17
+ x1 = self.expand1x1(x)
18
+ x2 = self.expand3x3(x)
19
+ x = F.relu(torch.cat((x1, x2), dim=1))
20
+ return x
21
+
22
+ class SqueezeNet(nn.Module):
23
+ def __init__(self, out_channels):
24
+ super(SqueezeNet, self).__init__()
25
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=7, stride=2)
26
+ self.max_pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
27
+ self.fire2 = FireModule(in_channels=96, s1x1=16, e1x1=64, e3x3=64)
28
+ self.fire3 = FireModule(in_channels=128, s1x1=16, e1x1=64, e3x3=64)
29
+ self.fire4 = FireModule(in_channels=128, s1x1=32, e1x1=128, e3x3=128)
30
+ self.max_pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
31
+ self.fire5 = FireModule(in_channels=256, s1x1=32, e1x1=128, e3x3=128)
32
+ self.fire6 = FireModule(in_channels=256, s1x1=48, e1x1=192, e3x3=192)
33
+ self.fire7 = FireModule(in_channels=384, s1x1=48, e1x1=192, e3x3=192)
34
+ self.fire8 = FireModule(in_channels=384, s1x1=64, e1x1=256, e3x3=256)
35
+ self.max_pool3 = nn.MaxPool2d(kernel_size=3, stride=2)
36
+ self.fire9 = FireModule(in_channels=512, s1x1=64, e1x1=256, e3x3=256)
37
+ self.conv10 = nn.Conv2d(in_channels=512, out_channels=out_channels, kernel_size=1, stride=1)
38
+ self.avgpool = nn.AvgPool2d(kernel_size=12, stride=1)
39
+
40
+ def forward(self, x):
41
+ x = self.max_pool1(self.conv1(x))
42
+ x = self.max_pool2(self.fire4(self.fire3(self.fire2(x))))
43
+ x = self.max_pool3(self.fire8(self.fire7(self.fire6(self.fire5(x)))))
44
+ x = self.avgpool(self.conv10(self.fire9(x)))
45
+ return torch.flatten(x, start_dim=1)
46
+
47
+ # Initialize the model and load weights
48
+ model = SqueezeNet(out_channels=1) # Adjust output channels if needed
49
+ model.load_state_dict(torch.load("squeezenet.pth", map_location=torch.device('cpu')))
50
+ model.eval()
51
 
52
  transform = transforms.Compose([
53
+ transforms.Resize((128, 128)), # Resize to match model's input size
54
+ transforms.ToTensor(), # Convert to tensor
55
+ transforms.Normalize([0.5], [0.5]) # Normalize based on training dataset
56
  ])
57
 
58
  def classify_brain_tumor(image):
59
+ image = transform(image).unsqueeze(0) # Preprocess and add batch dimension
60
  with torch.no_grad():
61
  output = model(image)
62
+ prediction = torch.sigmoid(output).item() # Apply sigmoid for binary classification
63
+ return "Tumor" if prediction >= 0.5 else "No Tumor"
64
 
65
  interface = gr.Interface(
66
  fn=classify_brain_tumor,
67
  inputs=gr.inputs.Image(type="pil"),
68
  outputs="text",
69
  title="Brain Tumor Classification",
70
+ description="Upload an MRI image to classify if it has a tumor or not."
71
  )
72
 
73
+ interface.launch()