|
|
|
import gradio as gr |
|
import torch |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
import json |
|
import os |
|
|
|
|
|
class ImageClassificationBase(torch.nn.Module): |
|
def validation_step(self, batch): |
|
images, labels = batch |
|
out = self(images) |
|
loss = torch.nn.functional.cross_entropy(out, labels) |
|
acc = accuracy(out, labels) |
|
return {"val_loss": loss.detach(), "val_accuracy": acc} |
|
|
|
def validation_epoch_end(self, outputs): |
|
batch_losses = [x["val_loss"] for x in outputs] |
|
batch_accuracy = [x["val_accuracy"] for x in outputs] |
|
epoch_loss = torch.stack(batch_losses).mean() |
|
epoch_accuracy = torch.stack(batch_accuracy).mean() |
|
return {"val_loss": epoch_loss, "val_accuracy": epoch_accuracy} |
|
|
|
def epoch_end(self, epoch, result): |
|
print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format( |
|
epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_accuracy'])) |
|
|
|
def ConvBlock(in_channels, out_channels, pool=False): |
|
layers = [torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
|
torch.nn.BatchNorm2d(out_channels), |
|
torch.nn.ReLU(inplace=True)] |
|
if pool: |
|
layers.append(torch.nn.MaxPool2d(4)) |
|
return torch.nn.Sequential(*layers) |
|
|
|
class ResNet9(ImageClassificationBase): |
|
def __init__(self, in_channels, num_diseases): |
|
super().__init__() |
|
|
|
self.conv1 = ConvBlock(in_channels, 64) |
|
self.conv2 = ConvBlock(64, 128, pool=True) |
|
self.res1 = torch.nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128)) |
|
|
|
self.conv3 = ConvBlock(128, 256, pool=True) |
|
self.conv4 = ConvBlock(256, 512, pool=True) |
|
self.res2 = torch.nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512)) |
|
|
|
self.classifier = torch.nn.Sequential(torch.nn.MaxPool2d(4), |
|
torch.nn.Flatten(), |
|
torch.nn.Linear(512, num_diseases)) |
|
|
|
def forward(self, xb): |
|
out = self.conv1(xb) |
|
out = self.conv2(out) |
|
out = self.res1(out) + out |
|
out = self.conv3(out) |
|
out = self.conv4(out) |
|
out = self.res2(out) + out |
|
out = self.classifier(out) |
|
return out |
|
|
|
CLASS_NAMES = [ |
|
'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', |
|
'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', |
|
|
|
] |
|
|
|
def predict_image(image_path, model): |
|
transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
img = Image.open(image_path).convert('RGB') |
|
img_tensor = transform(img).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
outputs = model(img_tensor) |
|
_, predicted = torch.max(outputs, 1) |
|
|
|
return CLASS_NAMES[predicted.item()] |
|
|
|
|
|
def load_model(model_path): |
|
model = torch.load(model_path, map_location=torch.device('cpu')) |
|
model.eval() |
|
return model |
|
|
|
|
|
|
|
|
|
model_path = 'models/leaf_disease_res50_model_epoch_10.pth' |
|
model = load_model(model_path) |
|
model.eval() |
|
|
|
|
|
def predict(image): |
|
|
|
image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
|
|
|
|
temp_image_path = "temp_image.jpg" |
|
image.save(temp_image_path) |
|
|
|
|
|
prediction = predict_image(temp_image_path, model) |
|
|
|
|
|
os.remove(temp_image_path) |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
]) |
|
img_tensor = transform(image).unsqueeze(0) |
|
with torch.no_grad(): |
|
outputs = model(img_tensor) |
|
probabilities = torch.nn.functional.softmax(outputs[0], dim=0) |
|
|
|
|
|
top5_prob, top5_catid = torch.topk(probabilities, 5) |
|
top_predictions = {CLASS_NAMES[top5_catid[i]]: top5_prob[i].item()*100 for i in range(top5_prob.size(0))} |
|
|
|
|
|
response = { |
|
"prediction": prediction, |
|
"confidence_scores": top_predictions |
|
} |
|
|
|
|
|
|
|
return json.dumps(response), image |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(), |
|
outputs=[gr.JSON(label="Prediction Result"), gr.Image(label="Processed Image")], |
|
title="Plant Disease Predictor", |
|
description="Upload an image of a plant leaf to predict if it has a disease." |
|
) |
|
|
|
|
|
iface.launch() |
|
|
|
|
|
|
|
|