Fawazzx commited on
Commit
9ec92cf
·
verified ·
1 Parent(s): 6609b34

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ from torchvision import models
6
+ import gradio as gr
7
+
8
+ # Define transformations (must be the same as those used during training)
9
+ transform = transforms.Compose([
10
+ transforms.Resize((224, 224)),
11
+ transforms.ToTensor(),
12
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
13
+ ])
14
+
15
+ # Load the model architecture and weights
16
+ model = models.resnet50(weights=None) # Initialize model without pretrained weights
17
+ model.fc = nn.Linear(model.fc.in_features, 4) # Adjust final layer for 4 classes
18
+
19
+ # Load the state dictionary with map_location for CPU
20
+ model.load_state_dict(torch.load("alzheimer_model_resnet50.pth", map_location=torch.device('cpu')))
21
+ model.eval() # Set model to evaluation mode
22
+
23
+ # Define class labels (must match the dataset used during training)
24
+ class_labels = ["Mild_Demented 0", "Moderate_Demented 1", "Non_Demented 2", "Very_Mild_Demented 3"] # Replace with your class names
25
+
26
+ # Define the prediction function
27
+ def predict(image):
28
+ image = Image.open(image).convert("RGB")
29
+ image = transform(image).unsqueeze(0) # Add batch dimension
30
+
31
+ with torch.no_grad():
32
+ outputs = model(image)
33
+ _, predicted = torch.max(outputs.data, 1)
34
+ label = class_labels[predicted.item()]
35
+ return label
36
+
37
+ # Create a Gradio interface
38
+ iface = gr.Interface(fn=predict, inputs="image", outputs="text", title="Alzheimer MRI Classification")
39
+
40
+ if __name__ == "__main__":
41
+ iface.launch()