File size: 4,962 Bytes
2f686da
 
 
 
 
 
 
c4b9624
 
965e7ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4b9624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f686da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import json
import os
# from leaf_disease_predict import ResNet9, load_model, predict_image, CLASS_NAMES

class ImageClassificationBase(torch.nn.Module):
    def validation_step(self, batch):
        images, labels = batch
        out = self(images)
        loss = torch.nn.functional.cross_entropy(out, labels)
        acc = accuracy(out, labels)
        return {"val_loss": loss.detach(), "val_accuracy": acc}
    
    def validation_epoch_end(self, outputs):
        batch_losses = [x["val_loss"] for x in outputs]
        batch_accuracy = [x["val_accuracy"] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()
        epoch_accuracy = torch.stack(batch_accuracy).mean()
        return {"val_loss": epoch_loss, "val_accuracy": epoch_accuracy}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_accuracy']))

def ConvBlock(in_channels, out_channels, pool=False):
    layers = [torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
             torch.nn.BatchNorm2d(out_channels),
             torch.nn.ReLU(inplace=True)]
    if pool:
        layers.append(torch.nn.MaxPool2d(4))
    return torch.nn.Sequential(*layers)

class ResNet9(ImageClassificationBase):
    def __init__(self, in_channels, num_diseases):
        super().__init__()
        
        self.conv1 = ConvBlock(in_channels, 64)
        self.conv2 = ConvBlock(64, 128, pool=True)
        self.res1 = torch.nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128))
        
        self.conv3 = ConvBlock(128, 256, pool=True)
        self.conv4 = ConvBlock(256, 512, pool=True)
        self.res2 = torch.nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512))
        
        self.classifier = torch.nn.Sequential(torch.nn.MaxPool2d(4),
                                       torch.nn.Flatten(),
                                       torch.nn.Linear(512, num_diseases))
        
    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.classifier(out)
        return out

CLASS_NAMES = [
    'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy',
    'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy',
    
]

def predict_image(image_path, model):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
    
    img = Image.open(image_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0)
    
    with torch.no_grad():
        outputs = model(img_tensor)
        _, predicted = torch.max(outputs, 1)
        
    return CLASS_NAMES[predicted.item()]


def load_model(model_path):
    model = torch.load(model_path, map_location=torch.device('cpu'))
    model.eval()
    return model



# Load the model
model_path = 'models/leaf_disease_res50_model_epoch_10.pth'
model = load_model(model_path)
model.eval()

# Define the prediction function
def predict(image):
    # Convert Gradio image input to PIL image
    image = Image.fromarray(image.astype('uint8'), 'RGB')
    
    # Save the uploaded file temporarily
    temp_image_path = "temp_image.jpg"
    image.save(temp_image_path)
    
    # Make prediction
    prediction = predict_image(temp_image_path, model)
    
    # Remove temporary file
    os.remove(temp_image_path)
    
    # Get confidence scores
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
    img_tensor = transform(image).unsqueeze(0)
    with torch.no_grad():
        outputs = model(img_tensor)
        probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
    
    # Get top 5 predictions
    top5_prob, top5_catid = torch.topk(probabilities, 5)
    top_predictions = {CLASS_NAMES[top5_catid[i]]: top5_prob[i].item()*100 for i in range(top5_prob.size(0))}
    
    # Create a JSON response
    response = {
        "prediction": prediction,
        "confidence_scores": top_predictions
    }
    
    # For the image output, we'll just return the original image for now
    # You can modify this part to add a bounding box if your model provides localization
    return json.dumps(response), image

# Define Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(),
    outputs=[gr.JSON(label="Prediction Result"), gr.Image(label="Processed Image")],
    title="Plant Disease Predictor",
    description="Upload an image of a plant leaf to predict if it has a disease."
)

# Launch the app
iface.launch()