File size: 4,962 Bytes
2f686da c4b9624 965e7ba c4b9624 2f686da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import json
import os
# from leaf_disease_predict import ResNet9, load_model, predict_image, CLASS_NAMES
class ImageClassificationBase(torch.nn.Module):
def validation_step(self, batch):
images, labels = batch
out = self(images)
loss = torch.nn.functional.cross_entropy(out, labels)
acc = accuracy(out, labels)
return {"val_loss": loss.detach(), "val_accuracy": acc}
def validation_epoch_end(self, outputs):
batch_losses = [x["val_loss"] for x in outputs]
batch_accuracy = [x["val_accuracy"] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean()
epoch_accuracy = torch.stack(batch_accuracy).mean()
return {"val_loss": epoch_loss, "val_accuracy": epoch_accuracy}
def epoch_end(self, epoch, result):
print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_accuracy']))
def ConvBlock(in_channels, out_channels, pool=False):
layers = [torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
torch.nn.BatchNorm2d(out_channels),
torch.nn.ReLU(inplace=True)]
if pool:
layers.append(torch.nn.MaxPool2d(4))
return torch.nn.Sequential(*layers)
class ResNet9(ImageClassificationBase):
def __init__(self, in_channels, num_diseases):
super().__init__()
self.conv1 = ConvBlock(in_channels, 64)
self.conv2 = ConvBlock(64, 128, pool=True)
self.res1 = torch.nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128))
self.conv3 = ConvBlock(128, 256, pool=True)
self.conv4 = ConvBlock(256, 512, pool=True)
self.res2 = torch.nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512))
self.classifier = torch.nn.Sequential(torch.nn.MaxPool2d(4),
torch.nn.Flatten(),
torch.nn.Linear(512, num_diseases))
def forward(self, xb):
out = self.conv1(xb)
out = self.conv2(out)
out = self.res1(out) + out
out = self.conv3(out)
out = self.conv4(out)
out = self.res2(out) + out
out = self.classifier(out)
return out
CLASS_NAMES = [
'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy',
'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy',
]
def predict_image(image_path, model):
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
img = Image.open(image_path).convert('RGB')
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(img_tensor)
_, predicted = torch.max(outputs, 1)
return CLASS_NAMES[predicted.item()]
def load_model(model_path):
model = torch.load(model_path, map_location=torch.device('cpu'))
model.eval()
return model
# Load the model
model_path = 'models/leaf_disease_res50_model_epoch_10.pth'
model = load_model(model_path)
model.eval()
# Define the prediction function
def predict(image):
# Convert Gradio image input to PIL image
image = Image.fromarray(image.astype('uint8'), 'RGB')
# Save the uploaded file temporarily
temp_image_path = "temp_image.jpg"
image.save(temp_image_path)
# Make prediction
prediction = predict_image(temp_image_path, model)
# Remove temporary file
os.remove(temp_image_path)
# Get confidence scores
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Get top 5 predictions
top5_prob, top5_catid = torch.topk(probabilities, 5)
top_predictions = {CLASS_NAMES[top5_catid[i]]: top5_prob[i].item()*100 for i in range(top5_prob.size(0))}
# Create a JSON response
response = {
"prediction": prediction,
"confidence_scores": top_predictions
}
# For the image output, we'll just return the original image for now
# You can modify this part to add a bounding box if your model provides localization
return json.dumps(response), image
# Define Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(),
outputs=[gr.JSON(label="Prediction Result"), gr.Image(label="Processed Image")],
title="Plant Disease Predictor",
description="Upload an image of a plant leaf to predict if it has a disease."
)
# Launch the app
iface.launch()
|