File size: 2,020 Bytes
6895a87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import torch
import torch.nn as nn
from torchvision import models, transforms
import gradio as gr
from PIL import Image
# Define the model architecture (must match the saved model)
class_names = ["cordana", "healthy", "pestalotiopsis", "sigatoka"]
def load_model():
model = models.alexnet(pretrained=False)
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, len(class_names)) # Adjust this for your number of classes
# Load the model weights
model.load_state_dict(torch.load('model_alexnet.pth'))
model.eval() # Set to evaluation mode
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, device
model_alexnet, device = load_model()
# Image transformations
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Prediction function
def predict_image(image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = transform(image).unsqueeze(0) # Add batch dimension
image = image.to(device)
with torch.no_grad():
outputs = model_alexnet(image)
_, predicted = torch.max(outputs, 1)
predicted = predicted.cpu().numpy()
return class_names[predicted[0]] # Adjust this if needed
iface = gr.Interface(fn=predict_image, inputs="image", outputs="label",
description="This model is a fine-tuned version of AlexNet specifically designed to identify four types of diseases in banana tree leaves. It can classify the leaves as Cordana, Healthy, Pestalotiopsis, or Sigatoka. Upload a photo of a banana leaf and the model will help you determine its health condition.",
examples=[
'data/test/cordana/1.jpeg',
'data/test/healthy/5.jpeg',
'data/test/pestalotiopsis/5.jpeg',
'data/test/sigatoka/1.jpeg'
]
)
iface.launch()
|