Fawazzx commited on
Commit
0933763
1 Parent(s): 7854ef8

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -57
app.py DELETED
@@ -1,57 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import torchvision.transforms as transforms
5
- from PIL import Image
6
- from torchvision import models
7
- import gradio as gr
8
-
9
- # Define transformations (must be the same as those used during training)
10
- transform = transforms.Compose([
11
- transforms.Resize((224, 224)),
12
- transforms.ToTensor(),
13
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
14
- ])
15
-
16
- # Load the model architecture and weights
17
- model = models.resnet50(weights=None) # Initialize model without pretrained weights
18
- model.fc = nn.Linear(model.fc.in_features, 4) # Adjust final layer for 4 classes
19
-
20
- # Load the state dictionary with map_location for CPU
21
- model.load_state_dict(torch.load("alzheimer_model_resnet50.pth", map_location=torch.device('cpu')))
22
- model.eval() # Set model to evaluation mode
23
-
24
- # Define class labels (must match the dataset used during training)
25
- class_labels = ["Mild_Demented 0", "Moderate_Demented 1", "Non_Demented 2", "Very_Mild_Demented 3"] # Replace with your class names
26
-
27
- # Define the prediction function
28
- def predict(image):
29
- if isinstance(image, np.ndarray):
30
- image = Image.fromarray(image.astype('uint8'), 'RGB')
31
- else:
32
- image = Image.open(image).convert("RGB")
33
- image = transform(image).unsqueeze(0) # Add batch dimension
34
-
35
- with torch.no_grad():
36
- outputs = model(image)
37
- _, predicted = torch.max(outputs.data, 1)
38
- label = class_labels[predicted.item()]
39
- return label
40
-
41
- # Create a Gradio interface with examples
42
- examples = [
43
- ["0.jpg"],
44
- ["f1.jpg"],
45
- ["image.jpg"]
46
- ]
47
-
48
- iface = gr.Interface(
49
- fn=predict,
50
- inputs=gr.inputs.Image(type="numpy", label="Upload an MRI Image"),
51
- outputs="text",
52
- title="Alzheimer MRI Classification",
53
- examples=examples
54
- )
55
-
56
- if __name__ == "__main__":
57
- iface.launch()