4darsh-Dev's picture
added conv block
965e7ba verified
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import json
import os
# from leaf_disease_predict import ResNet9, load_model, predict_image, CLASS_NAMES
class ImageClassificationBase(torch.nn.Module):
def validation_step(self, batch):
images, labels = batch
out = self(images)
loss = torch.nn.functional.cross_entropy(out, labels)
acc = accuracy(out, labels)
return {"val_loss": loss.detach(), "val_accuracy": acc}
def validation_epoch_end(self, outputs):
batch_losses = [x["val_loss"] for x in outputs]
batch_accuracy = [x["val_accuracy"] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean()
epoch_accuracy = torch.stack(batch_accuracy).mean()
return {"val_loss": epoch_loss, "val_accuracy": epoch_accuracy}
def epoch_end(self, epoch, result):
print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_accuracy']))
def ConvBlock(in_channels, out_channels, pool=False):
layers = [torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
torch.nn.BatchNorm2d(out_channels),
torch.nn.ReLU(inplace=True)]
if pool:
layers.append(torch.nn.MaxPool2d(4))
return torch.nn.Sequential(*layers)
class ResNet9(ImageClassificationBase):
def __init__(self, in_channels, num_diseases):
super().__init__()
self.conv1 = ConvBlock(in_channels, 64)
self.conv2 = ConvBlock(64, 128, pool=True)
self.res1 = torch.nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128))
self.conv3 = ConvBlock(128, 256, pool=True)
self.conv4 = ConvBlock(256, 512, pool=True)
self.res2 = torch.nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512))
self.classifier = torch.nn.Sequential(torch.nn.MaxPool2d(4),
torch.nn.Flatten(),
torch.nn.Linear(512, num_diseases))
def forward(self, xb):
out = self.conv1(xb)
out = self.conv2(out)
out = self.res1(out) + out
out = self.conv3(out)
out = self.conv4(out)
out = self.res2(out) + out
out = self.classifier(out)
return out
CLASS_NAMES = [
'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy',
'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy',
]
def predict_image(image_path, model):
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
img = Image.open(image_path).convert('RGB')
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(img_tensor)
_, predicted = torch.max(outputs, 1)
return CLASS_NAMES[predicted.item()]
def load_model(model_path):
model = torch.load(model_path, map_location=torch.device('cpu'))
model.eval()
return model
# Load the model
model_path = 'models/leaf_disease_res50_model_epoch_10.pth'
model = load_model(model_path)
model.eval()
# Define the prediction function
def predict(image):
# Convert Gradio image input to PIL image
image = Image.fromarray(image.astype('uint8'), 'RGB')
# Save the uploaded file temporarily
temp_image_path = "temp_image.jpg"
image.save(temp_image_path)
# Make prediction
prediction = predict_image(temp_image_path, model)
# Remove temporary file
os.remove(temp_image_path)
# Get confidence scores
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Get top 5 predictions
top5_prob, top5_catid = torch.topk(probabilities, 5)
top_predictions = {CLASS_NAMES[top5_catid[i]]: top5_prob[i].item()*100 for i in range(top5_prob.size(0))}
# Create a JSON response
response = {
"prediction": prediction,
"confidence_scores": top_predictions
}
# For the image output, we'll just return the original image for now
# You can modify this part to add a bounding box if your model provides localization
return json.dumps(response), image
# Define Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(),
outputs=[gr.JSON(label="Prediction Result"), gr.Image(label="Processed Image")],
title="Plant Disease Predictor",
description="Upload an image of a plant leaf to predict if it has a disease."
)
# Launch the app
iface.launch()