|
import torch |
|
import torch.nn as nn |
|
from torchvision import models, transforms |
|
import gradio as gr |
|
from PIL import Image |
|
|
|
|
|
class_names = ["cordana", "healthy", "pestalotiopsis", "sigatoka"] |
|
def load_model(): |
|
model = models.alexnet(pretrained=False) |
|
num_ftrs = model.classifier[6].in_features |
|
model.classifier[6] = nn.Linear(num_ftrs, len(class_names)) |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
if device == torch.device('cpu'): |
|
model.load_state_dict(torch.load('model_alexnet.pth', map_location=device)) |
|
else: |
|
model.load_state_dict(torch.load('model_alexnet.pth')) |
|
|
|
model.eval() |
|
model.to(device) |
|
return model, device |
|
|
|
|
|
model_alexnet, device = load_model() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
def predict_image(image): |
|
image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
image = transform(image).unsqueeze(0) |
|
image = image.to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model_alexnet(image) |
|
_, predicted = torch.max(outputs, 1) |
|
predicted = predicted.cpu().numpy() |
|
|
|
return class_names[predicted[0]] |
|
|
|
|
|
iface = gr.Interface(fn=predict_image, inputs="image", outputs="label", |
|
description="This model is a fine-tuned version of AlexNet specifically designed to identify four types of diseases in banana tree leaves. It can classify the leaves as Cordana, Healthy, Pestalotiopsis, or Sigatoka. Upload a photo of a banana leaf and the model will help you determine its health condition.", |
|
examples=[ |
|
'cordana.jpeg', |
|
'healthy.jpeg', |
|
'pestalotiopsis.jpeg', |
|
'sigatoka.jpeg' |
|
] |
|
) |
|
|
|
iface.launch() |
|
|