Swekerr commited on
Commit
c1e0a70
·
verified ·
1 Parent(s): 6241dff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -44,27 +44,26 @@ class SqueezeNet(nn.Module):
44
  x = self.avgpool(self.conv10(self.fire9(x)))
45
  return torch.flatten(x, start_dim=1)
46
 
47
- # Initialize the model and load weights
48
- model = SqueezeNet(out_channels=1) # Adjust output channels if needed
49
  model.load_state_dict(torch.load("squeezenet.pth", map_location=torch.device('cpu')))
50
  model.eval()
51
 
52
  transform = transforms.Compose([
53
- transforms.Resize((128, 128)), # Resize to match model's input size
54
- transforms.ToTensor(), # Convert to tensor
55
- transforms.Normalize([0.5], [0.5]) # Normalize based on training dataset
56
  ])
57
 
58
  def classify_brain_tumor(image):
59
- image = transform(image).unsqueeze(0) # Preprocess and add batch dimension
60
  with torch.no_grad():
61
  output = model(image)
62
- prediction = torch.sigmoid(output).item() # Apply sigmoid for binary classification
63
  return "Tumor" if prediction >= 0.5 else "No Tumor"
64
 
65
  interface = gr.Interface(
66
  fn=classify_brain_tumor,
67
- inputs=gr.inputs.Image(type="pil"),
68
  outputs="text",
69
  title="Brain Tumor Classification",
70
  description="Upload an MRI image to classify if it has a tumor or not."
 
44
  x = self.avgpool(self.conv10(self.fire9(x)))
45
  return torch.flatten(x, start_dim=1)
46
 
47
+ model = SqueezeNet(out_channels=1)
 
48
  model.load_state_dict(torch.load("squeezenet.pth", map_location=torch.device('cpu')))
49
  model.eval()
50
 
51
  transform = transforms.Compose([
52
+ transforms.Resize((128, 128)),
53
+ transforms.ToTensor(),
54
+ transforms.Normalize([0.5], [0.5])
55
  ])
56
 
57
  def classify_brain_tumor(image):
58
+ image = transform(image).unsqueeze(0)
59
  with torch.no_grad():
60
  output = model(image)
61
+ prediction = torch.sigmoid(output).item()
62
  return "Tumor" if prediction >= 0.5 else "No Tumor"
63
 
64
  interface = gr.Interface(
65
  fn=classify_brain_tumor,
66
+ inputs=gr.Image(type="pil"),
67
  outputs="text",
68
  title="Brain Tumor Classification",
69
  description="Upload an MRI image to classify if it has a tumor or not."