import numpy as np import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image from torchvision import models import gradio as gr # Define transformations (must be the same as those used during training) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load the model architecture and weights model = models.resnet50(weights=None) # Initialize model without pretrained weights model.fc = nn.Linear(model.fc.in_features, 4) # Adjust final layer for 4 classes # Load the state dictionary with map_location for CPU model.load_state_dict(torch.load("alzheimer_model_resnet50.pth", map_location=torch.device('cpu'))) model.eval() # Set model to evaluation mode # Define class labels (must match the dataset used during training) class_labels = ["Mild_Demented 0", "Moderate_Demented 1", "Non_Demented 2", "Very_Mild_Demented 3"] # Replace with your class names # Define the prediction function def predict(image): if isinstance(image, np.ndarray): image = Image.fromarray(image.astype('uint8'), 'RGB') else: image = Image.open(image).convert("RGB") image = transform(image).unsqueeze(0) # Add batch dimension with torch.no_grad(): outputs = model(image) _, predicted = torch.max(outputs.data, 1) label = class_labels[predicted.item()] return label # Create a Gradio interface with examples examples = [ ["image1.jpg"], ["image2.jpg"], ["image3.jpg"] ] iface = gr.Interface( fn=predict, inputs=gr.Image(type="numpy", label="Upload an MRI Image"), outputs=gr.Textbox(label="Prediction"), title="Alzheimer MRI Classification", examples=examples ) if __name__ == "__main__": iface.launch()