Fawazzx commited on
Commit
7854ef8
1 Parent(s): b1a2603

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+ from torchvision import models
7
+ import gradio as gr
8
+
9
+ # Define transformations (must be the same as those used during training)
10
+ transform = transforms.Compose([
11
+ transforms.Resize((224, 224)),
12
+ transforms.ToTensor(),
13
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
14
+ ])
15
+
16
+ # Load the model architecture and weights
17
+ model = models.resnet50(weights=None) # Initialize model without pretrained weights
18
+ model.fc = nn.Linear(model.fc.in_features, 4) # Adjust final layer for 4 classes
19
+
20
+ # Load the state dictionary with map_location for CPU
21
+ model.load_state_dict(torch.load("alzheimer_model_resnet50.pth", map_location=torch.device('cpu')))
22
+ model.eval() # Set model to evaluation mode
23
+
24
+ # Define class labels (must match the dataset used during training)
25
+ class_labels = ["Mild_Demented 0", "Moderate_Demented 1", "Non_Demented 2", "Very_Mild_Demented 3"] # Replace with your class names
26
+
27
+ # Define the prediction function
28
+ def predict(image):
29
+ if isinstance(image, np.ndarray):
30
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
31
+ else:
32
+ image = Image.open(image).convert("RGB")
33
+ image = transform(image).unsqueeze(0) # Add batch dimension
34
+
35
+ with torch.no_grad():
36
+ outputs = model(image)
37
+ _, predicted = torch.max(outputs.data, 1)
38
+ label = class_labels[predicted.item()]
39
+ return label
40
+
41
+ # Create a Gradio interface with examples
42
+ examples = [
43
+ ["0.jpg"],
44
+ ["f1.jpg"],
45
+ ["image.jpg"]
46
+ ]
47
+
48
+ iface = gr.Interface(
49
+ fn=predict,
50
+ inputs=gr.inputs.Image(type="numpy", label="Upload an MRI Image"),
51
+ outputs="text",
52
+ title="Alzheimer MRI Classification",
53
+ examples=examples
54
+ )
55
+
56
+ if __name__ == "__main__":
57
+ iface.launch()